In [2]:
import numpy as np
import scanpy as sc

un_treatment = 'DMSO_48hr'
pert_treatment = 'Tram_48hr'

# GEARS

In [None]:
import pickle

import gears

In [None]:
# Read and format for GEARS
adata = sc.read_h5ad('../plots/drugseries/tram_data.h5ad')
# adata = adata[adata.obs['Training']]  # Filter to training
adata.obs['cell_type'] = adata.obs['cell_line']
conditions = {
    un_treatment: 'ctrl',
    pert_treatment: 'MAP2K1+MAP2K2'}
adata = adata[adata.obs['treatment'].isin(list(conditions.keys()))]
adata.obs['condition'] = adata.obs['treatment'].cat.rename_categories(conditions)

# Preprocess
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
# sc.pp.highly_variable_genes(adata, n_top_genes=5000, subset=True)  # TODO: Maybe readd
pert_data = gears.PertData('../plots/drugseries')
pert_data.new_data_process(dataset_name='tram_gears', adata=adata)
pert_data.load(data_path = '../plots/drugseries/tram_gears')

# Split and get dataloader
adata.obs['split'] = 'train'
set2conditions = dict(adata.obs.groupby('split').agg({'condition': lambda x: x}).condition)
set2conditions = {i: j.unique().tolist() for i,j in set2conditions.items()}
set2conditions['val'] = set2conditions['test'] = set2conditions['val']  # We limit val ourselves
split_fname = '../plots/drugseries/tram_gears/splits/custom.pkl'
pickle.dump(set2conditions, open(split_fname, 'wb'))
# {'train': ['ctrl', 'MAP2K1+MAP2K2'], 'val': ['ctrl', 'MAP2K1+MAP2K2']}
pert_data.prepare_split(split='custom', split_dict_path=split_fname, seed=42)
pert_data.get_dataloader(batch_size=32, test_batch_size=128)

Found local copy...
Found local copy...
Creating pyg object for each cell in the data...
Creating dataset file...
100%|██████████| 2/2 [01:03<00:00, 31.80s/it]
Done!
Saving new dataset pyg object at ../plots/drugseries/tram_gears/data_pyg/cell_graphs.pkl
Done!
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
[]
Local copy of pyg dataset is detected. Loading...
Done!
Creating dataloaders....
Done!


In [None]:
# Train model
gears_model = gears.GEARS(pert_data, device='cuda:0')
gears_model.model_initialize(hidden_size=32)
gears_model.train(epochs=20)  # KeyError: 'SKIN_ctrl_1' ??


Found local copy...
Start Training...
Epoch 1 Step 1 Train Loss: 0.3667
Epoch 1 Step 51 Train Loss: 0.4371
Epoch 1 Step 101 Train Loss: 0.4648
Epoch 1: Train Overall MSE: 204.3902 Validation Overall MSE: 204.3886. 
Train Top 20 DE MSE: 29814.7441 Validation Top 20 DE MSE: 29814.4883. 
Epoch 2 Step 1 Train Loss: 0.5038
Epoch 2 Step 51 Train Loss: 0.4167
Epoch 2 Step 101 Train Loss: 0.3986
Epoch 2: Train Overall MSE: 31.4396 Validation Overall MSE: 31.4395. 
Train Top 20 DE MSE: 3914.1646 Validation Top 20 DE MSE: 3914.3132. 
Epoch 3 Step 1 Train Loss: 0.4522
Epoch 3 Step 51 Train Loss: 0.4526
Epoch 3 Step 101 Train Loss: 0.3964
Epoch 3: Train Overall MSE: 1.3181 Validation Overall MSE: 1.3181. 
Train Top 20 DE MSE: 189.0000 Validation Top 20 DE MSE: 188.9983. 
Epoch 4 Step 1 Train Loss: 0.5512
Epoch 4 Step 51 Train Loss: 0.3640
Epoch 4 Step 101 Train Loss: 0.4906
Epoch 4: Train Overall MSE: 0.0038 Validation Overall MSE: 0.0038. 
Train Top 20 DE MSE: 0.7020 Validation Top 20 DE MSE: 0.7

KeyError: 'SKIN_ctrl_1'

In [81]:
# Save model
model_fname = '../plots/drugseries/tram_gears/model'
gears_model.save_model(model_fname)
gears_model.load_pretrained(model_fname)

