# Hist2Cell Model Training Tutorial

## Welcome to the Hist2Cell Training Guide! 🚀

This comprehensive tutorial will guide you through the complete process of training a **Hist2Cell** model for predicting fine-grained cell type abundances from histology images. Whether you're new to deep learning or an experienced researcher, this guide provides step-by-step instructions with detailed explanations.

## 📋 Table of Contents

1. [**Understanding Hist2Cell Training**](#understanding-hist2cell-training)
2. [**Environment Setup & Reproducibility**](#environment-setup--reproducibility)
3. [**Model Architecture Deep Dive**](#model-architecture-deep-dive)
4. [**Data Loading & Preparation**](#data-loading--preparation)
5. [**Training Configuration**](#training-configuration)
6. [**Training Process & Monitoring**](#training-process--monitoring)
7. [**Results Analysis & Next Steps**](#results-analysis--next-steps)

## 🎯 What You'll Learn

By the end of this tutorial, you will:
- ✅ Understand the complete Hist2Cell training pipeline
- ✅ Master multi-scale graph neural network training
- ✅ Learn advanced spatial transcriptomics modeling techniques
- ✅ Implement robust model evaluation and validation
- ✅ Optimize training for your own datasets

## 🧬 Understanding Hist2Cell Training

**Hist2Cell** represents a paradigm shift in spatial biology analysis by directly predicting cell type abundances from histology images without requiring explicit gene expression measurements at inference time.

### 🔬 **The Challenge**
Traditional spatial transcriptomics analysis requires:
- **Expensive sequencing**: Costly gene expression measurements
- **Limited resolution**: Constrained by sequencing spot size
- **Processing time**: Long sequencing and analysis pipelines

### 🎯 **The Hist2Cell Solution**
Our approach enables:
- **Cost-effective prediction**: Uses only histology images
- **High resolution**: Patch-level predictions at any scale
- **Fast inference**: Real-time analysis of tissue slides
- **Broad applicability**: Works across different tissue types

### 🏗️ **Model Innovation**
Hist2Cell combines three powerful approaches:
1. **Local Feature Extractor**: ResNet18 extracts visual features from tissue patches
2. **Graph Neural Networks**: Model spatial relationships between tissue regions
3. **Vision Transformers**: Capture global tissue context and patterns

## 📊 **Training Strategy**

### 🎯 **Multi-scale Learning**
- **Spot-level**: Individual patch analysis
- **Local-level**: Neighborhood pattern recognition
- **Global-level**: Tissue-wide context understanding
- **Fusion-level**: Integrated multi-scale predictions

### 📈 **Evaluation Strategy**
- **Donor-based splits**: Test generalization across individuals
- **Multi-metric evaluation**: Loss, correlation, and biological validation
- **Real-time monitoring**: Track training progress and convergence

## 🔧 **Prerequisites**

Before starting, ensure you have:
- **Processed data**: From the data preparation tutorial
- **GPU access**: 8GB+ VRAM recommended (24GB+ ideal)
- **Python environment**: With PyTorch, PyTorch Geometric, and dependencies

## 1. Environment Setup & Reproducibility

### 🎲 **Why Reproducibility Matters**

In machine learning research, **reproducibility** is crucial for:
- **Scientific validity**: Others can verify your results
- **Debugging**: Consistent behavior across runs
- **Comparison**: Fair evaluation of different approaches
- **Production deployment**: Predictable model behavior

### 🔧 **Random Seed Management**

We'll set random seeds for all major libraries to ensure deterministic behavior:

In [1]:
# Import necessary libraries
import random
import torch
import os
import numpy as np
import torch.utils.data

def setup_seed(seed):
    """
    Set random seeds for reproducibility across different libraries and frameworks.
    
    Args:
        seed (int): The random seed value to use
    
    Note:
        This function ensures that:
        - All random number generators use the same seed
        - Results are reproducible across different runs
        - Both CPU and GPU computations are deterministic
    """
    torch.manual_seed(seed)                    # Set PyTorch random seed
    os.environ['PYTHONHASHSEED'] = str(seed)   # Set Python hash seed
    torch.cuda.manual_seed(seed)               # Set CUDA random seed for current GPU
    torch.cuda.manual_seed_all(seed)           # Set CUDA random seed for all GPUs
    np.random.seed(seed)                       # Set NumPy random seed
    random.seed(seed)                          # Set Python random seed
    torch.backends.cudnn.benchmark = False    # Disable cuDNN benchmark for reproducibility
    torch.backends.cudnn.deterministic = True # Use deterministic algorithms
    torch.backends.cudnn.enabled = True       # Enable cuDNN

# Set random seed to 3407 (a commonly used seed value in ML research)
setup_seed(3407)
print("✓ Random seed setup completed successfully!")

✓ Random seed setup completed successfully!


## 💻 Device Configuration & GPU Setup

### 🎯 **Why GPU Training Matters**

Training Hist2Cell involves processing:
- **Large image tensors**: 224×224×3 patches for hundreds of spots
- **Complex graph operations**: Spatial relationships between neighboring regions
- **Transformer computations**: Self-attention mechanisms across tissue contexts
- **Multi-scale processing**: Simultaneous spot, local, and global analysis

**GPU acceleration** provides:
- **10-50x speedup**: Compared to CPU training
- **Parallel processing**: Handle multiple spots simultaneously
- **Memory efficiency**: Large VRAM for batch processing
- **Optimized operations**: Specialized CUDA kernels for deep learning

### 🔧 **GPU Requirements**

| GPU Memory | Recommended Batch Size | Training Time (5 epochs) |
|------------|------------------------|---------------------------|
| **8GB** | `subgraph_bs=4` | ~60 minutes |
| **12GB** | `subgraph_bs=8` | ~45 minutes |
| **16GB** | `subgraph_bs=12` | ~35 minutes |
| **24GB+** | `subgraph_bs=16` | ~30 minutes |

### ⚡ **Device Detection & Configuration**

Let's configure the optimal computing device for training:

In [2]:
# Import required libraries
import os
import torch
import warnings

# Suppress warning messages for cleaner output
warnings.filterwarnings('ignore')

# Configure GPU settings
gpu_list = [0]  # Use GPU 0 (first GPU). Change this if you want to use different GPUs
gpu_list_str = ','.join(map(str, gpu_list))
os.environ.setdefault("CUDA_VISIBLE_DEVICES", gpu_list_str)

# Set device for training (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Display device information
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("Warning: CUDA not available. Training will be slower on CPU.")

device

Using device: cuda
GPU Name: NVIDIA GeForce RTX 3090
GPU Memory: 23.7 GB


device(type='cuda')

## 2. Model Definition

### Understanding the Hist2Cell Architecture

The Hist2Cell model is a sophisticated neural network that combines multiple components:

1. **ResNet18 Backbone**: Extracts 512-dimensional features from histology image patches
2. **Graph Attention Network (GAT)**: Captures spatial relationships between neighboring spots
3. **Vision Transformer (ViT)**: Processes global context across the entire tissue slide
4. **Multi-level Prediction Fusion**: Combines spot-level, local, global, and fused predictions

**Key Innovation**: Unlike single-scale approaches, Hist2Cell leverages information at multiple spatial scales to achieve more accurate cell type abundance predictions.

Let's define the model architecture:

In [3]:
# Import necessary modules for model definition
from torch.nn import Linear
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn import GATv2Conv, LayerNorm
import sys, os

# Add parent directory to path to import custom modules
sys.path.append(os.path.dirname(os.getcwd()))
from model.ViT import Mlp, VisionTransformer

class Hist2Cell(nn.Module):
    """
    Hist2Cell model for predicting cell type abundances from histology images.
    
    This model combines multiple approaches:
    1. ResNet18 for image feature extraction
    2. Graph Attention Network for spatial relationship modeling
    3. Vision Transformer for global context understanding
    4. Multi-level prediction fusion
    """
    
    def __init__(self, cell_dim=80, vit_depth=3):
        """
        Initialize the Hist2Cell model.
        
        Args:
            cell_dim (int): Number of cell types to predict (default: 80)
            vit_depth (int): Depth of Vision Transformer (default: 3)
        """
        super(Hist2Cell, self).__init__()
        
        # Load pre-trained ResNet18 and remove the final classification layer
        self.resnet18 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.resnet18 = torch.nn.Sequential(*list(self.resnet18.children())[:-1])
        
        # Model hyperparameters
        self.embed_dim = 32 * 8  # Embedding dimension (256)
        self.head = 8            # Number of attention heads
        self.dropout = 0.3       # Dropout rate
        
        # Graph Attention Network for local spatial relationships
        self.conv1 = GATv2Conv(
            in_channels=512,  # ResNet18 output features
            out_channels=int(self.embed_dim/self.head),  # 32 features per head
            heads=self.head   # 8 attention heads
        )
        self.norm1 = LayerNorm(in_channels=self.embed_dim)
        
        # Vision Transformer for global context
        self.cell_transformer = VisionTransformer(
            num_classes=cell_dim,
            embed_dim=self.embed_dim,
            depth=vit_depth,
            mlp_head=True,
            drop_rate=self.dropout,
            attn_drop_rate=self.dropout
        )
        
        # Prediction heads for different levels
        self.spot_fc = Linear(in_features=512, out_features=256)
        self.spot_head = Mlp(in_features=256, hidden_features=512*2, out_features=cell_dim)
        self.local_head = Mlp(in_features=256, hidden_features=512*2, out_features=cell_dim)
        self.fused_head = Mlp(in_features=256, hidden_features=512*2, out_features=cell_dim)
    
    def forward(self, x, edge_index):
        """
        Forward pass of the Hist2Cell model.
        
        Args:
            x (torch.Tensor): Input histology images
            edge_index (torch.Tensor): Graph edge indices for spatial relationships
            
        Returns:
            torch.Tensor: Predicted cell type abundances
        """
        # Extract features using ResNet18
        x_spot = self.resnet18(x)
        x_spot = x_spot.squeeze()
        
        # Process with Graph Attention Network
        x_local = self.conv1(x=x_spot, edge_index=edge_index)
        x_local = self.norm1(x_local)
        
        # Prepare for Vision Transformer
        x_local = x_local.unsqueeze(0)
        x_cell = x_local
        
        # Generate predictions at different levels
        # 1. Spot-level prediction
        x_spot = self.spot_fc(x_spot)
        cell_predication_spot = self.spot_head(x_spot)
        
        # 2. Local-level prediction (GAT output)
        x_local = x_local.squeeze(0)
        cell_prediction_local = self.local_head(x_local)
        
        # 3. Global-level prediction (Vision Transformer)
        cell_prediction_global, x_global = self.cell_transformer(x_cell)
        cell_prediction_global = cell_prediction_global.squeeze()
        x_global = x_global.squeeze()
        
        # 4. Fused prediction (average of all feature representations)
        cell_prediction_fused = self.fused_head((x_spot + x_local + x_global) / 3.0)
        
        # Final prediction: average of all four prediction levels
        cell_prediction = (cell_predication_spot + cell_prediction_local + 
                          cell_prediction_global + cell_prediction_fused) / 4.0
        
        # Apply ReLU activation to ensure non-negative cell abundances
        cell_prediction = torch.relu(cell_prediction)
        
        return cell_prediction

# Initialize the model
print("Initializing Hist2Cell model...")
model = Hist2Cell(vit_depth=3)
model = model.to(device)

# Display model information
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"✓ Model initialized successfully!")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model device: {next(model.parameters()).device}")

Initializing Hist2Cell model...
✓ Model initialized successfully!
Total parameters: 14,232,064
Trainable parameters: 14,232,064
Model device: cuda:0


## 3. Data Preparation

### Understanding the Dataset Split

We use a **donor-based split** strategy for our training and testing:

- **Training**: 3 donors (9 slides total) - Donors A37, A32, A28
- **Testing**: 1 donor (2 slides total) - Donor A50

This approach ensures that the model is tested on completely unseen donor data, which is crucial for evaluating generalization capability across different individuals and biological conditions.

#### Dataset Overview:

**Test slides (Donor A50):**
- WSA_LngSP9258463
- WSA_LngSP9258467

**Training slides (Other 3 donors):**
- WSA_LngSP8759311
- WSA_LngSP8759312
- WSA_LngSP8759313
- WSA_LngSP9258464
- WSA_LngSP9258468
- WSA_LngSP10193347
- WSA_LngSP10193348
- WSA_LngSP10193345
- WSA_LngSP10193346

### Loading Train/Test Split Files

In [4]:
# Load train/test split files
print("Loading train/test split files...")

# Read training slide names (leave-one-donor-out strategy)
train_slides = open("../train_test_splits/humanlung_cell2location/train_leave_A50.txt").read().split('\n')

# Read testing slide names (donor A50 only)
test_slides = open("../train_test_splits/humanlung_cell2location/test_leave_A50.txt").read().split('\n')

# Display dataset information
print(f"Training slides: {len(train_slides)} slides")
print(f"Testing slides: {len(test_slides)} slides")
print(f"\nTraining slides: {train_slides}")
print(f"Testing slides: {test_slides}")
print("✓ Split files loaded successfully!")

Loading train/test split files...
Training slides: 9 slides
Testing slides: 2 slides

Training slides: ['WSA_LngSP8759311', 'WSA_LngSP8759312', 'WSA_LngSP8759313', 'WSA_LngSP9258464', 'WSA_LngSP9258468', 'WSA_LngSP10193347', 'WSA_LngSP10193348', 'WSA_LngSP10193345', 'WSA_LngSP10193346']
Testing slides: ['WSA_LngSP9258463', 'WSA_LngSP9258467']
✓ Split files loaded successfully!


### Loading Processed Graph Data

Each slide has been preprocessed into a graph structure where:
- **Nodes**: Represent spatial spots with histology image patches
- **Edges**: Connect neighboring spots based on spatial proximity
- **Node Features**: Include both image features and spatial coordinates
- **Labels**: Cell type abundance values for each spot

The data is stored in PyTorch Geometric format (`.pt` files) for efficient loading and processing.

In [5]:
# Import PyTorch Geometric utilities
from torch_geometric.data import Batch

print("Loading processed graph data...")

# Load training data
train_graph_list = []
print("Loading training graphs:")
for item in train_slides:
    if item:  # Skip empty strings
        graph_path = os.path.join("../example_data/humanlung_cell2location", item + '.pt')
        graph = torch.load(graph_path)
        train_graph_list.append(graph)
        print(f"  - {item}: {graph.num_nodes} nodes, {graph.num_edges} edges")

# Combine all training graphs into a single batch
train_dataset = Batch.from_data_list(train_graph_list)

# Load testing data
test_graph_list = []
print("\nLoading testing graphs:")
for item in test_slides:
    if item:  # Skip empty strings
        graph_path = os.path.join("../example_data/humanlung_cell2location", item + '.pt')
        graph = torch.load(graph_path)
        test_graph_list.append(graph)
        print(f"  - {item}: {graph.num_nodes} nodes, {graph.num_edges} edges")

# Combine all testing graphs into a single batch
test_dataset = Batch.from_data_list(test_graph_list)

# Display dataset statistics
print(f"\n📊 Dataset Statistics:")
print(f"Training dataset:")
print(f"  - Total nodes: {train_dataset.num_nodes}")
print(f"  - Total edges: {train_dataset.num_edges}")
print(f"  - Node features: {train_dataset.x.shape}")
print(f"  - Labels: {train_dataset.y.shape}")

print(f"\nTesting dataset:")
print(f"  - Total nodes: {test_dataset.num_nodes}")
print(f"  - Total edges: {test_dataset.num_edges}")
print(f"  - Node features: {test_dataset.x.shape}")
print(f"  - Labels: {test_dataset.y.shape}")

print("✓ Graph data loaded successfully!")

Loading processed graph data...
Loading training graphs:
  - WSA_LngSP8759311: 2251 nodes, 15033 edges
  - WSA_LngSP8759312: 2564 nodes, 17216 edges
  - WSA_LngSP8759313: 2001 nodes, 13477 edges
  - WSA_LngSP9258464: 2323 nodes, 15413 edges
  - WSA_LngSP9258468: 2285 nodes, 14999 edges
  - WSA_LngSP10193347: 1937 nodes, 12171 edges
  - WSA_LngSP10193348: 937 nodes, 5903 edges
  - WSA_LngSP10193345: 3234 nodes, 21332 edges
  - WSA_LngSP10193346: 2430 nodes, 16300 edges

Loading testing graphs:
  - WSA_LngSP9258463: 386 nodes, 2442 edges
  - WSA_LngSP9258467: 422 nodes, 2732 edges

📊 Dataset Statistics:
Training dataset:
  - Total nodes: 19962
  - Total edges: 131844
  - Node features: torch.Size([19962, 3, 224, 224])
  - Labels: torch.Size([19962, 330])

Testing dataset:
  - Total nodes: 808
  - Total edges: 5174
  - Node features: torch.Size([808, 3, 224, 224])
  - Labels: torch.Size([808, 330])
✓ Graph data loaded successfully!


### DataLoader Configuration

For efficient training on large graphs, we use **subgraph sampling** with `NeighborLoader`. This approach:
- Samples subgraphs around center nodes
- Reduces memory usage by processing smaller graph portions
- Maintains spatial context through neighborhood sampling

#### Key Parameters:

- **`hop`**: Defines the receptive field (neighborhood size)
  - `hop=2` means we include 2-hop neighbors around each center node
  - Larger hop → more context but higher memory usage
  - We use 2-hop for optimal balance between performance and efficiency

- **`subgraph_bs`**: Subgraph batch size
  - Number of center nodes processed simultaneously
  - `subgraph_bs=16` works well on RTX 3090 GPU (24GB memory)
  - For smaller GPUs: try `subgraph_bs=8` (12GB) or `subgraph_bs=4` (8GB)
  - Larger values improve training efficiency but require more memory

- **`num_neighbors`**: Number of neighbors to sample at each hop
  - `[-1]` means sample all neighbors (no sampling)
  - Can use smaller values like `[10, 5]` for very large graphs

In [None]:
# Import necessary modules
from torch_geometric.loader import NeighborLoader
import torch_geometric

# Disable PyG lib for compatibility
torch_geometric.typing.WITH_PYG_LIB = False

print("Configuring DataLoaders...")

# DataLoader parameters
hop = 2              # 2-hop neighborhood sampling
subgraph_bs = 16     # Subgraph batch size (adjust based on GPU memory)

print(f"Parameters:")
print(f"  - Hop distance: {hop}")
print(f"  - Subgraph batch size: {subgraph_bs}")
print(f"  - Neighbor sampling: all neighbors ([-1] * {hop})")

# Create training data loader
train_loader = NeighborLoader(
    train_dataset,
    num_neighbors=[-1] * hop,    # Sample all neighbors at each hop
    batch_size=subgraph_bs,      # Number of center nodes per batch
    directed=False,              # Undirected graph
    input_nodes=None,            # Use all nodes as potential centers
    shuffle=True,                # Shuffle for training
    num_workers=2,               # Parallel data loading
)

# Create testing data loader
test_loader = NeighborLoader(
    test_dataset,
    num_neighbors=[-1] * hop,    # Sample all neighbors at each hop
    batch_size=subgraph_bs,      # Number of center nodes per batch
    directed=False,              # Undirected graph
    input_nodes=None,            # Use all nodes as potential centers
    shuffle=False,               # No shuffling for testing
    num_workers=2,               # Parallel data loading
)

# Display loader information
print(f"\n📋 DataLoader Information:")
print(f"Training loader:")
print(f"  - Approximate batches per epoch: {len(train_loader)}")
print(f"  - Shuffling: {train_loader.shuffle}")

print(f"\nTesting loader:")
print(f"  - Approximate batches per epoch: {len(test_loader)}")
print(f"  - Shuffling: {test_loader.shuffle}")

print("✓ DataLoaders configured successfully!")

## 4. Training Configuration

### Setting Up Training Components

Before we start training, we need to configure several key components:

#### 🎯 Loss Function
- **MSE Loss**: Mean Squared Error for regression (predicting continuous cell abundance values)

#### 🚀 Optimizer
- **Adam**: Adaptive learning rate optimizer
- **Learning Rate**: `1e-4` (0.0001) - a good starting point for most deep learning tasks
- **Weight Decay**: `1e-4` for regularization to prevent overfitting

#### 📈 Learning Rate Scheduler
- **CosineAnnealingLR**: Gradually decreases learning rate following a cosine curve
- **T_max**: 20 epochs for the annealing schedule
- **eta_min**: Minimum learning rate of `1e-5`

These settings have been empirically validated and work well for the Hist2Cell model.

In [None]:
# Training configuration
print("Setting up training configuration...")

# Learning rate
lr = 1e-4
print(f"Learning rate: {lr}")

# Get model parameters
params = model.parameters()

# Define loss function (Mean Squared Error for regression)
criterion = nn.MSELoss().to(device)
print(f"Loss function: MSE Loss")

# Define optimizer (Adam with weight decay for regularization)
optimizer = torch.optim.Adam(
    params, 
    lr=lr, 
    weight_decay=1e-4  # L2 regularization
)
print(f"Optimizer: Adam (lr={lr}, weight_decay=1e-4)")

# Define learning rate scheduler (Cosine Annealing)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=20,           # Maximum number of epochs for annealing
    eta_min=1e-5,       # Minimum learning rate
    last_epoch=-1,      # Start from the beginning
    verbose=False       # Don't print lr updates
)
print(f"Scheduler: CosineAnnealingLR (T_max=20, eta_min=1e-5)")

print("✓ Training configuration completed!")

### Training Loop

Now we're ready to start training! This section will:

1. **Train the model** for multiple epochs
2. **Monitor performance** on both training and test sets
3. **Save the best model** based on test performance
4. **Track metrics** including loss and Pearson correlation

#### Key Metrics:
- **Loss**: MSE loss for training optimization
- **Pearson R**: Correlation coefficient measuring prediction quality
- **Best model**: Saved based on highest test Pearson R

#### Training Process:
- **Epochs**: 5 (for demonstration - increase to 20-50 for production)
- **Evaluation**: After each epoch on both training and test sets
- **Model Saving**: Only when test Pearson R improves (best model checkpoint)
- **Progress Tracking**: Real-time loss and correlation monitoring

Let's start training!

In [None]:
# Import necessary libraries for training
import numpy as np
from scipy.stats import pearsonr
import time

# Training parameters
num_epochs = 5
print(f"🚀 Starting training for {num_epochs} epochs...\n")

# Initialize tracking variables
best_cell_abundance_all_average = 0.0
since = time.time()

# Create directory for saving model weights
os.makedirs("../model_weights", exist_ok=True)

# Main training loop
for epoch in range(num_epochs):
    print("=" * 100)
    print(f'📅 Epoch: {epoch + 1}/{num_epochs}')
    print(f'📊 Current learning rate: {optimizer.param_groups[0]["lr"]:.2e}')
    
    # =============================================================================
    # TRAINING PHASE
    # =============================================================================
    model.train()  # Set model to training mode
    
    # Initialize training metrics
    train_sample_num = 0
    train_cell_pred_array = []
    train_cell_label_array = []
    train_cell_abundance_loss = 0
    
    print("🏋️ Training...")
    for batch_idx, graph in enumerate(train_loader):
        # Move data to device
        x = graph.x.to(device)
        y = graph.y.to(device)
        edge_index = graph.edge_index.to(device)
        
        # Extract cell abundance labels (columns 250 onwards)
        cell_label = y[:, 250:]
        
        # Forward pass
        cell_pred = model(x=x, edge_index=edge_index)
        
        # Calculate loss
        cell_loss = criterion(cell_pred, cell_label)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        cell_loss.backward()
        optimizer.step()
        
        # Collect predictions and labels for evaluation
        center_num = len(graph.input_id)  # Number of center nodes
        center_cell_label = cell_label[:center_num, :]
        center_cell_pred = cell_pred[:center_num, :]
        
        # Store predictions and labels
        train_cell_label_array.append(center_cell_label.squeeze().cpu().detach().numpy())
        train_cell_pred_array.append(center_cell_pred.squeeze().cpu().detach().numpy())
        train_sample_num += center_num
        train_cell_abundance_loss += cell_loss.item() * center_num
        
        # Print progress
        if (batch_idx + 1) % 20 == 0:
            print(f"  Batch {batch_idx + 1}/{len(train_loader)}, Loss: {cell_loss.item():.6f}")
    
    # Calculate average training loss
    train_cell_abundance_loss = train_cell_abundance_loss / train_sample_num
    
    # Prepare arrays for correlation calculation
    if len(train_cell_pred_array[-1].shape) == 1:
        train_cell_pred_array[-1] = np.expand_dims(train_cell_pred_array[-1], axis=0)
    train_cell_pred_array = np.concatenate(train_cell_pred_array)
    
    if len(train_cell_label_array[-1].shape) == 1:
        train_cell_label_array[-1] = np.expand_dims(train_cell_label_array[-1], axis=0)
    train_cell_label_array = np.concatenate(train_cell_label_array)
    
    # Calculate average Pearson correlation across all cell types
    train_cell_abundance_all_pearson_average = 0.0
    for i in range(train_cell_pred_array.shape[1]):
        r, p = pearsonr(train_cell_pred_array[:, i], train_cell_label_array[:, i])
        train_cell_abundance_all_pearson_average += r
    train_cell_abundance_all_pearson_average /= train_cell_pred_array.shape[1]
    
    # Update learning rate
    scheduler.step()
    
    # =============================================================================
    # EVALUATION PHASE
    # =============================================================================
    print("🧪 Evaluating...")
    with torch.no_grad():
        model.eval()  # Set model to evaluation mode
        
        # Initialize evaluation metrics
        test_sample_num = 0
        test_cell_pred_array = []
        test_cell_label_array = []
        test_cell_abundance_loss = 0
        
        for graph in test_loader:
            # Move data to device
            x = graph.x.to(device)
            y = graph.y.to(device)
            edge_index = graph.edge_index.to(device)
            
            # Extract cell abundance labels
            cell_label = y[:, 250:]
            
            # Forward pass (no gradients needed)
            cell_pred = model(x=x, edge_index=edge_index)
            
            # Calculate loss
            cell_loss = criterion(cell_pred, cell_label)
            
            # Collect predictions and labels
            center_num = len(graph.input_id)
            center_cell_label = cell_label[:center_num, :]
            center_cell_pred = cell_pred[:center_num, :]
            
            test_cell_label_array.append(center_cell_label.squeeze().cpu().detach().numpy())
            test_cell_pred_array.append(center_cell_pred.squeeze().cpu().detach().numpy())
            test_sample_num += center_num
            test_cell_abundance_loss += cell_loss.item() * center_num
        
        # Calculate average test loss
        test_cell_abundance_loss = test_cell_abundance_loss / test_sample_num
    
    # Prepare arrays for correlation calculation
    if len(test_cell_pred_array[-1].shape) == 1:
        test_cell_pred_array[-1] = np.expand_dims(test_cell_pred_array[-1], axis=0)
    test_cell_pred_array = np.concatenate(test_cell_pred_array)
    
    if len(test_cell_label_array[-1].shape) == 1:
        test_cell_label_array[-1] = np.expand_dims(test_cell_label_array[-1], axis=0)
    test_cell_label_array = np.concatenate(test_cell_label_array)
    
    # Calculate average Pearson correlation across all cell types
    test_cell_abundance_all_pearson_average = 0.0
    for i in range(test_cell_pred_array.shape[1]):
        r, p = pearsonr(test_cell_pred_array[:, i], test_cell_label_array[:, i])
        test_cell_abundance_all_pearson_average += r
    test_cell_abundance_all_pearson_average /= test_cell_pred_array.shape[1]
    
    # =============================================================================
    # MODEL SAVING AND LOGGING
    # =============================================================================
    
    # Save model if test performance improved
    if test_cell_abundance_all_pearson_average > best_cell_abundance_all_average:
        best_cell_abundance_all_average = test_cell_abundance_all_pearson_average
        torch.save(model.state_dict(), os.path.join("../model_weights", "demo_ckpt.pth"))
        print(f"💾 New best model saved! Test Pearson R: {test_cell_abundance_all_pearson_average:.6f}")
    
    # Calculate and display timing
    time_elapsed = time.time() - since
    print(f'⏱️ Training complete in {(time_elapsed // 60):.0f}m {(time_elapsed % 60):.0f}s')
    
    # Display epoch results
    print(f'\n📈 Epoch {epoch + 1} Results:')
    print(f'  Training Loss: {train_cell_abundance_loss:.6f}')
    print(f'  Training Pearson R: {train_cell_abundance_all_pearson_average:.6f}')
    print(f'  Test Loss: {test_cell_abundance_loss:.6f}')
    print(f'  Test Pearson R: {test_cell_abundance_all_pearson_average:.6f}')
    print(f'  Best Test Pearson R: {best_cell_abundance_all_average:.6f}')
    print()

print("🎉 Training completed successfully!")
print(f"🏆 Best test Pearson R achieved: {best_cell_abundance_all_average:.6f}")
print(f"💾 Best model saved to: ../model_weights/demo_ckpt.pth")

------------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch: 1 	
lr =  9.94459753267812e-05
saving best cell all abundance average 0.25637321132934027
Training complete in 6m 35s
Epoch: 1 	Training Cell abundance Loss: 0.083641
Epoch: 1 	Training Cell abundance pearson all average: 0.807833
Epoch: 1 	Test Cell abundance Loss: 0.474147
Epoch: 1 	Test Cell abundance pearson all average: 0.256373
------------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch: 2 	
lr =  9.779754323328192e-05
Training complete in 13m 9s
Epoch: 2 	Training Cell abundance Loss: 0.055636
Epoch: 2 	Training Cell abundance pearson all average: 0.841353
Epoch: 2 	Test Cell abundance Loss: 0.492704
Epoch: 2 	Test Cell abundance pearson all average: 0.255606
----------------------------------------------------------

## 5. Tutorial Complete! 🎉

### What You've Accomplished

Congratulations! You've successfully completed the Hist2Cell training tutorial. Here's what you've learned:

✅ **Environment Setup**: Configured random seeds and GPU settings for reproducible training  
✅ **Model Architecture**: Understood the multi-component Hist2Cell model structure  
✅ **Data Loading**: Loaded and prepared graph-structured histology data  
✅ **Training Configuration**: Set up loss functions, optimizers, and learning rate schedulers  
✅ **Training Loop**: Implemented a complete training pipeline with evaluation and model saving  

### Key Results

- **Model**: Successfully trained Hist2Cell for cell type abundance prediction
- **Saved Model**: Best checkpoint saved to `../model_weights/demo_ckpt.pth`

### Next Steps

Now that you've trained your model, explore these additional tutorials to maximize your analysis:

#### 1. **Data Preparation**
- `../tutorial_data_preparation/data_preparation_tutorial.ipynb`
- Learn how to preprocess your own histology data for Hist2Cell training
- Understand the data structure and format requirements

#### 2. **Analysis and Evaluation**
Explore the comprehensive analysis tutorials in `../tutorial_analysis_evaluation/`:

- **`cell_abundance_visulization_tutorial.ipynb`**: Visualize predicted cell abundances for biological discovery
- **`key_cell_evaluation_tutorial.ipynb`**: Evaluate model performance on specific cell types of interest  
- **`cell_colocalization_tutorial.ipynb`**: Analyze spatial cell co-localization patterns
- **`super_resovled_cell_abundance_tutorial.ipynb`**: Generate super-resolved cell abundance maps

#### 3. **Model Improvements**
- **More Epochs**: Train for 20-50 epochs for better convergence (this demo used only 5 for quick demonstration)
- **Hyperparameter Tuning**: 
  - Learning rates: Try `5e-5`, `1e-4` (current), `2e-4`, `5e-4`
  - Subgraph batch sizes: Adjust `8`, `16` (current), `32` based on GPU memory
  - ViT depths: Experiment with `2`, `3` (current), `4`, `5` layers
  - Model dimensions: Try `cell_dim=80` (current) or adjust for your dataset
- **Custom Datasets**: Apply the data preparation pipeline to your own histology data

#### 4. **Advanced Usage**
- **Fine-tuning**: Adapt pre-trained models to new datasets
- **Integration**: Incorporate Hist2Cell into your spatial biology analysis pipeline
- **Batch Processing**: Process multiple slides efficiently

### Available Resources

This project provides:
- **Pre-trained Models**: Check `../model_weights/` for trained checkpoints
- **Example Data**: Processed datasets in `../example_data/`
- **Train/Test Splits**: Donor-based splits in `../train_test_splits/`
- **Paper Reference**: [Hist2Cell: Deciphering Fine-grained Cellular Architectures from Histology Images](https://www.biorxiv.org/content/10.1101/2024.02.17.580852v1.full.pdf)

### Important Notes

- **Demo Limitations**: This demo used only 5 epochs for quick demonstration (~30 minutes on RTX 3090)
- **Production Training**: For actual research, train for 20-50 epochs depending on convergence (several hours)
- **GPU Requirements**: Training requires significant GPU memory (≥8GB recommended, 24GB+ ideal)
- **Performance Monitoring**: Monitor both training and test metrics to avoid overfitting
- **Data Split Strategy**: The donor-based split ensures realistic evaluation of model generalization capability

### Citation

If you use this code in your research, please cite:
```bibtex
@article{zhao2024hist2cell,
  title={Hist2Cell: Deciphering Fine-grained Cellular Architectures from Histology Images},
  author={Zhao, Weiqin and Liang, Zhuo and Huang, Xianjie and Huang, Yuanhua and Yu, Lequan},
  journal={bioRxiv},
  pages={2024--02},
  year={2024},
  publisher={Cold Spring Harbor Laboratory}
}
```

Happy training and analyzing! 🚀
