## Tutorial for perturbation API

In [1]:
import numpy as np
import scanpy as sc 

from spatial_gnn.api.perturbation_api import (
    train_perturbation_model,
    predict_perturbation_effects, 
    get_perturbation_summary, 
    visualize_perturbation_effects, 
    create_perturbation_mask
)

  from pkg_resources import DistributionNotFound, get_distribution


In [2]:
train_data_path = "/oak/stanford/groups/akundaje/abuen/spatial/spatial-gnn/data/raw/aging_coronal.h5ad"
test_data_path = "/oak/stanford/groups/akundaje/abuen/spatial/spatial-gnn/data/raw/aging_coronal.h5ad" 

### Define and train perturbation model 

This requires defining the training arguments and supplying an anndata for the training dataset.

In [3]:
training_args = {
    "k_hop": 2,
    "augment_hop": 2,
    "center_celltypes": "T cell,NSC,Pericyte",
    "node_feature": "expression",
    "inject_feature": "None",
    "debug": True,
    "debug_subset_size": 10,
    "num_cells_per_ct_id": 100,
    "epochs": 10,
}

In [4]:
print("=== Training a new perturbation model ===")
model, model_config, model_path = train_perturbation_model(
    adata_path=train_data_path,
    **training_args
)

=== Training a new perturbation model ===
Training new perturbation model from scratch...
Training on device: cuda
Starting dataset processing at 23:34:33
Dataset already exists at:  ./data/gnn_datasets/None_expression_100per_2hop_2C0aug_200delaunay_expressionFeat_TNP_NoneInject/test
Finished processing test dataset
Starting dataset processing at 23:34:33
Dataset already exists at:  ./data/gnn_datasets/None_expression_100per_2hop_2C0aug_200delaunay_expressionFeat_TNP_NoneInject/train
Finished processing train dataset
DEBUG MODE: Using subset of 10 samples from each dataset
DEBUG: Train dataset subset to 10 samples
DEBUG: Test dataset subset to 10 samples


100%|██████████| 10/10 [00:00<00:00, 34.89it/s]
100%|██████████| 10/10 [00:00<00:00, 33.25it/s]

Train samples: 10
Test samples: 10





Model initialized on cuda
Starting training for 10 epochs...
Epoch: 001, Train WL1: 20.1060, Test WL1: 18.1426
Epoch: 002, Train WL1: 10.6469, Test WL1: 10.0698
Epoch: 003, Train WL1: 5.8632, Test WL1: 5.3231
Epoch: 004, Train WL1: 3.9010, Test WL1: 3.4831
Epoch: 005, Train WL1: 3.0822, Test WL1: 2.7088
Epoch: 006, Train WL1: 2.6656, Test WL1: 2.3524
Epoch: 007, Train WL1: 2.4030, Test WL1: 2.1120
Epoch: 008, Train WL1: 2.2571, Test WL1: 2.0301
Epoch: 009, Train WL1: 1.8365, Test WL1: 1.5772
Epoch: 010, Train WL1: 1.7198, Test WL1: 1.4779
Training completed. Model saved to results/gnn/None_expression_100per_2hop_2C0aug_200delaunay_expressionFeat_TNP_NoneInject/DEBUG_weightedl1_1en04/model.pth
Model configuration saved to results/gnn/None_expression_100per_2hop_2C0aug_200delaunay_expressionFeat_TNP_NoneInject/DEBUG_weightedl1_1en04/model_config.json
Training logs saved
Training completed. Model saved to: results/gnn/None_expression_100per_2hop_2C0aug_200delaunay_expressionFeat_TNP_NoneI

### Inference with perturbation model

Define set of perturbations in the form of a dictionary mapping `cell type` → `gene name` → `multiplier`. For instance, entry `'T cell': {'Igf2': 0.0},` indicates knockout of IGF2 in all T-cells in the input dataset.