In [86]:
# Predict perturbation
pert = gears_model.predict([['MAP2K1', 'MAP2K2']])['MAP2K1_MAP2K2']
np.save('../plots/drugseries/GEARS_perturbation.npy', pert)

# CPA

In [None]:
import cpa
import pandas as pd

Global seed set to 0


In [7]:
# Read and format
adata = sc.read_h5ad('../plots/drugseries/tram_data.h5ad')
adata.obs['split'] = 'val'
adata.obs.loc[adata.obs['Training'], 'split'] = 'train'
# # adata = adata[adata.obs['Training']]  # Filter to training
# adata.obs['cell_type'] = adata.obs['cell_line']
# conditions = {
#     un_treatment: 'ctrl',
#     pert_treatment: 'MAP2K1+MAP2K2'}
# adata = adata[adata.obs['treatment'].isin(list(conditions.keys()))]
# adata.obs['condition'] = adata.obs['treatment'].cat.rename_categories(conditions)

In [None]:
# Prepare data
cpa.CPA.setup_anndata(
    adata,
    perturbation_key='treatment',
    dosage_key=None,
    control_group='DMSO_48hr',
    batch_key=None,
    is_count_data=True,
    categorical_covariate_keys=['cell_line'],
    # deg_uns_key='rank_genes_groups_cov',
    # deg_uns_cat_key='cov_drug_dose',
    max_comb_len=2)




100%|██████████| 13713/13713 [00:00<00:00, 78401.47it/s]
100%|██████████| 13713/13713 [00:00<00:00, 976825.98it/s]


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


In [None]:
# Prepare and train model
model = cpa.CPA(
    adata=adata,
    split_key='split',
    train_split='train',
    valid_split='val',
    test_split='val',
    n_latent=32)
model.train(
    max_epochs=2000,
    use_gpu=True,
    batch_size=128,
    early_stopping_patience=10,
    check_val_every_n_epoch=5,
    save_path='../plots/drugseries/cpa/')
# model.save('../plots/drugseries/cpa/', overwrite=True)

