## Tutorial notebook of training with CRISP

In this notebook, we take NeurIPS as example to show how to train CRISP with measured perturbation single cell RNA-seq dataset. \
In practice, considering the large scale amount of single cell training data and high dimensional gene features, we recommend user to train it with terminal script, only 1 GPU node is enough.

In [1]:
from CRISP.trainer import Trainer
import scanpy as sc
import torch
import pandas as pd



In [2]:
adata = sc.read('PATH/TO/NEURIPS/DATASET')

In [3]:
dataset_params = {
    'perturbation_key':'condition',
    'dose_key': 'dose_val',
    'smiles_key': 'SMILES',
    'celltype_key': 'cell_type',
    'FM_key': 'X_scGPT',
    'control_key': 'control',
    'pc_cov': 'cell_type',
    'degs_key': "rank_genes_groups_cov",
    'pert_category': "cov_drug_name",
    'split_ood': True,
    'split_key': "split",
    'seed': 1327,
}

In [6]:
exp = Trainer()

In [7]:
exp.init_dataset(adata_obj=adata,**dataset_params)

In [8]:
chem_df = pd.read_parquet('../data/drug_embeddings/rdkit2D_embedding_lincs_nips.parquet')

In [9]:
exp.init_drug_embedding(chem_model='rdkit',chem_df=chem_df)

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
exp.init_model(
    hparams='',
    seed=1337,
)
exp.load_train()

In [11]:
train_params = {
    'checkpoint_freq': 25, # checkpoint frequency to run evaluate
    'num_epochs': 51,
    'max_minutes': 1000,
    'save_dir': '../experiments/results/test',
}

In [None]:
results = exp.train(**train_params)