In [5]:
# Define perturbations
perturbation_dict = {
    'T cell': {'Igf2': 0.0},  
    'NSC': {'Sox9': 2.0},         
    'Pericyte': {'Ccl4': 0.5}    
}

In [6]:
# Save perturbation mask to anndata
test_adata = sc.read_h5ad(test_data_path)
test_data_path_perturbed = create_perturbation_mask(test_adata, perturbation_dict, save_path=test_data_path)

Applying perturbations to 1042 cells of type 'T cell'
  - Gene 'Igf2': multiplier = 0.0
Applying perturbations to 2582 cells of type 'NSC'
  - Gene 'Sox9': multiplier = 2.0
Applying perturbations to 28257 cells of type 'Pericyte'
  - Gene 'Ccl4': multiplier = 0.5
Saved AnnData with perturbation mask to: /oak/stanford/groups/akundaje/abuen/spatial/spatial-gnn/data/raw/aging_coronal.h5ad

Perturbation mask created:
- Shape: (1453144, 300)
- Cell types affected: ['T cell', 'NSC', 'Pericyte']
- Cells affected: 31881
- Genes affected: ['Igf2', 'Ccl4', 'Sox9']
- Mask stored in adata.obsm['perturbation_mask']


In [7]:
print("\n=== Predicting perturbation effects ===")
adata_perturbed = predict_perturbation_effects(
    adata_path=test_data_path_perturbed,
    model_path=model_path,
    perturbation_dict=perturbation_dict,
    perturbation_mask_key="perturbation_mask"
)


=== Predicting perturbation effects ===
Loaded model configuration:
  - input_dim: 300
  - output_dim: 300
  - inject_dim: 0
  - num_layers: 2
  - method: GIN
  - pool: add
Loaded pretrained model from: results/gnn/None_expression_100per_2hop_2C0aug_200delaunay_expressionFeat_TNP_NoneInject/DEBUG_weightedl1_1en04/model.pth
Creating perturbation mask...
Applying perturbations to 1042 cells of type 'T cell'
  - Gene 'Igf2': multiplier = 0.0
Applying perturbations to 2582 cells of type 'NSC'
  - Gene 'Sox9': multiplier = 2.0
Applying perturbations to 28257 cells of type 'Pericyte'
  - Gene 'Ccl4': multiplier = 0.5

Perturbation mask created:
- Shape: (1453144, 300)
- Cell types affected: ['T cell', 'NSC', 'Pericyte']
- Cells affected: 31881
- Genes affected: ['Igf2', 'Ccl4', 'Sox9']
- Mask stored in adata.obsm['perturbation_mask']
Creating graphs from input data...
Starting dataset processing at 23:38:05
Dataset already exists at:  ./data/gnn_datasets/temp_expression_100per_2hop_2C0aug_2

Processing data groups: 100%|██████████| 1/1 [00:02<00:00,  2.43s/it]

Batch 1: 2.313s total
Batch 2: 2.321s total
Batch 3: 2.327s total
Batch 4: 2.332s total
Batch 5: 2.337s total
Batch 6: 2.342s total
Batch 7: 2.350s total
Batch 8: 2.355s total
Batch 9: 2.361s total
Batch 10: 2.366s total
Batch 11: 2.371s total
Batch 12: 2.376s total
Batch 13: 2.382s total
Batch 14: 2.387s total
Batch 15: 2.392s total
Batch 16: 2.397s total
Batch 17: 2.401s total
Batch 18: 2.406s total
Batch 19: 2.412s total
Batch 20: 2.418s total
Batch 21: 2.423s total
Batch 22: 2.428s total

=== FINAL SUMMARY ===
Total time: 2.430s
Processed 22 batches
Average time per batch: 1677164131.582s





Predicted 1608 cells out of 1453144 total cells
Perturbation effects calculated as: predicted - original
Perturbation prediction completed successfully!


In [8]:
perturbation_summary = get_perturbation_summary(adata_perturbed)