In [1]:
import scanpy as sc
import numpy as np
import os

from pathlib import Path
from pprint import pprint
from seml.config import generate_configs, read_config
from chemCPA.experiments_run import ExperimentWrapper

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Declare variables
train_split = 0.8
rng = np.random.default_rng(1337)
ood = ["Dacinostat", "Givinostat", "Belinostat", "Hesperadin", "Quisinostat", "Alvespimycin", "Tanespimycin", "TAK-901", "Flavopiridol"]

In [3]:
os.getcwd()

'/dss/dsshome1/0A/di93hoq/forkCPA/compare'

In [4]:
# Load the dataset
adata = sc.read_h5ad("/dss/dsshome1/0A/di93hoq/ConditionalMongeGap/Datasets/sciplex_complete_middle_subset.h5ad")

In [5]:
# Add a new split to the dataset to be used as a comparisson to Conditional Monge Gap
adata.obs["compare_split"] = np.where(
    adata.obs["condition"].isin(ood),
    "ood",
    rng.choice(["training", "test"], p=[train_split, 1-train_split]),
) 

In [6]:
sc.pp.subsample(
    adata,
    n_obs=5000
)

In [7]:
adata.obs["split_random"].unique()

['ood', 'train', 'test']
Categories (3, object): ['ood', 'test', 'train']

In [8]:
# Save the dataset
adata.write_h5ad("/dss/dsshome1/0A/di93hoq/ConditionalMongeGap/Datasets/sciplex_complete_middle_subset_compare.h5ad")

In [9]:
cd /dss/dsshome1/0A/di93hoq/forkCPA/

/dss/dsshome1/0A/di93hoq/forkCPA


In [10]:
exp = ExperimentWrapper(init_all=False)
# this is how seml loads the config file internally
assert Path(
    "manual_run.yaml"
).exists(), "config file not found"
seml_config, slurm_config, experiment_config = read_config(
   "manual_run.yaml"
)
# we take the first config generated
configs = generate_configs(experiment_config)
if len(configs) > 1:
    print("Careful, more than one config generated from the yaml file")
args = configs[0]
pprint(args)

exp.seed = 1337
# loads the dataset splits
exp.init_dataset(**args["dataset"])

