# Obtaining distributional metrics for the ChemCPA
**Goal**: To obtain distributional metrics for the ChemCPA, so that we can directly compare the results to the ConditionalMongeGap model.

In [1]:
PATH = "/home/thesis/"
# PATH = "/dss/dsshome1/0A/di93hoq/"

## 1. Dataset preparation
By design, chemCPA aims to solve a broader problem that the ConditionalMongeGap model, since it tries to make predictions for different cell lines and drug doses. To make the comparison fair, we will only select a single cell line and the highest dose on the dataset (which is what we train the ConditionalMongeGap model on).


In [2]:
import scanpy as sc
import jax.numpy as jnp
import os
from pathlib import Path
from pprint import pprint
from seml.config import generate_configs, read_config
from chemCPA.experiments_run import ExperimentWrapper

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Load the dataset
adata = sc.read_h5ad(PATH + "/ConditionalMongeGap/Datasets/sciplex_complete_middle_subset.h5ad")

In [4]:
adata = adata[
    (adata.obs["cell_type"] == "A549")
    & ((adata.obs["dose"] == 10000) | (adata.obs["dose"] == 0))
]

In [5]:
# Save the dataset
adata.write_h5ad(PATH + "ConditionalMongeGap/Datasets/sciplex_complete_middle_subset_compare.h5ad")

## 2. Experiment
We load the manual configuration file and run the experiment (i.e. train the model and save it)

In [6]:
exp = ExperimentWrapper(init_all=False)
# this is how seml loads the config file internally
assert Path(
    PATH + "forkCPA/manual_run.yaml"
).exists(), "config file not found"
seml_config, slurm_config, experiment_config = read_config(
   PATH + "forkCPA/manual_run.yaml"
)
# we take the first config generated
configs = generate_configs(experiment_config)
if len(configs) > 1:
    print("Careful, more than one config generated from the yaml file")
args = configs[0]
pprint(args)

exp.seed = 1337
# loads the dataset splits
exp.init_dataset(**args["dataset"])

