## Tutorial for perturbation API

In [None]:
import numpy as np
import scanpy as sc 

from spatial_gnn.api.perturbation_api import (
    train_perturbation_model,
    predict_perturbation_effects, 
    get_perturbation_summary, 
    visualize_perturbation_effects, 
    create_perturbation_mask
)

In [None]:
train_data_path = "/oak/stanford/groups/akundaje/abuen/spatial/spatial-gnn/data/raw/exercise.h5ad"
test_data_path = "/oak/stanford/groups/akundaje/abuen/spatial/spatial-gnn/data/raw/exercise.h5ad" 

### Define and train perturbation model 

This requires defining the training arguments and supplying an anndata for the training dataset.

In [None]:
training_args = {
    "k_hop": 2,
    "augment_hop": 2,
    "center_celltypes": "T cell,NSC,Pericyte",
    "node_feature": "expression",
    "inject_feature": "None",
    "debug": True,
    "debug_subset_size": 10,
    "num_cells_per_ct_id": 100,
    "epochs": 10,
}

In [None]:
print("=== Training a new perturbation model ===")
model, model_config, model_path = train_perturbation_model(
    adata_path=train_data_path,
    exp_name="exercise",
    **training_args
)

### Inference with perturbation model

Define set of perturbations in the form of a dictionary mapping `cell type` → `gene name` → `multiplier`. For instance, entry `'T cell': {'Igf2': 0.0},` indicates knockout of IGF2 in all T-cells in the input dataset.

In [None]:
# Define perturbations
perturbation_dict = {
    'T cell': {'Igf2': 0.0},  
    'NSC': {'Sox9': 2.0},         
    'Pericyte': {'Ccl4': 0.5}    
}

In [None]:
# Save perturbation mask to anndata
test_adata = sc.read_h5ad(test_data_path)
test_data_path_perturbed = create_perturbation_mask(test_adata, perturbation_dict, save_path=test_data_path)

In [None]:
print("\n=== Predicting perturbation effects ===")
adata_perturbed = predict_perturbation_effects(
    adata_path=test_data_path_perturbed,
    model_path=model_path,
    exp_name="aging_sagittal",
    perturbation_dict=perturbation_dict,
    perturbation_mask_key="perturbation_mask"
)