## Tutorial notebook of training with CRISP

In this notebook, we take NeurIPS as example to show how to train CRISP with measured perturbation single cell RNA-seq dataset. \
In practice, considering the large scale amount of single cell training data and high dimensional gene features, we recommend user to train it with shell script, only 1 GPU node is enough.

In [1]:
from CRISP.trainer import Trainer
import scanpy as sc
import torch
import pandas as pd

In [3]:
adata = sc.read('adata_pp_filtered_scFM_resplit.h5ad')

In [4]:
dataset_params = {
    'perturbation_key':'condition',
    'dose_key': 'dose_val',
    'smiles_key': 'SMILES',
    'celltype_key': 'cell_type',
    'FM_key': 'X_scGPT',
    'control_key': 'neg_control',
    'pc_cov': 'type_donor',
    'degs_key': "rank_genes_groups_cov",
    'pert_category': "cov_drug_name",
    'split_ood': True,
    'split_key': "split",
    'seed': 1327,
}

In [5]:
exp = Trainer()

In [6]:
exp.init_dataset(adata_obj=adata,**dataset_params)

  choice_control_mean = [group_dict[obs_df['pc_cov_split'][i]] for i in treated_index]
  for drug, smiles in dataset.obs.groupby(


In [7]:
chem_df = pd.read_parquet('../data/drug_embeddings/rdkit2D_embedding_lincs_nips.parquet')

In [8]:
exp.init_drug_embedding(chem_model='rdkit',chem_df=chem_df)

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
exp.init_model(
    hparams='',
    seed=1337,
)
exp.load_train()

In [10]:
train_params = {
    'checkpoint_freq': 51, # checkpoint frequency to run evaluate
    'num_epochs': 51,
    'max_minutes': 1000,
    'save_dir': '../experiments/results/nips_test',
}

In [11]:
results = exp.train(**train_params)

100%|██████████| 51/51 [39:34<00:00, 46.56s/it]


In [12]:
results['ood']

[{'r2score': 0.9322444459834656,
  'r2score_de': 0.23656416248965573,
  'pearson': 0.9696602225764862,
  'pearson_de': 0.44427059194645313,
  'mse': 0.07321457,
  'mse_de': 0.45216295,
  'pearson_delta': 0.40612612374275175,
  'pearson_delta_de': 0.6822232767848168,
  'sinkhorn_de': 16.342594424625496}]