{'dataset': {'data_params': {'covariate_keys': 'cell_type',
                             'dataset_path': '/home/thesis/ConditionalMongeGap/Datasets/sciplex_complete_middle_subset_compare.h5ad',
                             'degs_key': 'all_DEGs',
                             'dose_key': 'dose',
                             'pert_category': 'cov_drug_dose_name',
                             'perturbation_key': 'condition',
                             'smiles_key': 'SMILES',
                             'split_key': 'split_ood_finetuning',
                             'use_drugs_idx': True},
             'dataset_type': 'trapnell'},
 'model': {'additional_params': {'decoder_activation': 'ReLU',
                                 'doser_type': 'amortized',
                                 'multi_task': False,
                                 'patience': 50,
                                 'seed': 1337},
           'append_ae_layer': True,
           'embedding': {'directory': 'embeddings'

In [7]:
exp.init_drug_embedding(embedding=args["model"]["embedding"])
exp.init_model(
    hparams=args["model"]["hparams"],
    additional_params=args["model"]["additional_params"],
    load_pretrained=args["model"]["load_pretrained"],
    append_ae_layer=args["model"]["append_ae_layer"],
    enable_cpa_mode=args["model"]["enable_cpa_mode"],
    pretrained_model_path=args["model"]["pretrained_model_path"],
    pretrained_model_hashes=args["model"]["pretrained_model_hashes"],
)

In [8]:
# setup the torch DataLoader
exp.update_datasets()

In [9]:
train_results = exp.train(**args["training"])

CWD: /home/thesis/forkCPA
Save dir: compare/checkpoints
Size of disentanglement testdata: 2667

Took 1.6 min for evaluation.



In [10]:
os.listdir(PATH + "forkCPA/compare/checkpoints")

['model_24.pt', 'model_21.pt', '.gitignore']

## 3. Prediction
We load the trained model and make predictions on the test and out of distribution (ood) datasets. The prediction of the chemCPA model is a vector of means and standard deviations for each point in the source/control. We will use the mean as the prediction (and use it to compute the distributional metrics). 

In [11]:
from notebooks.utils import compute_pred
from chemCPA.model import ComPert
import csv
import torch
import numpy as np

In [12]:
os.chdir(PATH + "ConditionalMongeGap/")
from losses import sinkhorn_div
from utils import calculate_metrics

In [13]:
os.listdir(PATH + f"forkCPA/compare/checkpoints")

['model_24.pt', 'model_21.pt', '.gitignore']

In [14]:
model_hash = "model_24"  # Fine-tuned

In [15]:
model = torch.load(PATH + f"forkCPA/compare/checkpoints/{model_hash}.pt")

In [16]:
(
    state_dict,
    cov_adv_state_dicts,
    cov_emb_state_dicts,
    init_args,
    history,

) = model

In [17]:
model = ComPert(
        **init_args, drug_embeddings=exp.drug_embeddings
)
model = model.eval()

In [18]:
drug_r2_pretrained_degs_rdkit, _ = compute_pred(
    model,
    exp.datasets["ood"]
)

['A549', 'K562', 'MCF7']


5it [00:01,  4.37it/s]

A549_CUDC-101_1.0: -0.77
A549_CUDC-907_1.0: -0.39
A549_Dacinostat_1.0: -0.58
A549_Givinostat_1.0: -1.31
A549_Hesperadin_1.0: -0.39
A549_Pirarubicin_1.0: -0.45
A549_Raltitrexed_1.0: -1.69


9it [00:01,  4.56it/s]

A549_Tanespimycin_1.0: -1.07
A549_Trametinib_1.0: -0.51





In [19]:
prediction, embeddings = model.predict(
    genes=exp.datasets["ood"].genes,
    drugs_idx=exp.datasets["ood"].drugs_idx,
    dosages=exp.datasets["ood"].dosages,
    covariates=exp.datasets["ood"].covariates
)

In [20]:
for name in np.unique(exp.datasets["ood"].drugs_names):
    section = (exp.datasets["ood"].drugs_names == name)
    value = sinkhorn_div(
        jnp.asarray(prediction.detach().numpy()[section, 0:2000]),
        jnp.asarray(exp.datasets["ood"].genes[section]),
    )
    print(f"Sinkhorn divergence target and prediction for {name:12s}: {value:12.5f}")



Sinkhorn divergence target and prediction for CUDC-101    :     91.97215
Sinkhorn divergence target and prediction for CUDC-907    :     92.39673
Sinkhorn divergence target and prediction for Dacinostat  :     95.05641
Sinkhorn divergence target and prediction for Givinostat  :     92.18755
Sinkhorn divergence target and prediction for Hesperadin  :     84.31055
Sinkhorn divergence target and prediction for Pirarubicin :    101.57546
Sinkhorn divergence target and prediction for Raltitrexed :     93.96259
Sinkhorn divergence target and prediction for Tanespimycin:    105.46171
Sinkhorn divergence target and prediction for Trametinib  :    106.65128


In [21]:
with open(PATH+ "forkCPA/compare/results.csv", 'w') as f:
    w = csv.DictWriter(f, ["name", "type", "r2", "mae", "sinkhorn_source_target", "sinkhorn_target_pred", "mmd_source_target", "mmd_target_pred", "fid_source_target", "fid_target_pred", "e_source_target", "e_target_pred"])
    w.writeheader()
    print(
        f"\n{'Condition':25s}{'':5s}" +
        f"{'Typr':10s}{'':5s}{'r2':15s}{'':5s}{'mae':15s}" +
        f"{'':5s}{'SINK(S,T)':15s}{'':5s}{'SINK(T,P)':15s}" +
        f"{'':5s}{'MMD(S,T)':>15s}{'':5s}{'MMD(T,P)':>15s}"+
        f"{'':5s}{'FID(S,T)':>15s}{'':5s}{'FID(S,P)':>12s}" +
        f"{'':5s}{'E(S,T)':>15s}{'':5s}{'E(T,P)':>15s}"
    )
    for type_ in ["test", "ood"]:
        prediction, embeddings = model.predict(
            genes=exp.datasets[type_].genes,
            drugs_idx=exp.datasets[type_].drugs_idx,
            dosages=exp.datasets[type_].dosages,
            covariates=exp.datasets[type_].covariates
        )

        for name in np.unique(exp.datasets[type_].drugs_names):
            section = (exp.datasets[type_].drugs_names == name)
            results = calculate_metrics(
                name=name,
                type=type_,
                source=jnp.asarray(exp.datasets["training_control"].genes[0:len(section)]),
                target=jnp.asarray(exp.datasets[type_].genes[section]),
                predicted=jnp.asarray(prediction.detach().numpy()[section, 0:2000]),
                epsilon=0.1,
                epsilon_mmd=100
            )
            

            print(
                ("{:25s}{:5s}{:10s}{:5s}" + "{:>15.3f}{:5s}" * 9 +"{:15.3f}").format(
                    name,
                    '',
                    type_,
                    '',
                    results['r2'],
                    '',
                    results["mae"],
                    '',
                    results['sinkhorn_source_target'],
                    '',
                    results['sinkhorn_target_pred'],
                    '' ,
                    results['mmd_source_target'],
                    '',
                    results['mmd_target_pred'],
                    '',
                    results['fid_target_pred'],
                    '',
                    results['fid_source_target'],
                    '',
                    results['e_source_target'],
                    '',
                    results['e_target_pred']
                )
        )
        print(["-"]*210)
        w.writerow(results)


Condition                     Typr           r2                  mae                 SINK(S,T)           SINK(T,P)                  MMD(S,T)            MMD(T,P)            FID(S,T)         FID(S,P)              E(S,T)              E(T,P)
2-Methoxyestradiol            test                    -0.071               0.067             139.663             117.117              31.054              74.519             135.658             152.878               3.016              12.488
A-366                         test                    -0.091               0.064             122.607             100.221              13.097              55.139             107.606             121.955               1.374              10.560
ABT-737                       test                    -0.083               0.071             127.905             105.567              25.684              70.175             119.974             137.599               3.013              12.494
AC480                         test    

: 

: 