# Hist2Cell Data Preparation Tutorial

## Welcome to the Hist2Cell Data Preparation Guide! 🧬

This comprehensive tutorial will guide you through the essential process of preparing your spatial transcriptomics data for **Hist2Cell** training and inference. Whether you're new to spatial biology or an experienced researcher, this guide provides step-by-step instructions to transform your raw histology data into the format required by our Vision Graph-Transformer framework.

## 📋 Table of Contents

1. [**Understanding Spatial Transcriptomics Data**](#understanding-spatial-transcriptomics-data)
2. [**Hist2Cell Data Structure Overview**](#hist2cell-data-structure-overview)
3. [**Exploring Processed Data Examples**](#exploring-processed-data-examples)
4. [**Raw Data Structure Explanation**](#raw-data-structure-explanation)
5. [**Step-by-Step Data Processing Pipeline**](#step-by-step-data-processing-pipeline)
6. [**Graph Construction and Validation**](#graph-construction-and-validation)
7. [**Troubleshooting and Best Practices**](#troubleshooting-and-best-practices)

## 🎯 What You'll Learn

By the end of this tutorial, you will:
- ✅ Understand the spatial transcriptomics data structure used in Hist2Cell
- ✅ Learn how to process raw histology images and spatial gene expression data
- ✅ Master the graph construction process for spatial neighborhood relationships
- ✅ Prepare your own datasets for Hist2Cell training and inference
- ✅ Validate data quality and troubleshoot common issues

## 🔬 Understanding Spatial Transcriptomics Data

**Spatial Transcriptomics (ST)** is a cutting-edge technology that measures gene expression while preserving the spatial context of where each measurement was taken in tissue samples. Unlike traditional RNA-seq, which loses spatial information, ST provides:

- **Spatial Gene Expression**: Gene expression profiles at specific tissue locations
- **Histology Context**: High-resolution tissue images showing cellular architecture
- **Spatial Relationships**: Information about how different tissue regions interact

### Key Concepts:
- **Spots**: Discrete locations where gene expression is measured (think of them as pixels in a gene expression image)
- **Patches**: Image regions extracted around each spot from the histology slide
- **Cell Deconvolution**: The process of estimating cell type abundances from bulk measurements
- **Spatial Graphs**: Networks representing spatial relationships between neighboring spots

## 🏗️ Hist2Cell Architecture Overview

**Hist2Cell** is a Vision Graph-Transformer framework that predicts fine-grained cell type abundances directly from histology images. The model combines:

1. **Local Feature Extractor**: Analyzes histology image patches using ResNet18
2. **Graph Neural Networks**: Captures spatial relationships between neighboring tissue regions
3. **Transformer Architecture**: Processes global tissue context and patterns
4. **Multi-scale Fusion**: Integrates information from spot, local, and global levels

## 📊 Data Requirements

To prepare data for Hist2Cell, you need:

### Required Files:
- **Histology Images**: High-resolution tissue slide images (WSI format)
- **Image Patches**: 224×224 pixel patches extracted around each spot
- **Gene Expression Data**: Spatial gene expression measurements
- **Cell Type Abundances**: Fine-grained cell type composition (from deconvolution)
- **Spatial Coordinates**: X,Y positions of each spot on the slide

### Provided Example Data:
- `./example_data/humanlung_cell2location`: Processed data from healthy human lung tissue
- `./example_data/humanlung_cell2location_2x`: Super-resolution data (2x higher resolution)

Let's start by exploring the data structure and then learn how to process your own data!

# 🔍 Exploring Processed Data Examples

## Understanding the Hist2Cell Data Structure

Before diving into data processing, let's examine the structure of processed data that Hist2Cell expects. We'll use slide `WSA_LngSP9258467` from donor A50 in our healthy human lung dataset as an example.

### 📁 Data Format: PyTorch Geometric

Hist2Cell uses **PyTorch Geometric** format, which is specifically designed for graph neural networks. This format efficiently stores:
- **Node features** (image patches and coordinates)
- **Edge connections** (spatial relationships between neighboring spots)
- **Labels** (gene expression and cell type abundances)

### Loading a Processed Example

Let's load a processed data file and explore its structure:

In [None]:
import torch

processed_data = torch.load("../example_data/humanlung_cell2location/WSA_LngSP9258467.pt")
processed_data

## 📊 Data Structure Breakdown

The processed data contains the following key components:

### 🖼️ **Node Features (`x`)**: Image Patches
- **Shape**: `[422, 3, 224, 224]` 
- **Meaning**: 224×224 RGB image patches for each of the 422 spots in this slide
- **Purpose**: Each spot becomes a node in our spatial graph, and the image patch provides visual context
- **Format**: Standard PyTorch tensor format (C×H×W) with values normalized for ResNet18

### 🔗 **Graph Connectivity (`edge_index`)**: Spatial Relationships  
- **Shape**: `[2, 2732]` (COO format)
- **Meaning**: 2,732 edges connecting neighboring spots based on spatial proximity
- **Purpose**: Defines which spots are neighbors in the tissue, enabling spatial information flow
- **Format**: Each column represents an edge (source_node, target_node)

### 🏷️ **Labels (`y`)**: Multi-modal Ground Truth
- **Shape**: `[422, 330]` 
- **Composition**: 
  - **250 genes**: Top highly expressed genes (normalized expression values)
  - **80 cell types**: Fine-grained cell type abundances (from deconvolution)
- **Purpose**: Training targets for both gene expression and cell type prediction
- **Total**: 330 labels per spot (250 + 80)

### 📍 **Spatial Coordinates (`pos`)**: Physical Positions
- **Shape**: `[422, 2]`
- **Meaning**: X,Y pixel coordinates of each spot on the original histology slide
- **Purpose**: Used for visualization, spatial analysis, and cell co-localization metrics
- **Units**: Pixel coordinates on the original whole slide image (WSI)

### 💡 Key Insights:
- **Graph Structure**: This slide contains 422 spots (nodes) connected by 2,732 edges
- **Spatial Density**: Each spot has ~6.5 neighbors on average (2,732 × 2 ÷ 422)
- **Multi-scale Information**: Combines local image features with global spatial context
- **Rich Labels**: Both gene expression and cell type information for comprehensive training

Let's examine each component in detail:

In [None]:
# Examining the image patch dimensions for the first spot
print("📊 Image Patch Analysis:")
print(f"   Individual patch shape: {processed_data['x'][0].shape}")
print(f"   Format: [Channels, Height, Width]")
print(f"   Channels: {processed_data['x'][0].shape[0]} (RGB)")
print(f"   Image size: {processed_data['x'][0].shape[1]}×{processed_data['x'][0].shape[2]} pixels")
print(f"   Total patches: {processed_data['x'].shape[0]}")
print("\n💡 Each spot has a 224×224 RGB image patch extracted from the original slide")
print("   This provides visual context for the Hist2Cell model to analyze tissue architecture")

# Show the actual shape for reference
processed_data['x'][0].shape

In [None]:
# Examining the label structure for the first spot
print("🏷️ Label Analysis (First Spot):")
print(f"   Total labels per spot: {processed_data['y'].shape[1]}")
print(f"   Label composition: 250 genes + 80 cell types = 330 total")
print(f"   First 5 gene expression values: {processed_data['y'][0][:5].tolist()}")
print(f"   Sample cell type abundances (positions 250-255): {processed_data['y'][0][250:255].tolist()}")
print("\n💡 Label Structure:")
print("   - Positions 0-249: Top 250 highly expressed genes (log-normalized)")
print("   - Positions 250-329: 80 fine-grained cell type abundances")
print("   - Values represent normalized expression levels or abundance fractions")

# Show the actual first 5 values for reference
processed_data['y'][0][:5]

In [None]:
# Examining the graph connectivity structure
print("🔗 Graph Connectivity Analysis:")
print(f"   Edge index shape: {processed_data['edge_index'].shape}")
print(f"   Format: [2, num_edges] - COO (Coordinate) format")
print(f"   Total edges: {processed_data['edge_index'].shape[1]}")
print(f"   Source nodes (row 0): {processed_data['edge_index'][0][:10].tolist()}... (showing first 10)")
print(f"   Target nodes (row 1): {processed_data['edge_index'][1][:10].tolist()}... (showing first 10)")

# Calculate connectivity statistics
source_nodes = processed_data['edge_index'][0]
target_nodes = processed_data['edge_index'][1]
unique_nodes = torch.unique(torch.cat([source_nodes, target_nodes]))
avg_degree = processed_data['edge_index'].shape[1] / processed_data['x'].shape[0]

print(f"\n📊 Connectivity Statistics:")
print(f"   Total nodes: {processed_data['x'].shape[0]}")
print(f"   Connected nodes: {len(unique_nodes)}")
print(f"   Average degree: {avg_degree:.2f} neighbors per spot")
print(f"   Graph density: {(processed_data['edge_index'].shape[1] / (processed_data['x'].shape[0] * (processed_data['x'].shape[0] - 1))) * 100:.2f}%")

print("\n💡 The graph represents spatial relationships between neighboring spots")
print("   Each edge connects two spatially adjacent spots in the tissue")

# Show the actual edge_index for reference
processed_data['edge_index']

In [None]:
# Examining the spatial coordinate information
print("📍 Spatial Coordinates Analysis:")
print(f"   Position shape: {processed_data['pos'].shape}")
print(f"   Format: [num_spots, 2] - (X, Y) pixel coordinates")
print(f"   Sample coordinates (first 5 spots):")
for i in range(5):
    x, y = processed_data['pos'][i]
    print(f"     Spot {i}: ({x:.0f}, {y:.0f}) pixels")

# Calculate spatial extent
min_x, max_x = processed_data['pos'][:, 0].min(), processed_data['pos'][:, 0].max()
min_y, max_y = processed_data['pos'][:, 1].min(), processed_data['pos'][:, 1].max()
width = max_x - min_x
height = max_y - min_y

print(f"\n📏 Spatial Extent:")
print(f"   X range: {min_x:.0f} to {max_x:.0f} pixels (width: {width:.0f})")
print(f"   Y range: {min_y:.0f} to {max_y:.0f} pixels (height: {height:.0f})")
print(f"   Tissue area: {width:.0f} × {height:.0f} pixels")

# Calculate average spot spacing
import numpy as np
positions = processed_data['pos'].numpy()
distances = []
for i in range(min(100, len(positions))):  # Sample 100 spots for efficiency
    pos_i = positions[i]
    other_pos = positions[np.arange(len(positions)) != i]
    dists = np.sqrt(np.sum((other_pos - pos_i)**2, axis=1))
    distances.append(np.min(dists))
avg_spacing = np.mean(distances)

print(f"   Average spot spacing: {avg_spacing:.0f} pixels")

print("\n💡 Spatial coordinates are used for:")
print("   - Visualizing predictions on tissue slides")
print("   - Calculating cell co-localization metrics")
print("   - Constructing spatial neighborhood graphs")

# Show the actual first 5 positions for reference
processed_data['pos'][:5]

## 🔄 Subgraph Sampling with NeighborLoader

For efficient training on large graphs, we use **subgraph sampling** via `NeighborLoader` from PyTorch Geometric. This technique is crucial for handling large spatial transcriptomics datasets that may contain thousands of spots.

### 🎯 Why Subgraph Sampling?

**Challenge**: Processing entire graphs simultaneously requires enormous GPU memory and computational resources.

**Solution**: Sample smaller subgraphs around center nodes for training/testing, maintaining spatial context while reducing memory usage.

### 🔧 Key Parameters Explained

#### 1. **`hop`**: Receptive Field Size
- **Definition**: Number of neighborhood layers to include around center nodes
- **Our choice**: `hop=2` (2-hop subgraphs)
- **Meaning**: Include immediate neighbors (1-hop) + neighbors of neighbors (2-hop)
- **Trade-off**: 
  - ✅ Larger hop → More spatial context, better performance
  - ❌ Larger hop → Higher memory usage, slower training

#### 2. **`subgraph_bs`**: Subgraph Batch Size
- **Definition**: Number of center nodes (subgraphs) processed simultaneously
- **Our choice**: `subgraph_bs=16` (optimized for RTX 3090 24GB)
- **GPU Memory Guidelines**:
  - 🟢 RTX 3090 (24GB): `subgraph_bs=16`
  - 🟡 RTX 3080 (12GB): `subgraph_bs=8`
  - 🔴 RTX 3070 (8GB): `subgraph_bs=4`

#### 3. **`num_neighbors`**: Sampling Strategy
- **Our choice**: `[-1] * hop` (sample all neighbors)
- **Alternative**: `[20, 10]` (sample 20 at 1-hop, 10 at 2-hop)
- **Use case**: Limit neighbors for very dense graphs

### 💡 Benefits of This Approach:
- **Memory Efficiency**: Process large graphs on consumer GPUs
- **Spatial Context**: Maintain neighborhood relationships
- **Scalability**: Handle datasets with thousands of spots
- **Parallelization**: Batch multiple subgraphs for efficiency

Let's configure and test the DataLoader:

In [None]:
# Import required modules for subgraph sampling
from torch_geometric.loader import NeighborLoader
import torch_geometric
torch_geometric.typing.WITH_PYG_LIB = False

# Configure subgraph sampling parameters
hop = 2               # 2-hop neighborhood (optimal balance of context vs. memory)
subgraph_bs = 16      # Batch size for RTX 3090 (adjust based on your GPU memory)

print("⚙️ DataLoader Configuration:")
print(f"   Hop size: {hop} (include {hop}-hop neighbors)")
print(f"   Subgraph batch size: {subgraph_bs}")
print(f"   Total spots in slide: {processed_data.num_nodes}")
print(f"   Expected batches per epoch: {processed_data.num_nodes // subgraph_bs}")

# Create the NeighborLoader for subgraph sampling
dataloader_loader = NeighborLoader(
    processed_data,
    num_neighbors=[-1]*hop,    # Sample all neighbors at each hop
    batch_size=subgraph_bs,    # Number of center nodes per batch
    directed=False,            # Use undirected graph (spatial relationships are symmetric)
    input_nodes=None,          # Use all nodes as potential center nodes
    shuffle=True,              # Randomly sample center nodes for training
    num_workers=2,             # Parallel data loading workers
)

print("✅ NeighborLoader created successfully!")
print(f"   This will create subgraphs with {hop}-hop neighborhoods")
print(f"   Each batch contains {subgraph_bs} center nodes and their neighbors")
print(f"   The loader will iterate through all {processed_data.num_nodes} spots as center nodes")

## 🔄 Subgraph Batching and Merging

### How PyTorch Geometric Handles Subgraphs

In PyTorch Geometric, multiple sampled subgraphs are **merged into a single large graph** for efficient parallel processing. This approach:

1. **Batches Multiple Subgraphs**: Combines `subgraph_bs` subgraphs into one batch
2. **Maintains Separation**: Each subgraph remains disconnected from others in the batch
3. **Parallel Processing**: GPU can process all subgraphs simultaneously
4. **Memory Efficiency**: Reduces overhead compared to processing individual subgraphs

### 🔍 Key Features of Merged Subgraphs:

- **Node Remapping**: Original node IDs are preserved in `n_id` field
- **Edge Remapping**: Original edge IDs are preserved in `e_id` field  
- **Input Identification**: `input_id` tracks which nodes were center nodes
- **Batch Information**: `batch_size` indicates how many subgraphs are merged

### 📊 Expected Structure:
- **Nodes**: Total nodes from all subgraphs (varies based on neighborhood size)
- **Edges**: Total edges from all subgraphs (includes inter-subgraph disconnections)
- **Features**: Same format as original data but for merged subgraphs

Let's examine a sampled subgraph batch:

In [None]:
# Sample and analyze a subgraph batch
print("🔍 Analyzing Subgraph Sampling:")
print("=" * 50)

for subgraphs in dataloader_loader:
    print(f"📊 Subgraph Batch Structure:")
    print(f"   Merged subgraph: {subgraphs}")
    print(f"   Total nodes in batch: {subgraphs.num_nodes}")
    print(f"   Total edges in batch: {subgraphs.num_edges}")
    print(f"   Input nodes (centers): {subgraphs.input_id.shape[0]}")
    print(f"   Actual batch size: {subgraphs.batch_size}")
    
    # Analyze subgraph composition
    print(f"\n🔍 Subgraph Composition Analysis:")
    print(f"   Average nodes per subgraph: {subgraphs.num_nodes / subgraphs.batch_size:.1f}")
    print(f"   Average edges per subgraph: {subgraphs.num_edges / subgraphs.batch_size:.1f}")
    
    # Show center node information
    print(f"\n🎯 Center Node Information:")
    print(f"   Center node IDs (first 5): {subgraphs.input_id[:5].tolist()}")
    print(f"   Original node IDs (first 10): {subgraphs.n_id[:10].tolist()}")
    
    # Calculate neighborhood expansion
    expansion_ratio = subgraphs.num_nodes / subgraphs.batch_size
    print(f"\n📈 Neighborhood Expansion:")
    print(f"   Expansion ratio: {expansion_ratio:.2f}x")
    print(f"   Each center node includes ~{expansion_ratio:.0f} nodes (itself + neighbors)")
    
    print(f"\n💡 This batch contains {subgraphs.batch_size} subgraphs merged together")
    print(f"   Each subgraph represents a {hop}-hop neighborhood around a center node")
    print(f"   The model will process all {subgraphs.batch_size} subgraphs in parallel")
    
    # Show the raw output for reference
    print(f"\n📝 Raw Output:")
    print(subgraphs)
    break

# 🛠️ Step-by-Step Data Processing Pipeline

Now that we understand the target data structure, let's learn how to transform **raw spatial transcriptomics data** into the format required by Hist2Cell.

## 📂 Raw Data Structure Overview

We've provided example raw data for slide `WSA_LngSP9258467` in the directory:
```
./example_data/example_raw_data/WSA_LngSP9258467/
```

This represents a typical spatial transcriptomics dataset with all the necessary components for Hist2Cell processing.

## 📁 Raw Data Components Explained

### 🖼️ **Image Files**
| File | Purpose | Description |
|------|---------|-------------|
| `WSA_LngSP9258467.jpg` | Original slide image | High-resolution whole slide image (WSI) |
| `WSA_LngSP9258467_low_res.jpg` | Low-resolution image | Quick visualization and processing |
| `spot_view.jpg` | Spot visualization | Original slide with spot locations marked |
| `2x_spot_view.jpg` | Super-resolution spots | 2x resolution spot visualization |

### 📊 **Gene Expression Data**
| File | Content | Format | Use Case |
|------|---------|--------|----------|
| `stdata.csv` | Raw gene expression | Genes × Spots | Original measurements |
| `log1p_stdata.csv` | Log-normalized expression | Genes × Spots | Processed for analysis |
| `high_250_stdata.csv` | Top 250 genes | 250 × Spots | Highly expressed genes |
| `high_250_stdata_log1p.csv` | Log-normalized top 250 | 250 × Spots | Ready for training |

### 🎯 **Cell Type Information**
| File | Content | Description |
|------|---------|-------------|
| `cell_ratio.csv` | Cell type abundances | 80 fine-grained cell types per spot |

### 📍 **Spatial Information**
| File | Content | Description |
|------|---------|-------------|
| `spots.csv` | Spot coordinates | X,Y pixel positions on slide |
| `2x_spots.csv` | Super-resolution coordinates | 2x resolution positions |

### 🔲 **Image Patches**
| Folder | Content | Description |
|--------|---------|-------------|
| `patches/` | Image patches | 224×224 patches around each spot |
| `2x_patches/` | Super-resolution patches | 2x resolution patches |

## 🔧 **Image Patch Extraction**

**Important**: To extract image patches from whole slide images (WSI), we recommend using the **DSMIL pipeline**:
- **Repository**: [DSMIL-WSI](https://github.com/binli123/dsmil-wsi)
- **Function**: Automatically crops patches around spot coordinates
- **Output**: 224×224 pixel patches suitable for deep learning

### 🏗️ **Processing Workflow**

The data processing pipeline follows these steps:

1. **Image Patch Extraction**: Extract 224×224 patches from WSI around each spot
2. **Gene Expression Processing**: Normalize and filter gene expression data
3. **Cell Type Deconvolution**: Estimate cell type abundances (using tools like Cell2location)
4. **Spatial Graph Construction**: Build neighborhood graphs based on spot positions
5. **Data Integration**: Combine all components into PyTorch Geometric format

Let's implement this pipeline step by step!

## 📦 Step 1: Creating a Custom Dataset Class

To efficiently process our raw spatial transcriptomics data, we'll create a custom `STDataset` class that handles:

### 🎯 **STDataset Functionality**

1. **Data Loading**: Automatically loads image patches and corresponding labels
2. **Data Validation**: Ensures image patches and labels are properly aligned
3. **Preprocessing**: Applies image transformations (resize, normalize, etc.)
4. **Integration**: Combines gene expression and cell type abundance data

### 🔧 **Key Features**
- **Automatic Matching**: Matches image patches with corresponding gene/cell data
- **Flexible Transforms**: Supports PyTorch transforms for image preprocessing
- **Error Handling**: Handles missing data gracefully
- **Memory Efficient**: Loads data on-demand during iteration

### 📊 **Input Files Used**
- `patches/`: Image patches (224×224 pixels)
- `high_250_stdata.csv`: Top 250 gene expression values
- `cell_ratio.csv`: Cell type abundance ratios

Let's implement the STDataset class:

In [None]:
# Import required libraries for data processing
import os
import numpy as np
import pandas as pd
from torchvision import transforms
from PIL import Image

print("📦 Creating STDataset class for processing spatial transcriptomics data...")

class STDataset(torch.utils.data.Dataset):
    """
    Custom Dataset class for processing spatial transcriptomics data.
    
    This class handles loading and preprocessing of:
    - Image patches from histology slides
    - Gene expression data (top 250 genes)
    - Cell type abundance labels (80 cell types)
    """
    
    def __init__(self, root, slide, transform=None):
        """
        Initialize the STDataset.
        
        Args:
            root (str): Root directory containing the raw data
            slide (str): Slide name (e.g., 'WSA_LngSP9258467')
            transform: Optional transforms to apply to images
        """
        super(STDataset, self).__init__()
        self.root = root
        self.slide = slide
        self.transform = transform
        
        print(f"   📁 Loading data from: {os.path.join(root, slide)}")

        # 1. Load available image patches
        patch_path = os.path.join(root, slide, 'patches')
        patch_files = os.listdir(patch_path)
        patch_list = [x.split('.')[0] for x in patch_files]  # Remove .jpg extension
        print(f"   🖼️  Found {len(patch_list)} image patches")

        # 2. Load cell type abundance labels (80 fine-grained cell types)
        cell_label_path = os.path.join(root, slide, 'cell_ratio.csv')
        cell_label = pd.read_csv(cell_label_path, index_col=0)
        print(f"   🎯 Loaded cell type data: {cell_label.shape[1]} cell types")
        
        # 3. Load gene expression labels (top 250 genes)
        gene_label_path = os.path.join(root, slide, 'high_250_stdata.csv')
        gene_label = pd.read_csv(gene_label_path, index_col=0)
        print(f"   🧬 Loaded gene expression data: {gene_label.shape[1]} genes")

        # 4. Merge gene and cell type labels
        label_df = pd.merge(gene_label, cell_label, left_index=True, right_index=True)
        print(f"   🔗 Combined labels: {label_df.shape[1]} total (250 genes + 80 cell types)")

        # 5. Find intersection of spots with both images and labels
        label_index_set = set(label_df.index)
        patch_index_set = set(patch_list)
        valid_spots = label_index_set & patch_index_set
        
        print(f"   ✅ Valid spots (with both image and labels): {len(valid_spots)}")
        if len(valid_spots) < len(patch_list):
            print(f"   ⚠️  Missing labels for {len(patch_list) - len(valid_spots)} patches")
        if len(valid_spots) < len(label_df):
            print(f"   ⚠️  Missing images for {len(label_df) - len(valid_spots)} labels")

        # 6. Store final data
        self.patch = list(valid_spots)
        self.label_df = label_df.loc[self.patch]
        
        print(f"   📊 Dataset ready: {len(self.patch)} spots with complete data")

    def __getitem__(self, index):
        """
        Get a single data sample.
        
        Args:
            index (int): Index of the sample to retrieve
            
        Returns:
            tuple: (patch_id, image_data, labels)
        """
        # Get spot ID
        patch_id = self.patch[index]
        
        # Load and preprocess image patch
        patch_path = os.path.join(self.root, self.slide, 'patches', patch_id + '.jpg')
        patch = Image.open(patch_path).convert('RGB')
        
        # Resize to 224x224 (required for ResNet18)
        data = transforms.Resize((224, 224))(patch)
        
        # Apply additional transforms if provided
        if self.transform is not None:
            data = self.transform(data)
        
        # Get corresponding labels (250 genes + 80 cell types)
        label = self.label_df.loc[patch_id].values
        label = torch.Tensor(label)

        return patch_id, data, label

    def __len__(self):
        """Return the total number of samples in the dataset."""
        return len(self.patch)

print("✅ STDataset class defined successfully!")
print("   This class will automatically handle data loading and preprocessing")

In [None]:
print("🔄 Step 2: Setting up data transformations and dataset...")

# Define image preprocessing pipeline
print("   📋 Configuring image transformations:")
print("      - Convert PIL Image to PyTorch tensor")
print("      - Normalize with ImageNet statistics (required for ResNet18)")
print("      - Mean: [0.485, 0.456, 0.406] (ImageNet RGB means)")
print("      - Std:  [0.229, 0.224, 0.225] (ImageNet RGB stds)")

test_transform_pcam = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL Image to tensor and scale to [0,1]
    transforms.Normalize(   # Normalize with ImageNet statistics
        mean=[0.485, 0.456, 0.406],  # RGB channel means
        std=[0.229, 0.224, 0.225]    # RGB channel standard deviations
    )
])

print("\n🏗️ Creating dataset instance...")
# Create dataset instance
test_data = STDataset(
    root="../example_data/example_raw_data", 
    slide="WSA_LngSP9258467",
    transform=test_transform_pcam
)

print(f"\n📊 Dataset Statistics:")
print(f"   Total samples: {len(test_data)}")
print(f"   Data source: ../example_data/example_raw_data/WSA_LngSP9258467")
print(f"   Transform applied: ImageNet normalization")

print("\n🔄 Creating DataLoader...")
# Create DataLoader for batch processing
test_loader = torch.utils.data.DataLoader(
    test_data, 
    batch_size=512,      # Process 512 spots at once
    shuffle=False,       # Maintain order for reproducibility
    num_workers=4        # Use 4 parallel workers for faster loading
)

print(f"   DataLoader configuration:")
print(f"      Batch size: 512 spots")
print(f"      Total batches: {len(test_loader)}")
print(f"      Parallel workers: 4")
print(f"      Shuffle: False (maintains spot order)")

print("\n✅ Dataset and DataLoader ready for processing!")

## 📊 Step 3: Processing and Collecting Data

Now we'll iterate through our `STDataset` to collect all the necessary components:

### 🎯 **Data Collection Process**

1. **Image Patches**: Extract preprocessed 224×224 image patches
2. **Labels**: Collect gene expression + cell type abundance labels  
3. **Spot IDs**: Maintain spot identifiers for later graph construction

### 🔧 **Processing Steps**

- **Batch Processing**: Process data in batches of 512 for efficiency
- **Memory Management**: Handle varying batch sizes gracefully
- **Data Validation**: Ensure all arrays have consistent shapes
- **Progress Tracking**: Monitor processing progress

### 📋 **Expected Output**

After processing, we'll have:
- `spot_data_array`: Shape [num_spots, 3, 224, 224] - Image patches
- `spot_label_array`: Shape [num_spots, 330] - Gene + cell type labels
- `spot_id_array`: List of spot identifiers

Let's process the data:

In [None]:
print("🔄 Processing data through DataLoader...")
print("=" * 50)

# Initialize storage arrays
spot_data_array = []
spot_label_array = []
spot_id_array = []

# Process data in batches
batch_count = 0
total_spots_processed = 0

for name, data, label in test_loader:
    batch_count += 1
    batch_size = len(name)
    total_spots_processed += batch_size
    
    print(f"   📦 Processing batch {batch_count}/{len(test_loader)}: {batch_size} spots")
    
    # Store batch data
    spot_id_array.append(list(name))
    
    # Process labels (ensure float type and proper shape)
    label = label.float()
    label = label.squeeze()
    spot_label_array.append(label.detach().numpy())
    
    # Store image data
    spot_data_array.append(data.detach().numpy())

print(f"✅ Batch processing complete! Processed {total_spots_processed} spots")

print("\n🔧 Post-processing data arrays...")

# Handle single-item batches by adding batch dimension
print("   📏 Ensuring consistent array dimensions...")
for i in range(len(spot_data_array)):
    if len(spot_data_array[i].shape) <= 1:
        spot_data_array[i] = spot_data_array[i][np.newaxis, :]
        print(f"      Fixed data array {i} shape")

for i in range(len(spot_label_array)):
    if len(spot_label_array[i].shape) <= 1:
        spot_label_array[i] = spot_label_array[i][np.newaxis, :]
        print(f"      Fixed label array {i} shape")

# Concatenate all batch arrays into final arrays
print("   🔗 Concatenating batch arrays...")
spot_data_array = np.concatenate(spot_data_array)
spot_label_array = np.concatenate(spot_label_array)

# Flatten spot ID list
spot_ids = []
for ids in spot_id_array:
    spot_ids.extend(ids)
spot_id_array = spot_ids

print(f"✅ Data processing complete!")
print(f"   📊 Final array shapes:")
print(f"      Image data: {spot_data_array.shape}")
print(f"      Label data: {spot_label_array.shape}")
print(f"      Spot IDs: {len(spot_id_array)} identifiers")
print(f"   🎯 Data validation:")
print(f"      All arrays have {spot_data_array.shape[0]} spots")
print(f"      Labels contain {spot_label_array.shape[1]} values (250 genes + 80 cell types)")
print(f"      Image patches are {spot_data_array.shape[1]}×{spot_data_array.shape[2]}×{spot_data_array.shape[3]}")

In [None]:
# Display final processed data shape
print("📊 Final Processed Data Summary:")
print(f"   Image data shape: {spot_data_array.shape}")
print(f"   Format: [num_spots, channels, height, width]")
print(f"   Spots: {spot_data_array.shape[0]}")
print(f"   Channels: {spot_data_array.shape[1]} (RGB)")
print(f"   Image size: {spot_data_array.shape[2]}×{spot_data_array.shape[3]} pixels")
print(f"   Memory usage: {spot_data_array.nbytes / (1024**2):.1f} MB")

# Show the actual shape for reference
spot_data_array.shape

In [None]:
spot_data_array.shape

In [None]:
# Display sample spot IDs and their structure
print("🔍 Spot ID Analysis:")
print(f"   Total spot IDs: {len(spot_id_array)}")
print(f"   ID format: slide_spotbarcode")
print(f"   Sample IDs (first 5):")
for i, spot_id in enumerate(spot_id_array[:5]):
    print(f"      {i}: {spot_id}")

# Analyze ID structure
sample_id = spot_id_array[0]
slide_name = sample_id.split('_')[0] + '_' + sample_id.split('_')[1]
barcode = sample_id.split('_')[2]
print(f"\n📋 ID Structure Analysis:")
print(f"   Slide name: {slide_name}")
print(f"   Barcode format: {barcode}")
print(f"   Example full ID: {sample_id}")

print(f"\n💡 These IDs will be used to:")
print(f"   - Match spots with spatial coordinates")
print(f"   - Build spatial neighborhood graphs")
print(f"   - Track data provenance")

# Show the actual first 5 IDs for reference
spot_id_array[:5]

## 📍 Step 4: Loading Spatial Coordinate Information

To build the spatial neighborhood graph, we need to understand the **spatial layout** of spots on the tissue slide. 

### 🎯 **Coordinate Systems in Spatial Transcriptomics**

Spatial transcriptomics data uses **two coordinate systems**:

1. **Array Coordinates** (`array_col`, `array_row`):
   - **Grid-based positions** on the regular spot array
   - **Integer values** representing row/column positions
   - **Used for**: Determining spatial neighbors (adjacency)
   - **Example**: (col=27, row=29) means spot is in column 27, row 29 of the grid

2. **Pixel Coordinates** (`x`, `y`):
   - **Actual pixel positions** on the histology image
   - **Continuous values** in pixel units
   - **Used for**: Visualization and precise spatial analysis
   - **Example**: (x=12701, y=9136) means spot is at pixel (12701, 9136)

### 🔗 **Why Array Coordinates Matter**

- **Neighborhood Definition**: Spots are neighbors if their array coordinates are adjacent
- **Graph Construction**: Edges connect spots with nearby array positions
- **Regular Grid Structure**: Spatial transcriptomics uses hexagonal/square grid patterns
- **Consistent Spacing**: Array coordinates ensure uniform spatial relationships

### 📊 **Data Source**

We'll load array coordinates from a pre-processed AnnData file that contains spatial metadata for all spots.

Let's load the spatial coordinate information:

In [None]:
# Import scanpy for reading spatial transcriptomics data
import scanpy as sc

print("📍 Loading spatial coordinate metadata...")

# Load the AnnData file containing spatial metadata
adata_path = "../example_data/example_raw_data/sp.X_norm5e4_log1p.h5ad"
print(f"   📁 Loading from: {adata_path}")

adata = sc.read(adata_path)

# Extract array coordinates (grid positions)
spot_array_cols = adata.obs.array_col
spot_array_rows = adata.obs.array_row

print(f"✅ Spatial metadata loaded successfully!")
print(f"   📊 Dataset information:")
print(f"      Total spots in AnnData: {adata.n_obs}")
print(f"      Array column range: {spot_array_cols.min()} to {spot_array_cols.max()}")
print(f"      Array row range: {spot_array_rows.min()} to {spot_array_rows.max()}")
print(f"      Grid dimensions: {spot_array_cols.max() - spot_array_cols.min() + 1} × {spot_array_rows.max() - spot_array_rows.min() + 1}")

print(f"\n📋 Coordinate format:")
print(f"   Index: Spot barcode (e.g., 'WSA_LngSP8759311_AAACAAGTATCTCCCA-1')")
print(f"   array_col: Column position in spatial grid")
print(f"   array_row: Row position in spatial grid")

print(f"\n💡 These coordinates will be used to:")
print(f"   - Determine spatial neighbors for each spot")
print(f"   - Build adjacency matrix for graph construction")
print(f"   - Ensure consistent spatial relationships")

In [None]:
spot_array_cols.head()

## 🔗 Step 5: Graph Construction and Spatial Relationships

### 🎯 **Understanding Spatial Neighborhood Patterns**

From the spot visualization image (`./example_data/example_raw_data/WSA_LngSP9258467/spot_view.jpg`), we can observe that each spot has **6 nearest neighbors** in a hexagonal pattern. This is characteristic of spatial transcriptomics platforms like Visium.

### 📐 **Hexagonal Grid Structure**

In spatial transcriptomics:
- **Hexagonal Pattern**: Spots are arranged in a hexagonal grid
- **6 Neighbors**: Each spot (except edges) has exactly 6 spatial neighbors
- **Regular Spacing**: Neighbors are at consistent distances
- **Symmetric Relations**: If A is a neighbor of B, then B is a neighbor of A

### 🔧 **Neighborhood Definition Algorithm**

We'll create the adjacency matrix based on array coordinates:

```python
for each spot i:
    for each spot j:
        if distance(array_coords[i], array_coords[j]) <= threshold:
            adjacency[i][j] = 1  # They are neighbors
```

### 🎚️ **Distance Thresholds**

Based on the hexagonal grid pattern:
- **Column distance**: ±2 (within 2 columns)
- **Row distance**: ±1 (within 1 row)
- **Self-connection**: Each spot connects to itself (diagonal = 1)

### 📊 **Expected Graph Properties**

After construction, we expect:
- **Nodes**: 422 spots (same as our processed data)
- **Edges**: ~2,600-2,800 edges (each spot has ~6 neighbors)
- **Symmetry**: Undirected graph (mutual connections)
- **Connectivity**: All spots should be connected (single component)

Let's build the spatial graph:

In [None]:
spot_array_x_y = []
for item in spot_id_array:
    spot_array_x_y.append([int(spot_array_cols[item]), int(spot_array_rows[item])])
    
spot_array_x_y[:5]

In [None]:
print("🔗 Building spatial adjacency matrix...")
print("=" * 50)

# Initialize adjacency matrix
num_spots = len(spot_array_x_y)
adj = np.zeros((num_spots, num_spots))

print(f"   📊 Matrix dimensions: {num_spots} × {num_spots}")
print(f"   🔧 Neighborhood criteria:")
print(f"      - Column distance: within ±2 positions")
print(f"      - Row distance: within ±1 positions")
print(f"      - Self-connections: included (diagonal = 1)")

# Build adjacency matrix with progress tracking
connections_made = 0
progress_interval = num_spots // 10  # Show progress every 10%

for i in range(num_spots):
    # Show progress
    if i % progress_interval == 0:
        progress = (i / num_spots) * 100
        print(f"   📈 Progress: {progress:.0f}% ({i}/{num_spots} spots processed)")
    
    for j in range(num_spots):
        if i == j:
            # Self-connection
            adj[i][j] = 1.0
            connections_made += 1
        else:
            # Get array coordinates
            x1, y1 = spot_array_x_y[i]
            x2, y2 = spot_array_x_y[j]
            
            # Check if spots are neighbors based on hexagonal grid pattern
            col_dist = abs(x2 - x1)
            row_dist = abs(y2 - y1)
            
            # Hexagonal neighborhood: within 2 columns and 1 row
            if col_dist < 3 and row_dist < 2:
                adj[i][j] = 1.0
                connections_made += 1

print(f"✅ Adjacency matrix construction complete!")
print(f"   📊 Graph statistics:")
print(f"      Total possible connections: {num_spots * num_spots:,}")
print(f"      Actual connections made: {connections_made:,}")
print(f"      Graph density: {(connections_made / (num_spots * num_spots)) * 100:.2f}%")
print(f"      Average degree: {connections_made / num_spots:.2f}")
print(f"      Expected edges in PyTorch Geometric: {connections_made // 2:,} (undirected)")

# Verify symmetry (should be symmetric for undirected graph)
is_symmetric = np.allclose(adj, adj.T)
print(f"   ✅ Graph symmetry check: {'PASSED' if is_symmetric else 'FAILED'}")

print(f"\n💡 This adjacency matrix defines the spatial relationships")
print(f"   Each 1 indicates two spots are spatial neighbors")
print(f"   This will be converted to edge_index format for PyTorch Geometric")

## 📍 Step 6: Loading Pixel Coordinates

In addition to array coordinates for graph construction, we need **pixel coordinates** for visualization and spatial analysis.

### 🎯 **Pixel Coordinates vs Array Coordinates**

| Coordinate Type | Purpose | Format | Example |
|----------------|---------|---------|---------|
| **Array Coordinates** | Graph construction | Grid positions (int) | (27, 29) |
| **Pixel Coordinates** | Visualization & analysis | Image positions (float) | (12701.5, 9136.2) |

### 📊 **Uses of Pixel Coordinates**

- **Visualization**: Plotting predictions on tissue slides
- **Spatial Analysis**: Calculating distances between spots
- **Co-localization**: Measuring spatial relationships between cell types
- **Quality Control**: Verifying spatial patterns

### 📁 **Data Source**

The pixel coordinates are stored in `spots.csv` and contain the exact (x,y) positions of each spot on the original histology slide.

Let's load the pixel coordinates:

In [None]:
print("📍 Loading pixel coordinates...")

# Load pixel coordinates from spots.csv
spots_coord_path = "../example_data/example_raw_data/WSA_LngSP9258467/spots.csv"
print(f"   📁 Loading from: {spots_coord_path}")

spots_coord_df = pd.read_csv(spots_coord_path, index_col=0)
print(f"   📊 Loaded coordinates for {len(spots_coord_df)} spots")

# Filter to only spots that we have processed data for
print(f"   🔍 Filtering to {len(spot_id_array)} processed spots...")
spots_coord = spots_coord_df.loc[spot_id_array].values

# Validate coordinate data
print(f"   ✅ Coordinate validation:")
print(f"      Shape: {spots_coord.shape}")
print(f"      X range: {spots_coord[:, 0].min():.0f} to {spots_coord[:, 0].max():.0f}")
print(f"      Y range: {spots_coord[:, 1].min():.0f} to {spots_coord[:, 1].max():.0f}")
print(f"      Data type: {spots_coord.dtype}")

# Check for any missing coordinates
missing_coords = len(spot_id_array) - len(spots_coord)
if missing_coords > 0:
    print(f"   ⚠️  Warning: {missing_coords} spots missing pixel coordinates")
else:
    print(f"   ✅ All {len(spot_id_array)} spots have pixel coordinates")

# Calculate spatial extent
width = spots_coord[:, 0].max() - spots_coord[:, 0].min()
height = spots_coord[:, 1].max() - spots_coord[:, 1].min()
print(f"   📏 Spatial extent: {width:.0f} × {height:.0f} pixels")

# Show sample coordinates
print(f"\n📋 Sample pixel coordinates:")
for i in range(min(5, len(spots_coord))):
    spot_id = spot_id_array[i]
    x, y = spots_coord[i]
    print(f"      {spot_id}: ({x:.0f}, {y:.0f})")

print(f"\n💡 These coordinates will be used for:")
print(f"   - Visualizing predictions on tissue slides")
print(f"   - Calculating spatial distances and relationships")
print(f"   - Quality control and validation")

# Show the first 5 coordinates for reference
spots_coord[:5]

## 🎯 Step 7: Creating PyTorch Geometric Data Object

Finally, we'll combine all processed components into a single `torch_geometric.data.Data` object that Hist2Cell can use for training and inference.

### 📦 **Data Integration Process**

We'll combine the following components:

| Component | Variable | Shape | Description |
|-----------|----------|-------|-------------|
| **Image Features** | `x` | [422, 3, 224, 224] | Preprocessed histology patches |
| **Labels** | `y` | [422, 330] | Gene expression + cell abundances |
| **Graph Structure** | `edge_index` | [2, num_edges] | Spatial connectivity |
| **Positions** | `pos` | [422, 2] | Pixel coordinates |
| **Metadata** | `spot_id` | [422] | Spot identifiers |

### 🔧 **PyTorch Geometric Format**

The final `Data` object will contain:
- **Node features (`x`)**: Image patches for each spot
- **Edge indices (`edge_index`)**: Spatial graph connectivity in COO format
- **Labels (`y`)**: Training targets (genes + cell types)
- **Positions (`pos`)**: Pixel coordinates for visualization
- **Metadata (`spot_id`)**: Spot identifiers for tracking

### 🎯 **Key Transformations**

1. **Adjacency → Edge Index**: Convert adjacency matrix to COO format
2. **Numpy → PyTorch**: Convert all arrays to PyTorch tensors
3. **Data Validation**: Ensure all components have consistent dimensions
4. **Graph Properties**: Verify graph connectivity and structure

### 💾 **Output**

The final processed data will be saved as:
- **Format**: `.pt` file (PyTorch format)
- **Size**: ~50-100 MB (depends on number of spots)
- **Structure**: Ready for Hist2Cell training/inference

Let's create the final data object:

In [None]:
print("🎯 Creating PyTorch Geometric Data object...")
print("=" * 50)

# Import required modules
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import Data
from torch import Tensor

print("📦 Converting data components to PyTorch tensors...")

# Convert arrays to PyTorch tensors
print("   🖼️  Converting image data...")
x = Tensor(spot_data_array)
print(f"      Shape: {x.shape}")
print(f"      Memory: {x.numel() * x.element_size() / (1024**2):.1f} MB")

print("   🏷️  Converting label data...")
y = Tensor(spot_label_array)
print(f"      Shape: {y.shape}")
print(f"      Components: {y.shape[1]} (250 genes + 80 cell types)")

print("   📍 Converting position data...")
pos = Tensor(spots_coord)
print(f"      Shape: {pos.shape}")
print(f"      X range: {pos[:, 0].min():.0f} to {pos[:, 0].max():.0f}")
print(f"      Y range: {pos[:, 1].min():.0f} to {pos[:, 1].max():.0f}")

print("   🔗 Converting adjacency matrix to edge index...")
adj_tensor = Tensor(adj)
edge_index, edge_weights = dense_to_sparse(adj_tensor)
print(f"      Adjacency matrix shape: {adj_tensor.shape}")
print(f"      Edge index shape: {edge_index.shape}")
print(f"      Number of edges: {edge_index.shape[1]}")

# Create PyTorch Geometric Data object
print("\n🏗️ Creating PyTorch Geometric Data object...")
data = Data(
    x=x,                    # Node features (image patches)
    edge_index=edge_index,  # Graph connectivity
    y=y,                    # Labels (genes + cell types)
    pos=pos,                # Pixel coordinates
    spot_id=spot_id_array   # Spot identifiers
)

print("✅ Data object created successfully!")
print(f"   📊 Final data structure:")
print(f"      Nodes: {data.num_nodes}")
print(f"      Edges: {data.num_edges}")
print(f"      Node features: {data.x.shape}")
print(f"      Labels: {data.y.shape}")
print(f"      Positions: {data.pos.shape}")
print(f"      Spot IDs: {len(data.spot_id)}")

# Data validation checks
print(f"\n✅ Data validation:")
print(f"   - Node consistency: {data.x.shape[0] == data.y.shape[0] == data.pos.shape[0] == len(data.spot_id)}")
print(f"   - Edge validation: {data.edge_index.max().item() < data.num_nodes}")
print(f"   - Feature dimensions: {data.x.shape[1:]} (3 channels, 224x224)")
print(f"   - Label dimensions: {data.y.shape[1]} (330 total)")

# Memory usage summary
total_memory = (data.x.numel() * data.x.element_size() + 
                data.y.numel() * data.y.element_size() + 
                data.pos.numel() * data.pos.element_size() + 
                data.edge_index.numel() * data.edge_index.element_size()) / (1024**2)
print(f"   - Total memory: {total_memory:.1f} MB")

print(f"\n💡 This data object is now ready for:")
print(f"   - Hist2Cell model training")
print(f"   - Inference and prediction")
print(f"   - Subgraph sampling with NeighborLoader")
print(f"   - Visualization and analysis")

# Show the final data object
print(f"\n📋 Final Data Object:")
data

In [None]:
print("💾 Saving processed data to file...")
print("=" * 50)

# Define output path
output_path = "../example_data/example_processed_data/WSA_LngSP9258467.pt"
print(f"   📁 Output path: {output_path}")

# Create directory if it doesn't exist
import os
os.makedirs(os.path.dirname(output_path), exist_ok=True)

# Save the data object
print("   🔄 Saving PyTorch Geometric Data object...")
torch.save(data, output_path)

# Verify the saved file
if os.path.exists(output_path):
    file_size = os.path.getsize(output_path) / (1024**2)  # MB
    print(f"✅ Data saved successfully!")
    print(f"   📊 File information:")
    print(f"      Path: {output_path}")
    print(f"      Size: {file_size:.1f} MB")
    print(f"      Format: PyTorch (.pt)")
    
    # Test loading the saved data
    print(f"\n🧪 Testing data loading...")
    try:
        loaded_data = torch.load(output_path)
        print(f"   ✅ Load test successful!")
        print(f"      Loaded data: {loaded_data}")
        print(f"      Nodes: {loaded_data.num_nodes}")
        print(f"      Edges: {loaded_data.num_edges}")
    except Exception as e:
        print(f"   ❌ Load test failed: {e}")
        
else:
    print(f"❌ Failed to save data to {output_path}")

print(f"\n🎯 Next Steps:")
print(f"   1. Use this data for Hist2Cell training:")
print(f"      - See ../tutorial_training/training_tutorial.ipynb")
print(f"   2. Process additional slides using the same pipeline")
print(f"   3. Create train/test splits for model evaluation")
print(f"   4. Analyze predictions with evaluation tutorials")

print(f"\n💡 This processed data contains:")
print(f"   - Ready-to-use histology image patches")
print(f"   - Spatial graph structure for neighboring relationships")
print(f"   - Gene expression and cell type abundance labels")
print(f"   - Pixel coordinates for visualization")
print(f"   - All metadata needed for downstream analysis")

print(f"\n✅ Data preparation tutorial complete!")
print(f"   Your spatial transcriptomics data is now ready for Hist2Cell!")

# 🎉 Tutorial Complete! Your Data is Ready for Hist2Cell

## 📋 Summary of What We've Accomplished

Congratulations! You've successfully completed the **Hist2Cell Data Preparation Tutorial**. Here's what we've achieved:

### ✅ **Data Processing Pipeline**
1. **📊 Data Structure Understanding**: Learned PyTorch Geometric format for spatial transcriptomics
2. **🖼️ Image Processing**: Loaded and preprocessed 224×224 histology patches
3. **🧬 Label Integration**: Combined gene expression and cell type abundance data
4. **📍 Spatial Mapping**: Extracted both array and pixel coordinates
5. **🔗 Graph Construction**: Built spatial neighborhood relationships
6. **📦 Final Integration**: Created PyTorch Geometric Data object
7. **💾 Data Saving**: Saved processed data for training and inference

### 📊 **Final Data Structure**
- **Nodes**: 422 spots with complete data
- **Edges**: ~2,700 spatial connections
- **Features**: 224×224 RGB image patches
- **Labels**: 330 values (250 genes + 80 cell types)
- **Coordinates**: Both array and pixel positions
- **Format**: Ready for Hist2Cell training

## 🔄 Next Steps in Your Hist2Cell Journey

### 1. **🏋️ Model Training**
- **Tutorial**: `../tutorial_training/training_tutorial.ipynb`
- **Purpose**: Train Hist2Cell on your processed data
- **Output**: Trained model weights for inference

### 2. **📈 Analysis & Evaluation**
Navigate to `../tutorial_analysis_evaluation/` for:
- **Cell Abundance Visualization**: Visualize predictions on tissue
- **Key Cell Evaluation**: Evaluate specific cell types of interest
- **Cell Co-localization**: Analyze spatial relationships
- **Super-resolution**: Generate high-resolution abundance maps

### 3. **🔄 Process Additional Data**
Apply this pipeline to your own datasets:
- **Multiple Slides**: Process entire experiments
- **Different Tissues**: Adapt to various tissue types
- **Custom Cell Types**: Modify cell type definitions
- **Batch Processing**: Scale to large datasets

## 🛠️ Troubleshooting Guide

### ❓ Common Issues and Solutions

#### **Issue**: Memory errors during processing
- **Solution**: Reduce batch size in DataLoader
- **Code**: `batch_size=256` instead of `batch_size=512`

#### **Issue**: Missing image patches
- **Solution**: Check patch extraction with DSMIL pipeline
- **Reference**: [DSMIL-WSI](https://github.com/binli123/dsmil-wsi)

#### **Issue**: Coordinate mismatches
- **Solution**: Verify spot IDs match between image and label files
- **Check**: Ensure consistent spot naming conventions

#### **Issue**: Graph connectivity problems
- **Solution**: Adjust distance thresholds in adjacency matrix
- **Modify**: `col_dist < 3` and `row_dist < 2` parameters

### 🔧 **Performance Optimization**

#### **For Large Datasets**:
- **Parallel Processing**: Increase `num_workers` in DataLoader
- **Memory Management**: Process data in smaller chunks
- **Storage**: Use SSD for faster I/O operations

#### **For GPU Memory**:
- **Reduce Batch Size**: Lower `subgraph_bs` in NeighborLoader
- **Mixed Precision**: Use `torch.cuda.amp` for training
- **Model Checkpointing**: Save memory during training

## 📚 Best Practices for Your Own Data

### 🎯 **Data Quality Checks**
1. **Image Quality**: Ensure patches are clear and well-focused
2. **Coordinate Accuracy**: Verify spatial positions are correct
3. **Label Completeness**: Check for missing gene/cell type data
4. **Graph Connectivity**: Ensure all spots are properly connected

### 📏 **Standardization Tips**
1. **Image Normalization**: Always use ImageNet statistics
2. **Coordinate Systems**: Maintain consistent spatial references
3. **Spot Naming**: Use consistent barcode formats
4. **File Organization**: Keep organized directory structures

### 🔄 **Batch Processing Workflow**
```python
# Example batch processing loop
slides = ['slide1', 'slide2', 'slide3']
for slide in slides:
    # Process each slide using the pipeline
    processed_data = process_slide(slide)
    torch.save(processed_data, f'processed_{slide}.pt')
```

## 🌟 Advanced Usage Tips

### 🎨 **Custom Cell Types**
- **Cell2location**: Use custom reference for deconvolution
- **scRNA-seq**: Integrate your own single-cell reference
- **Manual Annotation**: Include expert annotations

### 🔬 **Multi-modal Integration**
- **Protein Data**: Add protein abundance information
- **Clinical Data**: Include patient metadata
- **Temporal Data**: Process time-series experiments

### 🚀 **Performance Scaling**
- **GPU Acceleration**: Use CUDA for faster processing
- **Distributed Processing**: Scale across multiple GPUs
- **Cloud Computing**: Leverage cloud resources for large datasets

## 📖 Additional Resources

### 📝 **Documentation**
- **PyTorch Geometric**: [https://pytorch-geometric.readthedocs.io/](https://pytorch-geometric.readthedocs.io/)
- **Scanpy**: [https://scanpy.readthedocs.io/](https://scanpy.readthedocs.io/)
- **Cell2location**: [https://cell2location.readthedocs.io/](https://cell2location.readthedocs.io/)

### 🔬 **Research Papers**
- **Hist2Cell**: [Deciphering Fine-grained Cellular Architectures from Histology Images](https://www.biorxiv.org/content/10.1101/2024.02.17.580852v1.full.pdf)
- **Spatial Transcriptomics**: Key papers in spatial biology
- **Graph Neural Networks**: GNN applications in biology

### 🤝 **Community Support**
- **GitHub Issues**: Report bugs and request features
- **Discussions**: Join community discussions
- **Contributions**: Contribute to the project development

## 🎯 Final Reminders

### ✅ **Success Checklist**
- [ ] Data processed successfully
- [ ] PyTorch Geometric format validated
- [ ] Graph structure verified
- [ ] File saved and tested
- [ ] Ready for model training

### 💡 **Key Takeaways**
1. **Data Quality**: High-quality input data is crucial for good results
2. **Spatial Context**: Graph structure preserves tissue architecture
3. **Standardization**: Consistent preprocessing ensures reproducibility
4. **Validation**: Always validate processed data before training
5. **Documentation**: Keep detailed records of processing steps

## 🚀 **You're Ready to Go!**

Your spatial transcriptomics data is now fully prepared for Hist2Cell analysis. The processed data contains all the necessary components for training, inference, and visualization. 

**Next up**: Head to the training tutorial to learn how to train your Hist2Cell model!

---

*Happy analyzing! 🧬✨*
