# Obtaining distributional metrics for the ChemCPA
**Goal**: To obtain distributional metrics for the ChemCPA, so that we can directly compare the results to the ConditionalMongeGap model.

In [1]:
PATH = "/home/thesis/"
# PATH = "/dss/dsshome1/0A/di93hoq/"

## 1. Dataset preparation
By design, chemCPA aims to solve a broader problem that the ConditionalMongeGap model, since it tries to make predictions for different cell lines and drug doses. To make the comparison fair, we will only select a single cell line and the highest dose on the dataset (which is what we train the ConditionalMongeGap model on).


In [2]:
import scanpy as sc
import jax.numpy as jnp
import os
from pathlib import Path
from pprint import pprint
from seml.config import generate_configs, read_config
from chemCPA.experiments_run import ExperimentWrapper

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Load the dataset
adata = sc.read_h5ad(PATH + "/ConditionalMongeGap/Datasets/sciplex_complete_middle_subset.h5ad")

In [4]:
adata = adata[
    (adata.obs["cell_type"] == "A549")
    & ((adata.obs["dose"] == 10000) | (adata.obs["dose"] == 0))
]

In [6]:
# Save the dataset
adata.write_h5ad(PATH + "ConditionalMongeGap/Datasets/sciplex_complete_middle_subset_compare.h5ad")

## 2. Experiment
We load the manual configuration file and run the experiment (i.e. train the model and save it)

In [5]:
exp = ExperimentWrapper(init_all=False)
# this is how seml loads the config file internally
assert Path(
    PATH + "forkCPA/manual_run.yaml"
).exists(), "config file not found"
seml_config, slurm_config, experiment_config = read_config(
   PATH + "forkCPA/manual_run.yaml"
)
# we take the first config generated
configs = generate_configs(experiment_config)
if len(configs) > 1:
    print("Careful, more than one config generated from the yaml file")
args = configs[0]
pprint(args)

exp.seed = 1337
# loads the dataset splits
exp.init_dataset(**args["dataset"])