{'dataset': {'data_params': {'covariate_keys': 'cell_type',
                             'dataset_path': '/dss/dsshome1/0A/di93hoq/ConditionalMongeGap/Datasets/sciplex_complete_middle_subset_compare.h5ad',
                             'degs_key': 'all_DEGs',
                             'dose_key': 'dose',
                             'pert_category': 'cov_drug_dose_name',
                             'perturbation_key': 'condition',
                             'smiles_key': 'SMILES',
                             'split_key': 'split_random',
                             'use_drugs_idx': True},
             'dataset_type': 'trapnell'},
 'model': {'additional_params': {'decoder_activation': 'ReLU',
                                 'doser_type': 'amortized',
                                 'multi_task': False,
                                 'patience': 50,
                                 'seed': 1337},
           'append_ae_layer': True,
           'embedding': {'directory': 'embeddi

In [11]:
exp.init_drug_embedding(embedding=args["model"]["embedding"])
exp.init_model(
    hparams=args["model"]["hparams"],
    additional_params=args["model"]["additional_params"],
    load_pretrained=args["model"]["load_pretrained"],
    append_ae_layer=args["model"]["append_ae_layer"],
    enable_cpa_mode=args["model"]["enable_cpa_mode"],
    pretrained_model_path=args["model"]["pretrained_model_path"],
    pretrained_model_hashes=args["model"]["pretrained_model_hashes"],
)

In [12]:
exp.datasets.keys()

dict_keys(['training', 'training_control', 'training_treated', 'test', 'test_control', 'test_treated', 'ood'])

In [13]:
# setup the torch DataLoader
exp.update_datasets()

In [14]:
exp.train(**args["training"])

CWD: /dss/dsshome1/0A/di93hoq/forkCPA
Save dir: compare/checkpoints




Size of disentanglement testdata: 731

Took 0.4 min for evaluation.



{'epoch': [0, 1, 2, 3, 4],
 'stats_epoch': [4],
 'loss_reconstruction': [-9.099171981215477,
  -24.392128109931946,
  -38.112823724746704,
  -48.48217499256134,
  -58.22065711021423],
 'loss_adv_drugs': [148.400963306427,
  148.87963104248047,
  148.8683567047119,
  148.32721185684204,
  148.072283744812],
 'loss_adv_covariates': [38.79986619949341,
  36.247304916381836,
  32.05805975198746,
  38.10633182525635,
  45.02868127822876],
 'penalty_adv_drugs': [4.205532506108284,
  0.7807078063488007,
  0.17748189065605402,
  0.10121615417301655,
  0.07536241970956326],
 'penalty_adv_covariates': [0.11973072402179241,
  0.08446761826053262,
  0.06932178698480129,
  0.07250320492312312,
  0.04423150490038097],
 'loss_multi_task': [0.0, 0.0, 0.0, 0.0, 0.0],
 'elapsed_time_min': 0.10051395495732625,
 'perturbation disentanglement': [1.0],
 'optimal for perturbations': [0.03967168262653899],
 'covariate disentanglement': [[1.0]],
 'optimal for covariates': [[0.49931600689888]],
 'training': [[0

In [15]:
!ls compare/checkpoints

model_4.pt


# Model loading and preparation

In [16]:
os.getcwd()

'/dss/dsshome1/0A/di93hoq/forkCPA'

In [17]:
from notebooks.utils import load_config, load_dataset, load_model, load_smiles, compute_pred
from chemCPA.train import evaluate, compute_prediction, repeat_n
import torch
from chemCPA.model import ComPert

In [18]:
seml_collection = "multi_task"
model_hash_pretrained_rdkit = "model_4"  # Fine-tuned

In [19]:
dir(exp.dataset)

['__annotations__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_drugs_name_to_idx',
 'atomic_сovars_dict',
 'canon_smiles_unique_sorted',
 'covariate_keys',
 'covariate_names',
 'covariate_names_unique',
 'covariates',
 'ctrl',
 'ctrl_name',
 'de_genes',
 'degs',
 'dosages',
 'dose_key',
 'dose_names',
 'drug_name_to_idx',
 'drugs_idx',
 'drugs_names',
 'drugs_names_unique_sorted',
 'genes',
 'indices',
 'max_num_perturbations',
 'num_covariates',
 'num_drugs',
 'num_genes',
 'pert_categories',
 'perturbation_key',
 'smiles_key',
 'subset',
 'use_drugs_idx',
 'var_names']

In [20]:
dir(exp)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'autoencoder',
 'dataset',
 'datasets',
 'drug_embeddings',
 'embedding_model_type',
 'init_all',
 'init_dataset',
 'init_drug_embedding',
 'init_model',
 'load_state_dict',
 'seed',
 'train',
 'update_datasets']

In [21]:
model = torch.load("compare/checkpoints/model_4.pt")

In [22]:
(
    state_dict,
    cov_adv_state_dicts,
    cov_emb_state_dicts,
    init_args,
    history,
        ) = model

In [23]:
model = ComPert(
        **init_args, drug_embeddings=exp.drug_embeddings, append_layer_width=256
)
model = model.eval()

In [24]:
compute_pred(
    model,
    exp.datasets["training"]
)

['A549', 'K562', 'MCF7']


992it [00:00, 41382.79it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (7x2000 and 256x2000)

In [25]:
model = ComPert(exp.dataset.num_genes, exp.dataset.num_drugs, exp.dataset.num_covariates)

In [26]:
model.load_state_dict(torch.load("compare/checkpoints/model_4.pt")[0])

RuntimeError: Error(s) in loading state_dict for ComPert:
	Missing key(s) in state_dict: "drug_embedding_encoder.network.0.weight", "drug_embedding_encoder.network.0.bias", "dosers.beta", "dosers.bias". 
	Unexpected key(s) in state_dict: "adversary_drugs.network.10.weight", "adversary_drugs.network.10.bias", "adversary_drugs.network.10.running_mean", "adversary_drugs.network.10.running_var", "adversary_drugs.network.10.num_batches_tracked", "adversary_drugs.network.12.weight", "adversary_drugs.network.12.bias", "dosers.network.0.weight", "dosers.network.0.bias", "dosers.network.1.weight", "dosers.network.1.bias", "dosers.network.1.running_mean", "dosers.network.1.running_var", "dosers.network.1.num_batches_tracked", "dosers.network.3.weight", "dosers.network.3.bias", "dosers.network.4.weight", "dosers.network.4.bias", "dosers.network.4.running_mean", "dosers.network.4.running_var", "dosers.network.4.num_batches_tracked", "dosers.network.6.weight", "dosers.network.6.bias", "dosers.network.7.weight", "dosers.network.7.bias", "dosers.network.7.running_mean", "dosers.network.7.running_var", "dosers.network.7.num_batches_tracked", "dosers.network.9.weight", "dosers.network.9.bias". 
	size mismatch for encoder.network.0.weight: copying a param with shape torch.Size([256, 2000]) from checkpoint, the shape in current model is torch.Size([512, 2000]).
	size mismatch for encoder.network.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.3.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for encoder.network.3.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.4.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.4.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.4.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.4.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.6.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for encoder.network.6.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.7.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.7.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.7.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.7.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.9.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for encoder.network.9.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.10.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.10.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.10.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.10.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.network.12.weight: copying a param with shape torch.Size([194, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for encoder.network.12.bias: copying a param with shape torch.Size([194]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.network.0.weight: copying a param with shape torch.Size([256, 194]) from checkpoint, the shape in current model is torch.Size([512, 256]).
	size mismatch for decoder.network.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.3.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.network.3.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.4.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.4.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.4.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.4.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.6.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.network.6.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.7.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.7.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.7.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.7.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.9.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.network.9.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.10.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.10.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.10.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.10.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.network.12.weight: copying a param with shape torch.Size([4000, 256]) from checkpoint, the shape in current model is torch.Size([4000, 512]).
	size mismatch for adversary_drugs.network.0.weight: copying a param with shape torch.Size([128, 194]) from checkpoint, the shape in current model is torch.Size([128, 256]).
	size mismatch for adversary_drugs.network.9.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([188, 128]).
	size mismatch for adversary_drugs.network.9.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([188]).
	size mismatch for drug_embeddings.weight: copying a param with shape torch.Size([188, 194]) from checkpoint, the shape in current model is torch.Size([188, 256]).

In [27]:
config = configs[0]
config["config_hash"] = "model_4"
load_model(config, exp.dataset.canon_smiles_unique_sorted,"compare/checkpoints/model_4.pt")

AssertionError: ['drug_embeddings.weight', 'drug_embedding_encoder.network.0.weight', 'drug_embedding_encoder.network.0.bias', 'drug_embedding_encoder.network.1.weight', 'drug_embedding_encoder.network.1.bias', 'drug_embedding_encoder.network.1.running_mean', 'drug_embedding_encoder.network.1.running_var', 'drug_embedding_encoder.network.3.weight', 'drug_embedding_encoder.network.3.bias', 'drug_embedding_encoder.network.4.weight', 'drug_embedding_encoder.network.4.bias', 'drug_embedding_encoder.network.4.running_mean', 'drug_embedding_encoder.network.4.running_var', 'drug_embedding_encoder.network.6.weight', 'drug_embedding_encoder.network.6.bias', 'drug_embedding_encoder.network.7.weight', 'drug_embedding_encoder.network.7.bias', 'drug_embedding_encoder.network.7.running_mean', 'drug_embedding_encoder.network.7.running_var', 'drug_embedding_encoder.network.9.weight', 'drug_embedding_encoder.network.9.bias', 'drug_embedding_encoder.network.10.weight', 'drug_embedding_encoder.network.10.bias', 'drug_embedding_encoder.network.10.running_mean', 'drug_embedding_encoder.network.10.running_var', 'drug_embedding_encoder.network.12.weight', 'drug_embedding_encoder.network.12.bias']