100%|██████████| 11/11 [00:00<00:00, 141.33it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A10G') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 5/2000:   0%|          | 4/2000 [01:23<11:21:38, 20.49s/it, v_num=1, recon=1.62e+4, r2_mean=0.519, adv_loss=3.36, acc_pert=0.13, acc_cell_line=0.833]   


Epoch 00004: cpa_metric reached. Module best state updated.


Epoch 10/2000:   0%|          | 9/2000 [03:15<11:34:36, 20.93s/it, v_num=1, recon=1.51e+4, r2_mean=0.688, adv_loss=2.59, acc_pert=0.161, acc_cell_line=0.945, val_recon=1.63e+4, disnt_basal=0.318, disnt_after=0.363, val_r2_mean=0.55, val_KL=nan]


Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 0.3112568525924838
disnt_after = 0.36250775395560797
val_r2_mean = 0.6913695387053383
val_r2_var = 0.05775611917910067
Epoch 15/2000:   1%|          | 14/2000 [05:05<11:33:25, 20.95s/it, v_num=1, recon=1.48e+4, r2_mean=0.74, adv_loss=2.38, acc_pert=0.191, acc_cell_line=0.955, val_recon=1.49e+4, disnt_basal=0.311, disnt_after=0.363, val_r2_mean=0.691, val_KL=nan] 


Epoch 00014: cpa_metric reached. Module best state updated.


Epoch 20/2000:   1%|          | 19/2000 [07:28<13:29:55, 24.53s/it, v_num=1, recon=1.46e+4, r2_mean=0.768, adv_loss=2.29, acc_pert=0.198, acc_cell_line=0.96, val_recon=1.46e+4, disnt_basal=0.286, disnt_after=0.363, val_r2_mean=0.74, val_KL=nan] 


Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 0.27303373047757434
disnt_after = 0.3634999655369918
val_r2_mean = 0.7646110348775182
val_r2_var = 0.07040883465123543
Epoch 25/2000:   1%|          | 24/2000 [09:19<11:51:26, 21.60s/it, v_num=1, recon=1.45e+4, r2_mean=0.784, adv_loss=2.24, acc_pert=0.21, acc_cell_line=0.965, val_recon=1.44e+4, disnt_basal=0.273, disnt_after=0.363, val_r2_mean=0.765, val_KL=nan] 


Epoch 00024: cpa_metric reached. Module best state updated.


Epoch 30/2000:   1%|▏         | 29/2000 [11:10<11:34:32, 21.14s/it, v_num=1, recon=1.44e+4, r2_mean=0.794, adv_loss=2.19, acc_pert=0.219, acc_cell_line=0.965, val_recon=1.43e+4, disnt_basal=0.267, disnt_after=0.364, val_r2_mean=0.783, val_KL=nan]


Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 0.2602108952039197
disnt_after = 0.3648730073543875
val_r2_mean = 0.7959153337824356
val_r2_var = 0.08484860977697081
Epoch 35/2000:   2%|▏         | 34/2000 [13:01<11:33:24, 21.16s/it, v_num=1, recon=1.43e+4, r2_mean=0.804, adv_loss=2.16, acc_pert=0.233, acc_cell_line=0.965, val_recon=1.42e+4, disnt_basal=0.26, disnt_after=0.365, val_r2_mean=0.796, val_KL=nan] 


Epoch 00034: cpa_metric reached. Module best state updated.


Epoch 40/2000:   2%|▏         | 39/2000 [14:53<11:29:01, 21.08s/it, v_num=1, recon=1.43e+4, r2_mean=0.808, adv_loss=2.11, acc_pert=0.243, acc_cell_line=0.968, val_recon=1.41e+4, disnt_basal=0.255, disnt_after=0.361, val_r2_mean=0.801, val_KL=nan]


Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 0.2582869621774076
disnt_after = 0.36471693463474997
val_r2_mean = 0.8046112125118194
val_r2_var = 0.09395654123536168
Epoch 45/2000:   2%|▏         | 44/2000 [16:44<11:25:00, 21.01s/it, v_num=1, recon=1.42e+4, r2_mean=0.813, adv_loss=2.08, acc_pert=0.254, acc_cell_line=0.968, val_recon=1.4e+4, disnt_basal=0.258, disnt_after=0.365, val_r2_mean=0.805, val_KL=nan] 


Epoch 00044: cpa_metric reached. Module best state updated.


Epoch 50/2000:   2%|▏         | 49/2000 [18:35<11:23:27, 21.02s/it, v_num=1, recon=1.41e+4, r2_mean=0.816, adv_loss=2.07, acc_pert=0.258, acc_cell_line=0.965, val_recon=1.4e+4, disnt_basal=0.252, disnt_after=0.362, val_r2_mean=0.807, val_KL=nan]


Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 0.2503237272349801
disnt_after = 0.3648081277948857
val_r2_mean = 0.8158274635650385
val_r2_var = 0.10375423461500069
Epoch 60/2000:   3%|▎         | 59/2000 [22:25<12:11:09, 22.60s/it, v_num=1, recon=1.4e+4, r2_mean=0.822, adv_loss=2.07, acc_pert=0.285, acc_cell_line=0.947, val_recon=1.39e+4, disnt_basal=0.255, disnt_after=0.367, val_r2_mean=0.814, val_KL=nan] 


Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 0.2531884209607144
disnt_after = 0.36464219646787854
val_r2_mean = 0.8198364634525177
val_r2_var = 0.10875049354863371
Epoch 65/2000:   3%|▎         | 64/2000 [24:17<11:27:12, 21.30s/it, v_num=1, recon=1.4e+4, r2_mean=0.825, adv_loss=2.09, acc_pert=0.274, acc_cell_line=0.932, val_recon=1.38e+4, disnt_basal=0.253, disnt_after=0.365, val_r2_mean=0.82, val_KL=nan] 


Epoch 00064: cpa_metric reached. Module best state updated.


Epoch 70/2000:   3%|▎         | 69/2000 [26:16<12:35:16, 23.47s/it, v_num=1, recon=1.39e+4, r2_mean=0.827, adv_loss=2.09, acc_pert=0.285, acc_cell_line=0.919, val_recon=1.38e+4, disnt_basal=0.248, disnt_after=0.363, val_r2_mean=0.821, val_KL=nan]


Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 0.24505922166555863
disnt_after = 0.36129180477751865
val_r2_mean = 0.824166988901387
val_r2_var = 0.11027381792737168
Epoch 75/2000:   4%|▎         | 74/2000 [28:07<11:27:13, 21.41s/it, v_num=1, recon=1.39e+4, r2_mean=0.828, adv_loss=2.1, acc_pert=0.29, acc_cell_line=0.912, val_recon=1.37e+4, disnt_basal=0.245, disnt_after=0.361, val_r2_mean=0.824, val_KL=nan]  


Epoch 00074: cpa_metric reached. Module best state updated.


Epoch 80/2000:   4%|▍         | 79/2000 [29:58<11:14:33, 21.07s/it, v_num=1, recon=1.38e+4, r2_mean=0.83, adv_loss=2.08, acc_pert=0.297, acc_cell_line=0.911, val_recon=1.37e+4, disnt_basal=0.25, disnt_after=0.365, val_r2_mean=0.824, val_KL=nan] 


Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 0.24499543023264322
disnt_after = 0.3640407708967268
val_r2_mean = 0.8229760968128971
val_r2_var = 0.11203359267035765
Epoch 85/2000:   4%|▍         | 84/2000 [31:49<11:11:04, 21.01s/it, v_num=1, recon=1.38e+4, r2_mean=0.831, adv_loss=2.1, acc_pert=0.302, acc_cell_line=0.903, val_recon=1.37e+4, disnt_basal=0.245, disnt_after=0.364, val_r2_mean=0.823, val_KL=nan] 


Epoch 00084: cpa_metric reached. Module best state updated.


Epoch 90/2000:   4%|▍         | 89/2000 [33:40<11:09:33, 21.02s/it, v_num=1, recon=1.37e+4, r2_mean=0.832, adv_loss=2.07, acc_pert=0.304, acc_cell_line=0.909, val_recon=1.36e+4, disnt_basal=0.239, disnt_after=0.364, val_r2_mean=0.825, val_KL=nan]
disnt_basal = 0.24298323716708398
disnt_after = 0.3606706502151565
val_r2_mean = 0.8297388164513205
val_r2_var = 0.11536988331316844
Epoch 95/2000:   5%|▍         | 94/2000 [35:32<11:12:06, 21.16s/it, v_num=1, recon=1.37e+4, r2_mean=0.833, adv_loss=2.08, acc_pert=0.302, acc_cell_line=0.904, val_recon=1.36e+4, disnt_basal=0.243, disnt_after=0.361, val_r2_mean=0.83, val_KL=nan] 


Epoch 00094: cpa_metric reached. Module best state updated.


Epoch 100/2000:   5%|▍         | 99/2000 [37:23<11:05:11, 20.99s/it, v_num=1, recon=1.37e+4, r2_mean=0.833, adv_loss=2.08, acc_pert=0.312, acc_cell_line=0.902, val_recon=1.36e+4, disnt_basal=0.24, disnt_after=0.362, val_r2_mean=0.83, val_KL=nan]


Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 0.23766060240654333
disnt_after = 0.36343903408263994
val_r2_mean = 0.8317215403100818
val_r2_var = 0.12019940429480568
Epoch 105/2000:   5%|▌         | 104/2000 [39:14<11:03:06, 20.98s/it, v_num=1, recon=1.36e+4, r2_mean=0.835, adv_loss=2.03, acc_pert=0.326, acc_cell_line=0.908, val_recon=1.35e+4, disnt_basal=0.238, disnt_after=0.363, val_r2_mean=0.832, val_KL=nan]


Epoch 00104: cpa_metric reached. Module best state updated.


Epoch 110/2000:   5%|▌         | 109/2000 [41:06<11:04:43, 21.09s/it, v_num=1, recon=1.36e+4, r2_mean=0.835, adv_loss=2.05, acc_pert=0.316, acc_cell_line=0.905, val_recon=1.35e+4, disnt_basal=0.241, disnt_after=0.367, val_r2_mean=0.83, val_KL=nan] 


Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 0.2376580874434773
disnt_after = 0.36552739601958895
val_r2_mean = 0.8302782673800848
val_r2_var = 0.12798137542529062
Epoch 115/2000:   6%|▌         | 114/2000 [43:00<11:32:51, 22.04s/it, v_num=1, recon=1.35e+4, r2_mean=0.837, adv_loss=2.05, acc_pert=0.324, acc_cell_line=0.902, val_recon=1.35e+4, disnt_basal=0.238, disnt_after=0.366, val_r2_mean=0.83, val_KL=nan]


Epoch 00114: cpa_metric reached. Module best state updated.


Epoch 120/2000:   6%|▌         | 119/2000 [44:55<11:11:25, 21.42s/it, v_num=1, recon=1.35e+4, r2_mean=0.836, adv_loss=2.07, acc_pert=0.319, acc_cell_line=0.896, val_recon=1.34e+4, disnt_basal=0.234, disnt_after=0.363, val_r2_mean=0.829, val_KL=nan]


Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 0.2301736005803818
disnt_after = 0.36178311701087484
val_r2_mean = 0.829807965705183
val_r2_var = 0.12356228583148357
Epoch 130/2000:   6%|▋         | 129/2000 [48:58<11:29:19, 22.11s/it, v_num=1, recon=1.35e+4, r2_mean=0.838, adv_loss=2.06, acc_pert=0.325, acc_cell_line=0.894, val_recon=1.34e+4, disnt_basal=0.233, disnt_after=0.363, val_r2_mean=0.831, val_KL=nan]


Epoch 00129: cpa_metric reached. Module best state updated.



disnt_basal = 0.23461238642020696
disnt_after = 0.3601473651774933
val_r2_mean = 0.8329944996719272
val_r2_var = 0.1312118517039162
Epoch 135/2000:   7%|▋         | 134/2000 [50:49<10:59:12, 21.20s/it, v_num=1, recon=1.34e+4, r2_mean=0.838, adv_loss=2.06, acc_pert=0.323, acc_cell_line=0.899, val_recon=1.34e+4, disnt_basal=0.235, disnt_after=0.36, val_r2_mean=0.833, val_KL=nan] 


Epoch 00134: cpa_metric reached. Module best state updated.


Epoch 140/2000:   7%|▋         | 139/2000 [52:44<11:19:21, 21.90s/it, v_num=1, recon=1.34e+4, r2_mean=0.838, adv_loss=2.07, acc_pert=0.332, acc_cell_line=0.89, val_recon=1.33e+4, disnt_basal=0.232, disnt_after=0.366, val_r2_mean=0.832, val_KL=nan] 
disnt_basal = 0.22935446259315073
disnt_after = 0.36279361973484764
val_r2_mean = 0.8327453964353657
val_r2_var = 0.12334743378104257
Epoch 150/2000:   7%|▋         | 149/2000 [56:26<10:48:05, 21.01s/it, v_num=1, recon=1.33e+4, r2_mean=0.841, adv_loss=2.06, acc_pert=0.336, acc_cell_line=0.89, val_recon=1.33e+4, disnt_basal=0.232, disnt_after=0.361, val_r2_mean=0.834, val_KL=nan] 
disnt_basal = 0.23182702251526885
disnt_after = 0.3645416564454672
val_r2_mean = 0.8348963153197781
val_r2_var = 0.12714532948002968
Epoch 155/2000:   8%|▊         | 154/2000 [58:17<10:44:52, 20.96s/it, v_num=1, recon=1.33e+4, r2_mean=0.842, adv_loss=2.07, acc_pert=0.336, acc_cell_line=0.888, val_recon=1.33e+4, disnt_basal=0.232, disnt_after=0.365, val_r2_mean=0.835


Epoch 00154: cpa_metric reached. Module best state updated.


Epoch 160/2000:   8%|▊         | 159/2000 [1:00:08<10:43:43, 20.98s/it, v_num=1, recon=1.33e+4, r2_mean=0.841, adv_loss=2.06, acc_pert=0.342, acc_cell_line=0.888, val_recon=1.33e+4, disnt_basal=0.226, disnt_after=0.362, val_r2_mean=0.834, val_KL=nan]
disnt_basal = 0.2267897185055339
disnt_after = 0.3615286642492839
val_r2_mean = 0.8345825090299485
val_r2_var = 0.12840272377100564
Epoch 170/2000:   8%|▊         | 169/2000 [1:03:50<10:41:22, 21.02s/it, v_num=1, recon=1.33e+4, r2_mean=0.841, adv_loss=2.06, acc_pert=0.338, acc_cell_line=0.885, val_recon=1.32e+4, disnt_basal=0.226, disnt_after=0.363, val_r2_mean=0.833, val_KL=nan]


Epoch 00169: cpa_metric reached. Module best state updated.



disnt_basal = 0.22895687961254452
disnt_after = 0.36516188339357775
val_r2_mean = 0.8363002596945177
val_r2_var = 0.13119162220632039
Epoch 175/2000:   9%|▊         | 174/2000 [1:06:06<13:22:28, 26.37s/it, v_num=1, recon=1.32e+4, r2_mean=0.841, adv_loss=2.05, acc_pert=0.337, acc_cell_line=0.89, val_recon=1.32e+4, disnt_basal=0.229, disnt_after=0.365, val_r2_mean=0.836, val_KL=nan] 


Epoch 00174: cpa_metric reached. Module best state updated.


Epoch 180/2000:   9%|▉         | 179/2000 [1:08:04<11:21:09, 22.44s/it, v_num=1, recon=1.32e+4, r2_mean=0.842, adv_loss=2.06, acc_pert=0.342, acc_cell_line=0.88, val_recon=1.32e+4, disnt_basal=0.224, disnt_after=0.364, val_r2_mean=0.834, val_KL=nan] 
disnt_basal = 0.22564430764008642
disnt_after = 0.3584417654665027
val_r2_mean = 0.8352246314393054
val_r2_var = 0.1358809192636314
Epoch 190/2000:   9%|▉         | 189/2000 [1:11:46<10:35:32, 21.06s/it, v_num=1, recon=1.32e+4, r2_mean=0.842, adv_loss=2.06, acc_pert=0.344, acc_cell_line=0.883, val_recon=1.32e+4, disnt_basal=0.226, disnt_after=0.361, val_r2_mean=0.836, val_KL=nan]


Epoch 00189: cpa_metric reached. Module best state updated.



disnt_basal = 0.22158782274268157
disnt_after = 0.36353022943259045
val_r2_mean = 0.8341621220191789
val_r2_var = 0.13076819221497482
Epoch 200/2000:  10%|▉         | 199/2000 [1:15:28<10:31:00, 21.02s/it, v_num=1, recon=1.31e+4, r2_mean=0.842, adv_loss=2.03, acc_pert=0.347, acc_cell_line=0.885, val_recon=1.32e+4, disnt_basal=0.222, disnt_after=0.363, val_r2_mean=0.832, val_KL=nan]
disnt_basal = 0.22399287696883327
disnt_after = 0.36092239487208755
val_r2_mean = 0.8369308727033595
val_r2_var = 0.1303050636820574
Epoch 210/2000:  10%|█         | 209/2000 [1:19:11<10:28:19, 21.05s/it, v_num=1, recon=1.31e+4, r2_mean=0.843, adv_loss=2.03, acc_pert=0.352, acc_cell_line=0.885, val_recon=1.31e+4, disnt_basal=0.223, disnt_after=0.363, val_r2_mean=0.833, val_KL=nan]
disnt_basal = 0.22142222558823293
disnt_after = 0.36058800378858125
val_r2_mean = 0.8327396852704018
val_r2_var = 0.12699151590502947
Epoch 220/2000:  11%|█         | 219/2000 [1:23:18<12:05:30, 24.44s/it, v_num=1, recon=1.31e+4, 


Epoch 00229: cpa_metric reached. Module best state updated.



disnt_basal = 0.2221716033032125
disnt_after = 0.36102152334540666
val_r2_mean = 0.8358691539955792
val_r2_var = 0.13508744958549082
Epoch 240/2000:  12%|█▏        | 239/2000 [1:30:42<10:16:35, 21.01s/it, v_num=1, recon=1.3e+4, r2_mean=0.843, adv_loss=2.04, acc_pert=0.353, acc_cell_line=0.879, val_recon=1.31e+4, disnt_basal=0.222, disnt_after=0.363, val_r2_mean=0.831, val_KL=nan]
disnt_basal = 0.2209018198653417
disnt_after = 0.3614017755486403
val_r2_mean = 0.8335287695015
val_r2_var = 0.133018372780066
Epoch 245/2000:  12%|█▏        | 244/2000 [1:32:33<10:13:32, 20.96s/it, v_num=1, recon=1.3e+4, r2_mean=0.844, adv_loss=2.03, acc_pert=0.364, acc_cell_line=0.877, val_recon=1.3e+4, disnt_basal=0.221, disnt_after=0.361, val_r2_mean=0.834, val_KL=nan] 


Epoch 00244: cpa_metric reached. Module best state updated.


Epoch 250/2000:  12%|█▏        | 249/2000 [1:34:24<10:11:47, 20.96s/it, v_num=1, recon=1.3e+4, r2_mean=0.845, adv_loss=2.03, acc_pert=0.359, acc_cell_line=0.881, val_recon=1.31e+4, disnt_basal=0.217, disnt_after=0.361, val_r2_mean=0.836, val_KL=nan]
disnt_basal = 0.21999189235050715
disnt_after = 0.36034669511240786
val_r2_mean = 0.8341419550730004
val_r2_var = 0.13611378952055636
Epoch 260/2000:  13%|█▎        | 259/2000 [1:38:13<10:13:40, 21.15s/it, v_num=1, recon=1.29e+4, r2_mean=0.843, adv_loss=2.04, acc_pert=0.358, acc_cell_line=0.88, val_recon=1.3e+4, disnt_basal=0.22, disnt_after=0.362, val_r2_mean=0.834, val_KL=nan]  
disnt_basal = 0.2189352620984105
disnt_after = 0.3594436533361183
val_r2_mean = 0.8335151110562139
val_r2_var = 0.1347584672850516
Epoch 270/2000:  13%|█▎        | 269/2000 [1:41:55<10:07:25, 21.05s/it, v_num=1, recon=1.29e+4, r2_mean=0.844, adv_loss=2.03, acc_pert=0.361, acc_cell_line=0.875, val_recon=1.3e+4, disnt_basal=0.219, disnt_after=0.361, val_r2_mean=0.83


Epoch 00274: cpa_metric reached. Module best state updated.


Epoch 280/2000:  14%|█▍        | 279/2000 [1:45:37<10:02:32, 21.01s/it, v_num=1, recon=1.29e+4, r2_mean=0.846, adv_loss=2.03, acc_pert=0.363, acc_cell_line=0.877, val_recon=1.3e+4, disnt_basal=0.219, disnt_after=0.364, val_r2_mean=0.834, val_KL=nan]
disnt_basal = 0.217780779608271
disnt_after = 0.35749148919660373
val_r2_mean = 0.8336594542600815
val_r2_var = 0.1366866738199318
Epoch 290/2000:  14%|█▍        | 289/2000 [1:49:20<9:58:55, 21.00s/it, v_num=1, recon=1.29e+4, r2_mean=0.845, adv_loss=2.02, acc_pert=0.365, acc_cell_line=0.873, val_recon=1.3e+4, disnt_basal=0.221, disnt_after=0.361, val_r2_mean=0.836, val_KL=nan] 


Epoch 00289: cpa_metric reached. Module best state updated.



disnt_basal = 0.22048708531445618
disnt_after = 0.3630063256958039
val_r2_mean = 0.8370702519972997
val_r2_var = 0.1390535088267698
Epoch 300/2000:  15%|█▍        | 299/2000 [1:53:01<9:55:43, 21.01s/it, v_num=1, recon=1.29e+4, r2_mean=0.846, adv_loss=2, acc_pert=0.37, acc_cell_line=0.878, val_recon=1.3e+4, disnt_basal=0.216, disnt_after=0.359, val_r2_mean=0.832, val_KL=nan]     


Epoch 00299: cpa_metric reached. Module best state updated.



disnt_basal = 0.218819266164188
disnt_after = 0.3618891991244186
val_r2_mean = 0.8362826852327498
val_r2_var = 0.14101181146271577
Epoch 310/2000:  15%|█▌        | 309/2000 [1:56:43<9:52:13, 21.01s/it, v_num=1, recon=1.28e+4, r2_mean=0.846, adv_loss=2.01, acc_pert=0.364, acc_cell_line=0.879, val_recon=1.3e+4, disnt_basal=0.216, disnt_after=0.357, val_r2_mean=0.833, val_KL=nan] 
disnt_basal = 0.2152914383799755
disnt_after = 0.3586681730070986
val_r2_mean = 0.8346480092813999
val_r2_var = 0.13362810875435904
Epoch 320/2000:  16%|█▌        | 319/2000 [2:00:25<9:48:54, 21.02s/it, v_num=1, recon=1.28e+4, r2_mean=0.845, adv_loss=2, acc_pert=0.373, acc_cell_line=0.881, val_recon=1.3e+4, disnt_basal=0.218, disnt_after=0.36, val_r2_mean=0.834, val_KL=nan]     
disnt_basal = 0.21864591965633112
disnt_after = 0.35858445922786936
val_r2_mean = 0.8359903254789725
val_r2_var = 0.13411811644879565
Epoch 330/2000:  16%|█▋        | 329/2000 [2:04:07<9:44:48, 21.00s/it, v_num=1, recon=1.28e+4, r2_mean


Epoch 00349: cpa_metric reached. Module best state updated.



disnt_basal = 0.21635082235900704
disnt_after = 0.3607681201231233
val_r2_mean = 0.8351466893761739
val_r2_var = 0.14297566893943772
Epoch 360/2000:  18%|█▊        | 359/2000 [2:15:13<9:35:07, 21.03s/it, v_num=1, recon=1.28e+4, r2_mean=0.846, adv_loss=1.99, acc_pert=0.383, acc_cell_line=0.876, val_recon=1.3e+4, disnt_basal=0.213, disnt_after=0.359, val_r2_mean=0.834, val_KL=nan] 
disnt_basal = 0.21580566240875193
disnt_after = 0.3595484005878319
val_r2_mean = 0.835434922258568
val_r2_var = 0.1407967806167381
Epoch 370/2000:  18%|█▊        | 369/2000 [2:18:55<9:31:51, 21.04s/it, v_num=1, recon=1.28e+4, r2_mean=0.846, adv_loss=1.98, acc_pert=0.376, acc_cell_line=0.879, val_recon=1.3e+4, disnt_basal=0.216, disnt_after=0.359, val_r2_mean=0.831, val_KL=nan] 
disnt_basal = 0.2148614647691582
disnt_after = 0.35769189669427204
val_r2_mean = 0.832197070983762
val_r2_var = 0.1407263039807501
Epoch 380/2000:  19%|█▉        | 379/2000 [2:22:37<9:26:52, 20.98s/it, v_num=1, recon=1.27e+4, r2_mean=0


Epoch 00389: cpa_metric reached. Module best state updated.



disnt_basal = 0.21161492981009056
disnt_after = 0.36013156306252914
val_r2_mean = 0.8362216434337031
val_r2_var = 0.13843472902494788
Epoch 400/2000:  20%|█▉        | 399/2000 [2:30:02<9:20:53, 21.02s/it, v_num=1, recon=1.27e+4, r2_mean=0.847, adv_loss=2, acc_pert=0.383, acc_cell_line=0.875, val_recon=1.29e+4, disnt_basal=0.216, disnt_after=0.36, val_r2_mean=0.832, val_KL=nan]    
disnt_basal = 0.21591535706673035
disnt_after = 0.3590443061287641
val_r2_mean = 0.8331812524433979
val_r2_var = 0.13468087993301575
Epoch 410/2000:  20%|██        | 409/2000 [2:33:44<9:17:46, 21.03s/it, v_num=1, recon=1.27e+4, r2_mean=0.846, adv_loss=1.97, acc_pert=0.376, acc_cell_line=0.884, val_recon=1.29e+4, disnt_basal=0.218, disnt_after=0.36, val_r2_mean=0.835, val_KL=nan]  
disnt_basal = 0.2169181369259394
disnt_after = 0.3575956032512808
val_r2_mean = 0.8328639419679499
val_r2_var = 0.1358556258286511
Epoch 420/2000:  21%|██        | 419/2000 [2:37:26<9:14:37, 21.05s/it, v_num=1, recon=1.27e+4, r2_me

In [43]:
# Predict perturbations
model.predict(adata, batch_size=1024)

100%|██████████| 14/14 [00:08<00:00,  1.71it/s]


In [58]:
# Prepare data
adata_pert = adata[adata.obs['treatment']==un_treatment].copy()
adata_pert.obs['treatment'] = pert_treatment
cpa.CPA.setup_anndata(
    adata_pert,
    perturbation_key='treatment',
    dosage_key=None,
    control_group='DMSO_48hr',
    batch_key=None,
    is_count_data=True,
    categorical_covariate_keys=['cell_line'],
    # deg_uns_key='rank_genes_groups_cov',
    # deg_uns_cat_key='cov_drug_dose',
    max_comb_len=2)




100%|██████████| 2256/2256 [00:00<00:00, 74346.29it/s]
100%|██████████| 2256/2256 [00:00<00:00, 963972.07it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        





In [None]:
# Filter to val
common_mask = np.ones(adata_pert.shape[0], dtype=bool)
# common_mask = ~adata_pert.obs['Training']

# Get CPA output
pred = adata_pert.obsm['CPA_pred'].copy()
np.save('../plots/drugseries/CPA_perturbation.npy', pred[common_mask].mean(axis=0))
df = pd.DataFrame(pred, columns=adata_pert.var['gene_name'])
df['cell_line'] = adata_pert[common_mask].obs['cell_line'].to_numpy()
for cell_line, row in df.groupby('cell_line').mean().iterrows():
    np.save(f'../plots/drugseries/CPA_{cell_line}_perturbation.npy', row.to_numpy())