## Training and deploying a SpatialProp model

This tutorial demonstrates how to train a SpatialProp model from scratch and deploy it for scoring on a coronal mouse tissue brain section. 

We will make use of the `aging_coronal.h5ad` dataset from [Sun et al., 2025](https://www.nature.com/articles/s41586-024-08334-8). This dataset includes coronal brain sections from mice at 20 different ages tiling the entire lifespan. Spatial transcriptomics of 300 genes were profiled with MERFISH technology. To download the dataset, run the following cell:

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
! mkdir -p ./data
! if [ ! -f ./data/aging_coronal.h5ad ]; then \
      echo "Downloading aging_coronal.h5ad..."; \
      wget https://zenodo.org/records/13883177/files/aging_coronal.h5ad -O ./data/aging_coronal.h5ad; \
    else \
      echo "File already exists: ./data/aging_coronal.h5ad — skipping download."; \
  fi

File already exists: ./data/aging_coronal.h5ad — skipping download.


With the `spatial-prop` conda environment activated (see Installation section of [README.md](../README.md)), run the following cell to import the API:

In [3]:
import numpy as np
import scanpy as sc 
import torch

from spatial_gnn.api.perturbation_api import (
    train_perturbation_model,
    create_perturbation_input_matrix,
    predict_perturbation_effects, 
    predict_perturbation_effects
)
from spatial_gnn.utils.plot_utils import (
    plot_loss_curves,
    plot_celltype_performance,
    plot_gene_in_section,
)

  from pkg_resources import DistributionNotFound, get_distribution


### Define and train the GNN

Training the SpatialProp GNN requires defining the set of training arguments detailed in the [perturbation training API](../src/spatial_gnn/api/perturbation_api.py) docstring. Here we reuse the model configuration reported in the paper. 

Graphs are constructed using 2-hop neighbors centered around cells of all cell types, and we limit to 100 cells per cell type. We augment the training and test sets with 2-hop neighborhood graphs around each surrounding cell. Here we train the base model which does not use cell type labels as a feature.

For the purposes of the demo, we train using a single mouse's data and evaluate on a held-out mouse's data.

In [4]:
adata_path = "./data/aging_coronal.h5ad"
train_ids = ["14"]
test_ids = ["11"]
exp_name = "api_demo"

In [8]:
training_args = {
    "dataset": "aging_coronal",
    "exp_name": "api_demo",
    "file_path": adata_path,
    "train_ids": train_ids, 
    "test_ids": test_ids,
    "k_hop": 2,
    "augment_hop": 2,
    "center_celltypes": "all",
    "node_feature": "expression",
    "inject_feature": "none",
    "learning_rate": 0.0001,
    "loss": "weightedl1",
    "epochs": 30,
    "normalize_total": True,
    "num_cells_per_ct_id": 100,
    "predict_celltype": False,
    "pool": "center",
    "do_eval": True,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "overwrite": True,
}

The training API call will trigger construction of the graph dataset in the `./data/gnn_datasets/` directory.

In [None]:
test_loader, gene_names, (model, model_config, trained_model_path) = train_perturbation_model(
    **training_args,
)

Training new perturbation model from scratch...
Model will be saved to: ./output/api_demo/aging_coronal_expression_2hop_2augment_expression_none/weightedl1_1en04
Training on device: cuda
Dataset already exists at:  ./data/gnn_datasets/aging_coronal_expression_100per_2hop_2C0aug_200delaunay_expressionFeat_all_NoneInject_subset_11/test
Finished processing test dataset
Dataset already exists at:  ./data/gnn_datasets/aging_coronal_expression_100per_2hop_2C0aug_200delaunay_expressionFeat_all_NoneInject_subset_14/train
Finished processing train dataset


0it [00:00, ?it/s]


ValueError: num_samples should be a positive integer value, but got num_samples=0

: 

### Inspect training model performance

In [None]:

plot_loss_curves(save_dir)

In [None]:
plot_celltype_performance(save_dir)

## Inference on deployed SpatialProp model

Define set of perturbations in the form of a dictionary mapping `cell type` → `gene name` → `multiplier`. For instance, entry `'T cell': {'Igf2': 0.0},` indicates knockout of IGF2 in all T-cells in the input dataset.

In [None]:
perturbation_dict = {
    'T cell': {'Il6': 10.0, 'Tnf': 10.0, 'Ifng': 10.0},    
    'Microglia': {'Il6': 10.0, 'Tnf': 10.0, 'Ifng': 10.0},          
}
adata_path = "./data/aging_coronal.h5ad"
save_path = "./data/perturbed_adata/aging_coronal_perturbed.h5ad"

adata = sc.read_h5ad(adata_path)
save_path = create_perturbation_input_matrix(adata, perturbation_dict, save_path=save_path)

In [None]:
adata_result = predict_perturbation_effects(
    save_path, trained_model_path, exp_name, use_ids=test_ids
)