In [1]:
import numpy as np
import scanpy as sc
import torch
import random
import cpa
import os
from icecream import ic
import anndata as ad
import gc

def set_seed(seed):
    ic('Setting seed to', seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Global seed set to 0


In [2]:
adata_path = '/data/Experiments/Benchmark/SCDISENTANGLE_REPRODUCE/Datasets/preprocessed_datasets/kang.h5ad'
cov_key = "cell_type"
cond_key = "condition"
ood_cov = "B"
control_name = "control"
stim_name = "stimulated"
vars_to_predict = ['stimulated', 'control']
categorical_attributes = ['cell_type'] # Should be in this order: cond, cov
seed_nb = 1
device_nb = 1

In [3]:
adata = sc.read_h5ad(adata_path)

try:
    adata.X = adata.X.toarray()
except:
    print('Data is already array')

# Set seed
set_seed(seed_nb)

Data is already array


[38;5;247mic[39m[38;5;245m|[39m[38;5;245m [39m[38;5;36m'[39m[38;5;36mSetting seed to[39m[38;5;36m'[39m[38;5;245m,[39m[38;5;245m [39m[38;5;247mseed[39m[38;5;245m:[39m[38;5;245m [39m[38;5;36m1[39m


In [4]:
# Compute data stats (train / val / ood)
_train_adata = adata[adata.obs[f'split_{stim_name}_{ood_cov}'] == 'train'].copy()
_val_adata = adata[adata.obs[f'split_{stim_name}_{ood_cov}'] == 'val'].copy()
_ood_adata = adata[adata.obs[f'split_{stim_name}_{ood_cov}'] == 'ood'].copy()

# Compute median
_sums = _train_adata.X.sum(axis=1, keepdims=True)
data_median = np.median(_sums)

In [5]:
# Dose col
adata.obs['dose'] = adata.obs[cond_key].apply(lambda x: '+'.join(['1.0' for _ in x.split('+')]))

# cov_cond col
cov_cond = []
for i in range(adata.shape[0]):
    _name = adata.obs[cov_key][i] + '_' + adata.obs[cond_key][i]
    cov_cond.append(_name)
adata.obs['cov_cond'] = cov_cond

In [6]:
adata.obs[f'split_{stim_name}_{ood_cov}'] = [x.replace('val', 'valid') for x in adata.obs[f'split_{stim_name}_{ood_cov}']]

# Setup anndata
cpa.CPA.setup_anndata(
        adata, 
        perturbation_key=cond_key,
        control_group=control_name,
        dosage_key='dose',
        categorical_covariate_keys=categorical_attributes,
        is_count_data=True,
        deg_uns_key=f'rank_genes_groups_{cond_key}',
        deg_uns_cat_key='cov_cond',
        max_comb_len=1,
        )

# Model params
model_params = {
        "n_latent": 64,
        "recon_loss": "nb",
        "doser_type": "linear",
        "n_hidden_encoder": 128,
        "n_layers_encoder": 2,
        "n_hidden_decoder": 512,
        "n_layers_decoder": 2,
        "use_batch_norm_encoder": True,
        "use_layer_norm_encoder": False,
        "use_batch_norm_decoder": False,
        "use_layer_norm_decoder": True,
        "dropout_rate_encoder": 0.0,
        "dropout_rate_decoder": 0.1,
        "variational": False,
        "seed": seed_nb,
        }

# Train params
trainer_params = {
        "n_epochs_kl_warmup": None,
        "n_epochs_pretrain_ae": 30,
        "n_epochs_adv_warmup": 50,
        "n_epochs_mixup_warmup": 0,
        "mixup_alpha": 0.0,
        "adv_steps": None,
        "n_hidden_adv": 64,
        "n_layers_adv": 3,
        "use_batch_norm_adv": True,
        "use_layer_norm_adv": False,
        "dropout_rate_adv": 0.3,
        "reg_adv": 20.0,
        "pen_adv": 5.0,
        "lr": 0.0003,
        "wd": 4e-07,
        "adv_lr": 0.0003,
        "adv_wd": 4e-07,
        "adv_loss": "cce",
        "doser_lr": 0.0003,
        "doser_wd": 4e-07,
        "do_clip_grad": True,
        "gradient_clip_value": 1.0,
        "step_size_lr": 10,
    }

100%|██████████| 13576/13576 [00:00<00:00, 125980.94it/s]
100%|██████████| 13576/13576 [00:00<00:00, 1305168.04it/s]
100%|██████████| 16/16 [00:00<00:00, 2148.45it/s]
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


In [7]:
# CPA model
model = cpa.CPA(
        adata=adata, 
        split_key=f'split_{stim_name}_{ood_cov}',
        train_split='train',
        valid_split='valid',
        test_split='ood',
        **model_params,
               )

Global seed set to 1


In [8]:
model.train(
        max_epochs=2000,
        use_gpu=True, 
        batch_size=512,
        plan_kwargs=trainer_params,
        early_stopping_patience=5,
        check_val_every_n_epoch=5,
        save_path=None,
        )

100%|██████████| 2/2 [00:00<00:00, 22.41it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Epoch 5/2000:   0%|          | 4/2000 [00:08<1:05:18,  1.96s/it, v_num=1, recon=2.39e+3, r2_mean=0.853, adv_loss=1.45, acc_pert=0.884, acc_cell_type=0.678]


Epoch 00004: cpa_metric reached. Module best state updated.


Epoch 10/2000:   0%|          | 9/2000 [00:18<1:04:52,  1.96s/it, v_num=1, recon=2.34e+3, r2_mean=0.922, adv_loss=1.12, acc_pert=0.898, acc_cell_type=0.731, val_recon=2.39e+3, disnt_basal=1.21, disnt_after=1.62, val_r2_mean=0.879, val_KL=nan]


Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.1057463768855293
disnt_after = 1.6246264677307567
val_r2_mean = 0.9020187086529203
val_r2_var = 0.0433932794464959
Epoch 15/2000:   1%|          | 14/2000 [00:28<1:04:19,  1.94s/it, v_num=1, recon=2.32e+3, r2_mean=0.929, adv_loss=1.11, acc_pert=0.872, acc_cell_type=0.74, val_recon=2.37e+3, disnt_basal=1.11, disnt_after=1.62, val_r2_mean=0.902, val_KL=nan] 


Epoch 00014: cpa_metric reached. Module best state updated.


Epoch 20/2000:   1%|          | 19/2000 [00:37<1:02:28,  1.89s/it, v_num=1, recon=2.3e+3, r2_mean=0.933, adv_loss=1.07, acc_pert=0.867, acc_cell_type=0.746, val_recon=2.35e+3, disnt_basal=1.04, disnt_after=1.61, val_r2_mean=0.911, val_KL=nan] 


Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.0030831421746094
disnt_after = 1.6152865924943178
val_r2_mean = 0.9080110059844122
val_r2_var = 0.18406146499845719
Epoch 25/2000:   1%|          | 24/2000 [00:47<1:01:26,  1.87s/it, v_num=1, recon=2.28e+3, r2_mean=0.936, adv_loss=1.04, acc_pert=0.864, acc_cell_type=0.749, val_recon=2.34e+3, disnt_basal=1, disnt_after=1.62, val_r2_mean=0.908, val_KL=nan]  


Epoch 00024: cpa_metric reached. Module best state updated.


Epoch 30/2000:   1%|▏         | 29/2000 [00:56<1:01:43,  1.88s/it, v_num=1, recon=2.26e+3, r2_mean=0.934, adv_loss=1.01, acc_pert=0.866, acc_cell_type=0.757, val_recon=2.35e+3, disnt_basal=1, disnt_after=1.61, val_r2_mean=0.909, val_KL=nan]
disnt_basal = 1.0071393980017698
disnt_after = 1.607313817542012
val_r2_mean = 0.9119944082366095
val_r2_var = 0.22497564554214478
Epoch 35/2000:   2%|▏         | 34/2000 [01:06<1:04:04,  1.96s/it, v_num=1, recon=2.25e+3, r2_mean=0.936, adv_loss=1.11, acc_pert=0.832, acc_cell_type=0.748, val_recon=2.36e+3, disnt_basal=1.01, disnt_after=1.61, val_r2_mean=0.912, val_KL=nan]


Epoch 00034: cpa_metric reached. Module best state updated.


Epoch 40/2000:   2%|▏         | 39/2000 [01:17<1:06:29,  2.03s/it, v_num=1, recon=2.24e+3, r2_mean=0.937, adv_loss=1.58, acc_pert=0.678, acc_cell_type=0.678, val_recon=2.36e+3, disnt_basal=0.94, disnt_after=1.61, val_r2_mean=0.914, val_KL=nan]


Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 0.8386064946139883
disnt_after = 1.6117723895209763
val_r2_mean = 0.9111208452118768
val_r2_var = 0.2443843417697483
Epoch 45/2000:   2%|▏         | 44/2000 [01:27<1:06:22,  2.04s/it, v_num=1, recon=2.23e+3, r2_mean=0.937, adv_loss=2.37, acc_pert=0.55, acc_cell_type=0.427, val_recon=2.36e+3, disnt_basal=0.839, disnt_after=1.61, val_r2_mean=0.911, val_KL=nan] 


Epoch 00044: cpa_metric reached. Module best state updated.


Epoch 50/2000:   2%|▏         | 49/2000 [01:37<1:03:49,  1.96s/it, v_num=1, recon=2.22e+3, r2_mean=0.939, adv_loss=2.29, acc_pert=0.542, acc_cell_type=0.431, val_recon=2.35e+3, disnt_basal=0.755, disnt_after=1.61, val_r2_mean=0.908, val_KL=nan]


Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 0.7175031014563457
disnt_after = 1.6048649260759908
val_r2_mean = 0.9032185475031534
val_r2_var = 0.2476418972015381
Epoch 55/2000:   3%|▎         | 54/2000 [01:47<1:03:44,  1.97s/it, v_num=1, recon=2.21e+3, r2_mean=0.938, adv_loss=2.3, acc_pert=0.532, acc_cell_type=0.424, val_recon=2.35e+3, disnt_basal=0.718, disnt_after=1.6, val_r2_mean=0.903, val_KL=nan]  


Epoch 00054: cpa_metric reached. Module best state updated.


Epoch 60/2000:   3%|▎         | 59/2000 [01:57<1:03:55,  1.98s/it, v_num=1, recon=2.2e+3, r2_mean=0.938, adv_loss=2.32, acc_pert=0.518, acc_cell_type=0.411, val_recon=2.37e+3, disnt_basal=0.713, disnt_after=1.6, val_r2_mean=0.907, val_KL=nan] 


Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 0.7077656592035545
disnt_after = 1.6023623643600733
val_r2_mean = 0.9062639792760213
val_r2_var = 0.2586240397559272
Epoch 65/2000:   3%|▎         | 64/2000 [02:07<1:02:29,  1.94s/it, v_num=1, recon=2.19e+3, r2_mean=0.939, adv_loss=2.33, acc_pert=0.52, acc_cell_type=0.406, val_recon=2.36e+3, disnt_basal=0.708, disnt_after=1.6, val_r2_mean=0.906, val_KL=nan] 


Epoch 00064: cpa_metric reached. Module best state updated.


Epoch 70/2000:   3%|▎         | 69/2000 [02:17<1:02:52,  1.95s/it, v_num=1, recon=2.18e+3, r2_mean=0.939, adv_loss=2.33, acc_pert=0.515, acc_cell_type=0.402, val_recon=2.37e+3, disnt_basal=0.691, disnt_after=1.61, val_r2_mean=0.905, val_KL=nan]


Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 0.6845819798205207
disnt_after = 1.6064080426682281
val_r2_mean = 0.9064804103639391
val_r2_var = 0.26037815279430815
Epoch 75/2000:   4%|▎         | 74/2000 [02:27<1:03:01,  1.96s/it, v_num=1, recon=2.18e+3, r2_mean=0.94, adv_loss=2.35, acc_pert=0.509, acc_cell_type=0.39, val_recon=2.38e+3, disnt_basal=0.685, disnt_after=1.61, val_r2_mean=0.906, val_KL=nan]  


Epoch 00074: cpa_metric reached. Module best state updated.


Epoch 80/2000:   4%|▍         | 79/2000 [02:36<1:00:30,  1.89s/it, v_num=1, recon=2.17e+3, r2_mean=0.94, adv_loss=2.35, acc_pert=0.512, acc_cell_type=0.389, val_recon=2.36e+3, disnt_basal=0.678, disnt_after=1.6, val_r2_mean=0.907, val_KL=nan] 
disnt_basal = 0.6791487789065714
disnt_after = 1.5970721834063042
val_r2_mean = 0.906959366798401
val_r2_var = 0.268030223581526
Epoch 85/2000:   4%|▍         | 84/2000 [02:46<1:01:47,  1.94s/it, v_num=1, recon=2.16e+3, r2_mean=0.94, adv_loss=2.35, acc_pert=0.508, acc_cell_type=0.388, val_recon=2.37e+3, disnt_basal=0.679, disnt_after=1.6, val_r2_mean=0.907, val_KL=nan] 


Epoch 00084: cpa_metric reached. Module best state updated.


Epoch 90/2000:   4%|▍         | 89/2000 [02:56<1:01:42,  1.94s/it, v_num=1, recon=2.15e+3, r2_mean=0.94, adv_loss=2.35, acc_pert=0.511, acc_cell_type=0.392, val_recon=2.38e+3, disnt_basal=0.675, disnt_after=1.6, val_r2_mean=0.906, val_KL=nan]


Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 0.6754778750930068
disnt_after = 1.6080992535611334
val_r2_mean = 0.9019965476459927
val_r2_var = 0.2693125393655565
Epoch 95/2000:   5%|▍         | 94/2000 [03:06<1:00:55,  1.92s/it, v_num=1, recon=2.14e+3, r2_mean=0.941, adv_loss=2.35, acc_pert=0.509, acc_cell_type=0.382, val_recon=2.37e+3, disnt_basal=0.675, disnt_after=1.61, val_r2_mean=0.902, val_KL=nan]


Epoch 00094: cpa_metric reached. Module best state updated.


Epoch 100/2000:   5%|▍         | 99/2000 [03:16<1:01:42,  1.95s/it, v_num=1, recon=2.14e+3, r2_mean=0.94, adv_loss=2.34, acc_pert=0.516, acc_cell_type=0.393, val_recon=2.38e+3, disnt_basal=0.67, disnt_after=1.6, val_r2_mean=0.907, val_KL=nan]  
disnt_basal = 0.682439370340641
disnt_after = 1.599729617004791
val_r2_mean = 0.9044288860427009
val_r2_var = 0.2814134902424283
Epoch 110/2000:   5%|▌         | 109/2000 [03:35<1:01:03,  1.94s/it, v_num=1, recon=2.13e+3, r2_mean=0.941, adv_loss=2.34, acc_pert=0.518, acc_cell_type=0.392, val_recon=2.39e+3, disnt_basal=0.671, disnt_after=1.61, val_r2_mean=0.905, val_KL=nan]


Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 0.6610552642329962
disnt_after = 1.6026633886249495
val_r2_mean = 0.9062826977835762
val_r2_var = 0.28635290066401164
Epoch 115/2000:   6%|▌         | 114/2000 [03:45<1:02:00,  1.97s/it, v_num=1, recon=2.12e+3, r2_mean=0.941, adv_loss=2.33, acc_pert=0.521, acc_cell_type=0.393, val_recon=2.38e+3, disnt_basal=0.661, disnt_after=1.6, val_r2_mean=0.906, val_KL=nan] 


Epoch 00114: cpa_metric reached. Module best state updated.


Epoch 120/2000:   6%|▌         | 119/2000 [03:55<1:01:27,  1.96s/it, v_num=1, recon=2.11e+3, r2_mean=0.942, adv_loss=2.33, acc_pert=0.52, acc_cell_type=0.39, val_recon=2.38e+3, disnt_basal=0.665, disnt_after=1.6, val_r2_mean=0.907, val_KL=nan]  
disnt_basal = 0.667226100858314
disnt_after = 1.5996180433888565
val_r2_mean = 0.9063876854048835
val_r2_var = 0.28664151827494305
Epoch 130/2000:   6%|▋         | 129/2000 [04:15<1:01:55,  1.99s/it, v_num=1, recon=2.1e+3, r2_mean=0.941, adv_loss=2.32, acc_pert=0.529, acc_cell_type=0.398, val_recon=2.4e+3, disnt_basal=0.678, disnt_after=1.6, val_r2_mean=0.9, val_KL=nan]    
disnt_basal = 0.670774901410066
disnt_after = 1.595640923989924
val_r2_mean = 0.9066732000736963
val_r2_var = 0.2933745922550322
Epoch 140/2000:   7%|▋         | 139/2000 [04:35<1:00:05,  1.94s/it, v_num=1, recon=2.09e+3, r2_mean=0.943, adv_loss=2.32, acc_pert=0.521, acc_cell_type=0.396, val_recon=2.39e+3, disnt_basal=0.669, disnt_after=1.6, val_r2_mean=0.907, val_KL=nan]
di

In [9]:
adata_subset = adata[(adata.obs[cond_key] == control_name) & (adata.obs[cov_key] == ood_cov) & (adata.obs[f'split_{stim_name}_{ood_cov}'] == 'train')].copy()

In [27]:
var_to_predict = 'stimulated'

In [None]:
adata_pred = adata_subset.copy()
adata_pred.obs[f'{cond_key}_org'] = adata_pred.obs[cond_key].copy()
adata_pred.obs[cond_key] = [var_to_predict] * adata_pred.shape[0]
        
adata_pred.obs[f'{cond_key}_pred'] = [var_to_predict] * adata_pred.shape[0]
adata_pred.obs[f'{cond_key}_pred'] = adata_pred.obs[f'{cond_key}_pred'].astype('category')
        
cov_cond_pred = []
for i in range(adata_pred.shape[0]):
    _name = adata_pred.obs[cov_key][i] + '_' + adata_pred.obs[f'{cond_key}_pred'][i]
    cov_cond_pred.append(_name)
            
adata_pred.obs['cov_cond'] = cov_cond_pred

cpa.CPA.setup_anndata(
                      adata_pred, 
                      perturbation_key=cond_key,
                      control_group=control_name,
                      dosage_key='dose',
                      categorical_covariate_keys=categorical_attributes,
                      is_count_data=True,
                      deg_uns_key=f'rank_genes_groups_{cond_key}',
                      deg_uns_cat_key='cov_cond',
                      max_comb_len=1,
                    )
                    
model.predict(adata_pred, batch_size=2048)
adata_pred.X = adata_pred.obsm['CPA_pred'].copy()
del adata_pred.obsm['CPA_pred']

adata_pred = sc.AnnData(
    X=adata_pred.X, 
    obs=adata_pred.obs
)

100%|██████████| 542/542 [00:00<00:00, 118506.63it/s]
100%|██████████| 542/542 [00:00<00:00, 1209852.46it/s]
100%|██████████| 1/1 [00:00<00:00, 2035.08it/s]


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                                        


100%|██████████| 1/1 [00:00<00:00, 14.99it/s]


In [29]:
adata_pred.X.max()

386.54196

In [None]:
adata_pred