{'dataset': {'data_params': {'covariate_keys': 'cell_type',
                             'dataset_path': '/home/thesis/ConditionalMongeGap/Datasets/sciplex_complete_middle_subset_compare.h5ad',
                             'degs_key': 'all_DEGs',
                             'dose_key': 'dose',
                             'pert_category': 'cov_drug_dose_name',
                             'perturbation_key': 'condition',
                             'smiles_key': 'SMILES',
                             'split_key': 'split_ood_finetuning',
                             'use_drugs_idx': True},
             'dataset_type': 'trapnell'},
 'model': {'additional_params': {'decoder_activation': 'ReLU',
                                 'doser_type': 'amortized',
                                 'multi_task': False,
                                 'patience': 50,
                                 'seed': 1337},
           'append_ae_layer': True,
           'embedding': {'directory': 'embeddings'

In [6]:
exp.init_drug_embedding(embedding=args["model"]["embedding"])
exp.init_model(
    hparams=args["model"]["hparams"],
    additional_params=args["model"]["additional_params"],
    load_pretrained=args["model"]["load_pretrained"],
    append_ae_layer=args["model"]["append_ae_layer"],
    enable_cpa_mode=args["model"]["enable_cpa_mode"],
    pretrained_model_path=args["model"]["pretrained_model_path"],
    pretrained_model_hashes=args["model"]["pretrained_model_hashes"],
)

In [7]:
# setup the torch DataLoader
exp.update_datasets()

In [11]:
train_results = exp.train(**args["training"])

CWD: /home/thesis/forkCPA
Save dir: compare/checkpoints



KeyboardInterrupt



In [None]:
os.listdir(PATH + "forkCPA/compare/checkpoints")

['model_4.pt']

## 3. Prediction
We load the trained model and make predictions on the test and out of distribution (ood) datasets. The prediction of the chemCPA model is a vector of means and standard deviations for each point in the source/control. We will use the mean as the prediction (and use it to compute the distributional metrics). 

In [29]:
from notebooks.utils import compute_pred
from chemCPA.model import ComPert
import csv
import torch
import numpy as np

In [20]:
os.chdir(PATH + "ConditionalMongeGap/")
from losses import sinkhorn_div
from utils import calculate_metrics

In [9]:
model_hash = "model_4"  # Fine-tuned

In [10]:
model = torch.load(PATH + f"forkCPA/compare/checkpoints/{model_hash}.pt")

In [11]:
(
    state_dict,
    cov_adv_state_dicts,
    cov_emb_state_dicts,
    init_args,
    history,

) = model

In [12]:
model = ComPert(
        **init_args, drug_embeddings=exp.drug_embeddings
)
model = model.eval()

In [13]:
drug_r2_pretrained_degs_rdkit, _ = compute_pred(
    model,
    exp.datasets["ood"]
)

['A549', 'K562', 'MCF7']


5it [00:05,  1.23it/s]

A549_CUDC-101_1.0: -0.77
A549_CUDC-907_1.0: -0.36
A549_Dacinostat_1.0: -0.58
A549_Givinostat_1.0: -1.34
A549_Hesperadin_1.0: -0.40
A549_Pirarubicin_1.0: -0.46


9it [00:05,  1.61it/s]

A549_Raltitrexed_1.0: -1.69
A549_Tanespimycin_1.0: -1.08
A549_Trametinib_1.0: -0.53





In [14]:
prediction, embeddings = model.predict(
    genes=exp.datasets["ood"].genes,
    drugs_idx=exp.datasets["ood"].drugs_idx,
    dosages=exp.datasets["ood"].dosages,
    covariates=exp.datasets["ood"].covariates
)

In [21]:
for name in np.unique(exp.datasets["ood"].drugs_names):
    section = (exp.datasets["ood"].drugs_names == name)
    value = sinkhorn_div(
        jnp.asarray(prediction.detach().numpy()[section, 0:2000]),
        jnp.asarray(exp.datasets["ood"].genes[section]),
    )
    print(f"Sinkhorn divergence target and prediction for {name:12s}: {value:12.5f}")



Sinkhorn divergence target and prediction for CUDC-101: 91.75827026367188
Sinkhorn divergence target and prediction for CUDC-907: 92.26095581054688
Sinkhorn divergence target and prediction for Dacinostat: 94.89117431640625
Sinkhorn divergence target and prediction for Givinostat: 92.05887603759766
Sinkhorn divergence target and prediction for Hesperadin: 83.93119049072266
Sinkhorn divergence target and prediction for Pirarubicin: 101.37113189697266
Sinkhorn divergence target and prediction for Raltitrexed: 93.80936431884766
Sinkhorn divergence target and prediction for Tanespimycin: 105.04225158691406
Sinkhorn divergence target and prediction for Trametinib: 106.17483520507812


In [33]:
with open(PATH+ "forkCPA/compare/results.csv", 'w') as f:
    w = csv.DictWriter(f, ["name", "type", "r2", "mae", "sinkhorn_source_target", "sinkhorn_target_pred", "mmd_source_target", "mmd_target_pred", "fid_source_target", "fid_target_pred", "e_source_target", "e_target_pred"])
    w.writeheader()
    print(
        f"\n{'Condition':15s}{'':5s}" +
        f"{'Typr':10s}{'':5s}{'r2':15s}{'':5s}{'mae':15s}" +
        f"{'':5s}{'SINK(S,T)':15s}{'':5s}{'SINK(T,P)':15s}" +
        f"{'':5s}{'MMD(S,T)':>15s}{'':5s}{'MMD(T,P)':>15s}"+
        f"{'':5s}{'FID(S,T)':>15s}{'':5s}{'FID(S,P)':>12s}" +
        f"{'':5s}{'E(S,T)':>15s}{'':5s}{'E(T,P)':>15s}"
    )
    
    for name in np.unique(exp.datasets["ood"].drugs_names):
        section = (exp.datasets["ood"].drugs_names == name)
        results = calculate_metrics(
            name=name,
            type="ood",
            source=jnp.asarray(exp.datasets["training_control"].genes[0:len(section)]),
            target=jnp.asarray(exp.datasets["ood"].genes[section]),
            predicted=jnp.asarray(prediction.detach().numpy()[section, 0:2000]),
            epsilon=0.1,
            epsilon_mmd=100
        )
        

        print(
            ("{:15s}{:5s}{:10s}{:5s}" + "{:>15.3f}{:5s}" * 9 +"{:15.3f}").format(
                name,
                '',
                'ood',
                '',
                results['r2'],
                '',
                results["mae"],
                '',
                results['sinkhorn_source_target'],
                '',
                results['sinkhorn_target_pred'],
                '' ,
                results['mmd_source_target'],
                '',
                results['mmd_source_target'],
                '',
                results['fid_target_pred'],
                '',
                results['fid_source_target'],
                '',
                results['e_source_target'],
                '',
                results['e_target_pred']
            )
    )
        
    w.writerow(results)


Condition           Typr           r2                  mae                 SINK(S,T)           SINK(T,P)                  MMD(S,T)            MMD(T,P)            FID(S,T)         FID(S,P)              E(S,T)              E(T,P)
CUDC-101            ood                     -0.112               0.058             132.398             103.970              11.608              11.608             104.763              88.078               0.975               8.963
