## Training and deploying a SpatialProp model

This tutorial demonstrates how to train a SpatialProp model from scratch and deploy it for scoring on a coronal mouse tissue brain section. 

We will make use of the `aging_coronal.h5ad` dataset from [Sun et al., 2025](https://www.nature.com/articles/s41586-024-08334-8). This dataset includes coronal brain sections from mice at 20 different ages tiling the entire lifespan. Spatial transcriptomics of 300 genes were profiled with MERFISH technology. To download the dataset, run the following cell:

In [1]:
! mkdir -p ./data
! if [ ! -f ./data/aging_coronal.h5ad ]; then \
      echo "Downloading aging_coronal.h5ad..."; \
      wget https://zenodo.org/records/13883177/files/aging_coronal.h5ad -O ./data/aging_coronal.h5ad; \
    else \
      echo "File already exists: ./data/aging_coronal.h5ad — skipping download."; \
    fi

File already exists: ./data/aging_coronal.h5ad — skipping download.


With the `spatial-prop` conda environment activated (see Installation section of [README.md](../README.md)), run the following cell to import the API:

In [2]:
import numpy as np
import scanpy as sc 
import torch

from spatial_gnn.api.perturbation_api import (
    train_perturbation_model,
    create_perturbation_input_matrix,
    predict_perturbation_effects, 
    predict_perturbation_effects
)

  from pkg_resources import DistributionNotFound, get_distribution


### Define and train the GNN

Training the SpatialProp GNN requires defining the set of training arguments detailed in the [perturbation training API](../src/spatial_gnn/api/perturbation_api.py) docstring. Here we reuse the model configuration reported in the paper. 

Graphs are constructed using 2-hop neighbors centered around cells of all cell types, and we limit to 100 cells per cell type. We augment the training and test sets with 2-hop neighborhood graphs around each surrounding cell. Here we train the base model which does not use or predict cell type labels.

In [3]:
training_args = {
    "dataset": "aging_coronal",
    "exp_name": "api_demo",
    "base_path": "./data",
    "k_hop": 2,
    "augment_hop": 2,
    "center_celltypes": "all",
    "node_feature": "expression",
    "inject_feature": "none",
    "learning_rate": 0.0001,
    "loss": "weightedl1",
    "epochs": 50,
    "normalize_total": True,
    "num_cells_per_ct_id": 100,
    "adata_path": "./data/aging_coronal.h5ad",
    "predict_celltype": False,
    "pool": "center",
    "do_eval": True,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

The training API call will trigger construction of the graph dataset in the `./data/gnn_datasets/` directory.

In [4]:
test_loader, gene_names, (model, model_config, trained_model_path) = train_perturbation_model(
    **training_args,
)

Training new perturbation model from scratch...
Model will be saved to: output/api_demo/aging_coronal_expression_2hop_2augment_expression_none/weightedl1_1en04
Training on device: cuda
Dataset already exists at:  ./data/gnn_datasets/aging_coronal_expression_100per_2hop_2C0aug_200delaunay_expressionFeat_all_NoneInject/test
Finished processing test dataset
Dataset already exists at:  ./data/gnn_datasets/aging_coronal_expression_100per_2hop_2C0aug_200delaunay_expressionFeat_all_NoneInject/train
Finished processing train dataset


100%|██████████| 775/775 [01:30<00:00,  8.57it/s]
100%|██████████| 190/190 [00:22<00:00,  8.41it/s]

Train samples: 387234
Test samples: 94792





Expression model initialized on cuda
Starting Baseline training for 50 epochs...
 Epoch: 001, Train WL1: 6.7997, Test WL1: 6.7990, Test Spearman: 0.5687
 Epoch: 002, Train WL1: 6.4684, Test WL1: 6.4551, Test Spearman: 0.5806
 Epoch: 003, Train WL1: 6.3286, Test WL1: 6.3037, Test Spearman: 0.5857
 Epoch: 004, Train WL1: 6.2463, Test WL1: 6.2213, Test Spearman: 0.5876
 Epoch: 005, Train WL1: 6.1936, Test WL1: 6.1705, Test Spearman: 0.5883
 Epoch: 006, Train WL1: 6.1517, Test WL1: 6.1303, Test Spearman: 0.5898
 Epoch: 007, Train WL1: 6.1238, Test WL1: 6.1033, Test Spearman: 0.5907
 Epoch: 008, Train WL1: 6.1026, Test WL1: 6.0834, Test Spearman: 0.5916
 Epoch: 009, Train WL1: 6.0882, Test WL1: 6.0707, Test Spearman: 0.5913
 Epoch: 010, Train WL1: 6.0823, Test WL1: 6.0658, Test Spearman: 0.5921
 Epoch: 011, Train WL1: 6.0733, Test WL1: 6.0568, Test Spearman: 0.5925
 Epoch: 012, Train WL1: 6.0659, Test WL1: 6.0502, Test Spearman: 0.5923
 Epoch: 013, Train WL1: 6.0589, Test WL1: 6.0425, Test 

NotADirectoryError: [Errno 20] Not a directory: 'output/api_demo/aging_coronal_expression_2hop_2augment_expression_none/weightedl1_1en04/model.pth/training.pkl'

### Inference with perturbation model

Define set of perturbations in the form of a dictionary mapping `cell type` → `gene name` → `multiplier`. For instance, entry `'T cell': {'Igf2': 0.0},` indicates knockout of IGF2 in all T-cells in the input dataset.

In [None]:
# Define perturbations
perturbation_dict = {
    'T cell': {'Igf2': 0.0},  
    'NSC': {'Sox9': 2.0},         
    'Pericyte': {'Ccl4': 0.5}    
}

In [None]:
# Save perturbation mask to anndata
test_adata = sc.read_h5ad(test_data_path)
test_data_path_perturbed = create_perturbation_mask(test_adata, perturbation_dict, save_path=test_data_path)

In [None]:
print("\n=== Predicting perturbation effects ===")
adata_perturbed = predict_perturbation_effects(
    adata_path=test_data_path_perturbed,
    model_path=model_path,
    exp_name="aging_sagittal",
    perturbation_dict=perturbation_dict,
    perturbation_mask_key="perturbation_mask"
)

In [3]:
result_path = "/oak/stanford/groups/akundaje/abuen/spatial/spatial-gnn/data/perturbed/aging_coronal_perturbed_result.h5ad"

In [4]:
import scanpy as sc 

adata_result = sc.read_h5ad(result_path)

In [5]:
adata_result

AnnData object with n_obs × n_vars = 1453144 × 300
    obs: 'volume', 'center_x', 'center_y', 'min_x', 'min_y', 'max_x', 'max_y', 'transcript_count', 'num_detected_genes', 'barcodeCount', 'mouse_id', 'slide_id', 'cohort', 'age', 'batch', 'celltype', 'region', 'subregion'
    uns: 'neighbors', 'pca', 'perturbation_info', 'umap'
    obsm: 'X_pca', 'X_umap', 'perturbation_mask', 'spatial'
    varm: 'PCs'
    layers: 'perturbation_effects', 'predicted_perturbed'
    obsp: 'connectivities', 'distances'

In [6]:
effects = adata_result.layers['perturbation_effects']