In [11]:
import os
import cpa
import scanpy as sc
import pandas as pd

In [12]:
sc.settings.set_figure_params(dpi=100)

In [13]:
path = "/home/katerchen/Code/data/open-problems-single-cell-perturbations/"
adata_obs = pd.read_csv(path + "adata_obs_meta.csv")
id_map = pd.read_csv(path + "id_map.csv")
multiome_obs_meta = pd.read_csv(path + "multiome_obs_meta.csv")
multiome_var_meta = pd.read_csv(path + "multiome_var_meta.csv")

de_train = pd.read_parquet(path + "de_train.parquet")

de_train_meta = de_train.iloc[:, :5]

de_train_exp = de_train.drop(columns = ["cell_type", "sm_name", "sm_lincs_id", "SMILES", "control"])

de_train_meta.head()

Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control
0,NK cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False
1,T cells CD4+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False
2,T cells CD8+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False
3,T regulatory cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False
4,NK cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False


Next, we just replace `de_train_exp.X` with raw counts to be able to train CPA with Negative Binomial (aka NB) loss.

In [14]:
# turn de_train_exp into scanpy object
# adata = sc.AnnData(de_train_exp)
# adata.obs = de_train_meta

In [None]:
adata = pd.read_parquet(path + "adata_train.parquet")

In [15]:
adata.X

array([[ 0.10472047, -0.07752421, -1.62559604, ...,  0.03412678,
         0.22137655,  0.36875538],
       [ 0.91595324, -0.88438038,  0.37183448, ...,  0.70477983,
         1.09670189, -0.86988664],
       [-0.38772076, -0.30537826,  0.56777737, ...,  0.41576793,
         0.07843919, -0.25936541],
       ...,
       [ 0.33816764, -0.10907872,  0.27018167, ..., -0.20907666,
         0.38975144, -0.33708204],
       [ 0.10113796, -0.40972434, -0.60629187, ..., -0.02968417,
         0.00550565, -1.73311173],
       [-0.75711627,  0.08591048, -0.73002496, ..., -0.60328012,
        -0.09804148, -0.75068134]])

In [16]:
adata.obs['cell_type'].value_counts()

cell_type
NK cells              146
T cells CD4+          146
T regulatory cells    146
T cells CD8+          142
B cells                17
Myeloid cells          17
Name: count, dtype: int64

In [17]:
# replace True with 'ctrl' and False with 'stimulated' in adata.obs['control']
adata.obs['control'] = adata.obs['control'].replace([True, False], ['ctrl', 'stimulated'])

In [18]:
# setting up dummy variable for dosage
adata.obs['dose'] = adata.obs['control'].apply(lambda x: '+'.join(['1.0' for _ in x.split('+')]))

In [19]:
adata.obs['control'].value_counts()

control
stimulated    602
ctrl           12
Name: count, dtype: int64

In [10]:
cpa.CPA.setup_anndata(adata, 
                      perturbation_key='control',
                      control_group='ctrl',
                      dosage_key='dose',
                      categorical_covariate_keys=['cell_type'],
                      is_count_data=False,
                      deg_uns_key='rank_genes_groups_cov',
                      deg_uns_cat_key='cov_cond',
                      max_comb_len=1,
                     )

100%|██████████| 614/614 [00:00<00:00, 89086.16it/s]
100%|██████████| 614/614 [00:00<00:00, 1384571.32it/s]


KeyError: 'cov_cond'

In [None]:
model_params = {
    "n_latent": 64,
    "recon_loss": "nb",
    "doser_type": "linear",
    "n_hidden_encoder": 128,
    "n_layers_encoder": 2,
    "n_hidden_decoder": 512,
    "n_layers_decoder": 2,
    "use_batch_norm_encoder": True,
    "use_layer_norm_encoder": False,
    "use_batch_norm_decoder": False,
    "use_layer_norm_decoder": True,
    "dropout_rate_encoder": 0.0,
    "dropout_rate_decoder": 0.1,
    "variational": False,
    "seed": 6977,
}

trainer_params = {
    "n_epochs_kl_warmup": None,
    "n_epochs_pretrain_ae": 30,
    "n_epochs_adv_warmup": 50,
    "n_epochs_mixup_warmup": 0,
    "mixup_alpha": 0.0,
    "adv_steps": None,
    "n_hidden_adv": 64,
    "n_layers_adv": 3,
    "use_batch_norm_adv": True,
    "use_layer_norm_adv": False,
    "dropout_rate_adv": 0.3,
    "reg_adv": 20.0,
    "pen_adv": 5.0,
    "lr": 0.0003,
    "wd": 4e-07,
    "adv_lr": 0.0003,
    "adv_wd": 4e-07,
    "adv_loss": "cce",
    "doser_lr": 0.0003,
    "doser_wd": 4e-07,
    "do_clip_grad": True,
    "gradient_clip_value": 1.0,
    "step_size_lr": 10,
}

In [None]:
model = cpa.CPA(de_train_exp=de_train_exp, 
                split_key='split_B',
                train_split='train',
                valid_split='valid',
                test_split='ood',
                **model_params,
               )