# HEIST: Hierarchical Embeddings for Integrated Spatial Transcriptomics

This notebook demonstrates how to:
1. Load spatial transcriptomics data (AnnData format)
2. Preprocess data and build hierarchical graphs
3. Generate cell representations using a pre-trained model
4. Visualize embeddings with PHATE colored by cell type

In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
import scanpy as sc
from utils.preprocess import preprocess
from utils.dataloader import create_dataloader
from model.model import GraphEncoder
from torch_geometric.nn.pool import global_mean_pool
from tqdm import tqdm
import matplotlib.pyplot as plt
import phate
import numpy as np

## Configuration

**Update these paths before running:**
- `data_path`: Path to your AnnData file (.h5ad)
- `model_path`: Path to pre-trained model checkpoint (.pth)

**Required data structure:**
- `adata.X`: gene expression matrix
- `adata.obsm['spatial']`: spatial coordinates (required)
- `adata.obs['cell_type']`: cell type labels (optional, will cluster if missing)

In [None]:
# Update these paths
data_path = "path/to/adata"
model_path = "path/to/model"

# Save locations
save_root = "data/preprocessed"
save_file_name = "sample_data"

# Device configuration (automatically uses GPU if available)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## Step 1: Load AnnData file

In [None]:
print(f"Loading data from {data_path}")
adata = sc.read_h5ad(data_path)
print(f"Loaded AnnData with {adata.n_obs} cells and {adata.n_vars} genes")
print(f"Available obsm keys: {list(adata.obsm.keys())}")
print(f"Available obs columns: {list(adata.obs.columns)}")

## Step 2: Preprocess data and create graphs

This step will:
1. Filter and normalize gene expression
2. Select top 200 highly variable genes
3. Build spatial cell graph using Voronoi tessellation
4. Create gene regulatory networks (GRNs) for each cell type using mutual information
5. Save graphs to `{save_root}/{save_file_name}.pt`

**Note:** If the .pt file already exists, it will load from cache instead of reprocessing.

In [None]:
print("Preprocessing data and creating graphs...")
graphs = preprocess(
    adata=adata,
    save_root=save_root,
    save_file_name=save_file_name,
    max_genes=200,  # Number of highly variable genes to select
    spatial='spatial',  # Key in adata.obsm for spatial coordinates
    cell_type='cell_type'  # Set to 'cell_type' if already annotated, None for automatic clustering
)
print(f"Created {len(graphs)} graphs:")
print(f"  - 1 high-level cell graph with {graphs[0].num_nodes} nodes")
print(f"  - {len(graphs)-1} low-level gene graphs (one per cell)")

## Step 3: Load pre-trained model

The checkpoint should contain:
- `args`: hyperparameters (pe_dim, hidden_dim, etc.)
- `model_state_dict`: trained model weights

In [None]:
print(f"Loading model from {model_path}")
checkpoint = torch.load(model_path, weights_only = False, map_location=device)
args = checkpoint['args']

print(f"Model configuration:")
print(f"  - PE dimension: {args.pe_dim}")
print(f"  - Hidden dimension: {args.hidden_dim}")
print(f"  - Output dimension: {args.output_dim}")
# Initialize model architecture
model = GraphEncoder(
    args.pe_dim, 
    args.init_dim, 
    args.hidden_dim, 
    args.output_dim, 
    args.num_layers, 
    args.num_heads, 
    args.cross_message_passing, 
    args.pe, 
    args.blending
).to(device)

# Load model weights and set to evaluation mode
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded successfully")

## Step 4: Calculate cell representations

This uses the pre-trained model to encode both:
- **High-level embeddings**: from the spatial cell graph
- **Low-level embeddings**: from gene regulatory networks

The final representation concatenates both embeddings.

In [None]:
# print("Calculating cell representations...")
# batch_size = 128 # Process cells in batches to fit GPU memory
# dataloader = create_dataloader(graphs, batch_size, permute=False)

# Initialize tensor to store embeddings (2x output_dim for concatenated embeddings)
graph_embeddings = torch.zeros((graphs[0].num_nodes, 2*args.output_dim)).to(device)
high_level_graph = graphs[0].to(device)

with torch.no_grad():
    for high_level_subgraph, low_level_batch, batch_idx in tqdm(dataloader, desc="Encoding cells"):
        high_level_subgraph = high_level_subgraph.to(device)
        low_level_batch = low_level_batch.to(device)
        low_level_batch.batch_idx = batch_idx.to(device)
        
        # Encode both high-level (spatial) and low-level (gene) graphs
        high_emb, low_emb = model.encode(high_level_subgraph, low_level_batch, args.pe_dim)
        
        # Combine embeddings: [high-level spatial | low-level gene]
        graph_embeddings[batch_idx] = torch.cat([high_emb, global_mean_pool(low_emb, low_level_batch.batch)], dim=1)

# Store embeddings in the graph
high_level_graph.X = graph_embeddings
print(f"Cell representations calculated: {graph_embeddings.shape}")

## Step 5: Save results

Save the graph with embeddings for downstream analysis.

In [None]:
output_path = f"data/{save_file_name}_with_embeddings.pt"
torch.save(high_level_graph, output_path)
print(f"Saved graph with embeddings to {output_path}")

print("\n" + "="*60)
print("Pipeline completed successfully!")
print("="*60)
print(f"Summary:")
print(f"  - Processed {graphs[0].num_nodes} cells")
print(f"  - Embedding dimension: {graph_embeddings.shape[1]}")
print(f"  - Outputs saved:")
print(f"    • Graph: {output_path}")