## Tutorial for perturbation API

In [1]:
import numpy as np
import scanpy as sc 

from spatial_gnn.api.perturbation_api import (
    train_perturbation_model,
    predict_perturbation_effects, 
    get_perturbation_summary, 
    visualize_perturbation_effects, 
    create_perturbation_mask
)

  from pkg_resources import DistributionNotFound, get_distribution


In [2]:
train_data_path = "/oak/stanford/groups/akundaje/abuen/spatial/spatial-gnn/data/raw/exercise.h5ad"
test_data_path = "/oak/stanford/groups/akundaje/abuen/spatial/spatial-gnn/data/raw/exercise.h5ad" 

### Define and train perturbation model 

This requires defining the training arguments and supplying an anndata for the training dataset.

In [3]:
training_args = {
    "k_hop": 2,
    "augment_hop": 2,
    "center_celltypes": "T cell,NSC,Pericyte",
    "node_feature": "expression",
    "inject_feature": "None",
    "debug": True,
    "debug_subset_size": 10,
    "num_cells_per_ct_id": 100,
    "epochs": 10,
}

In [None]:
print("=== Training a new perturbation model ===")
model, model_config, model_path = train_perturbation_model(
    adata_path=train_data_path,
    exp_name="exercise",
    **training_args
)

=== Training a new perturbation model ===
Training new perturbation model from scratch...
Training on device: cuda
Starting dataset processing at 00:11:42
Gene processing: 4.202s
Gene saving: 0.133s

Processing file 1/1: exercise.h5ad
  File loading: 3.653s
  Cell type filtering: 0.368s
  Normalizing data


  view_to_actual(adata)


  Normalization: 5.699s
  Missing gene handling: 3.106s
  Gene ordering: 6.837s
  Sample ID processing: 7.507s
  Processing 3 samples
    Sample 1/3: OC1
      Sample subsetting: 0.025s


  obj[key] = data


      Spatial graph building: 29.413s
      PyG conversion: 29.424s
      Node label construction: 0.049s
      Center cell selection: 0.004s
      Selected 208 center cells
      Subgraph extraction: 1.343s
      Created 208 subgraphs
      Subgraph saving: 0.360s
      Augmentation setup: 0.754s
      Augmentation processing: 26.451s
      Created 3346 augmented subgraphs
    Sample OC1 completed in 58.412s
    Sample 2/3: OC4
      Sample subsetting: 0.043s


  obj[key] = data


      Spatial graph building: 32.270s
      PyG conversion: 32.278s
      Node label construction: 0.068s
      Center cell selection: 0.003s
      Selected 217 center cells
      Subgraph extraction: 2.371s
      Created 217 subgraphs
      Subgraph saving: 0.337s
      Augmentation setup: 0.804s
      Augmentation processing: 26.245s
      Created 3155 augmented subgraphs
    Sample OC4 completed in 62.150s
    Sample 3/3: OE3
      Sample subsetting: 0.037s


  obj[key] = data


      Spatial graph building: 23.726s
      PyG conversion: 23.738s
      Node label construction: 0.060s
      Center cell selection: 0.003s
      Selected 174 center cells
      Subgraph extraction: 1.121s
      Created 174 subgraphs
      Subgraph saving: 0.230s
      Augmentation setup: 0.597s
      Augmentation processing: 21.599s
      Created 2946 augmented subgraphs
    Sample OE3 completed in 47.385s
  File exercise.h5ad completed in 188.280s
  Total subgraphs created: 10046

Dataset processing completed in 192.620s (3.2 minutes)
Total subgraphs created: 10046
Finished processing test dataset
Starting dataset processing at 00:14:55
Gene processing: 3.590s
Gene saving: 0.034s

Processing file 1/1: exercise.h5ad
  File loading: 3.372s
  Cell type filtering: 0.306s
  Normalizing data


  view_to_actual(adata)


  Normalization: 5.593s
  Missing gene handling: 3.108s
  Gene ordering: 6.648s
  Sample ID processing: 7.202s
  Processing 9 samples
    Sample 1/9: OC2
      Sample subsetting: 0.026s


KeyboardInterrupt: 

: 

### Inference with perturbation model

Define set of perturbations in the form of a dictionary mapping `cell type` → `gene name` → `multiplier`. For instance, entry `'T cell': {'Igf2': 0.0},` indicates knockout of IGF2 in all T-cells in the input dataset.

In [None]:
# Define perturbations
perturbation_dict = {
    'T cell': {'Igf2': 0.0},  
    'NSC': {'Sox9': 2.0},         
    'Pericyte': {'Ccl4': 0.5}    
}

In [None]:
# Save perturbation mask to anndata
test_adata = sc.read_h5ad(test_data_path)
test_data_path_perturbed = create_perturbation_mask(test_adata, perturbation_dict, save_path=test_data_path)

In [None]:
print("\n=== Predicting perturbation effects ===")
adata_perturbed = predict_perturbation_effects(
    adata_path=test_data_path_perturbed,
    model_path=model_path,
    exp_name="aging_sagittal",
    perturbation_dict=perturbation_dict,
    perturbation_mask_key="perturbation_mask"
)