# Obtaining distributional metrics for the ChemCPA
**Goal**: To obtain distributional metrics for the ChemCPA, so that we can directly compare the results to the ConditionalMongeGap model.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path

PATH = str(Path(os.getcwd()).parent)

## 1. Dataset preparation
By design, chemCPA aims to solve a broader problem that the ConditionalMongeGap model, since it tries to make predictions for different cell lines and drug doses. To make the comparison fair, we will only select a single cell line and the highest dose on the dataset (which is what we train the ConditionalMongeGap model on).


In [4]:
import scanpy as sc
# import jax.numpy as jnp
import sys
from pathlib import Path
from pprint import pprint
from seml.config import generate_configs, read_config
import torch

In [5]:
sys.path.append(PATH)

In [6]:
from chemCPA.experiments_run import ExperimentWrapper
from chemCPA.model import ComPert
from chemCPA.train import compute_prediction
from notebooks.utils import repeat_n
from notebooks import utils

In [7]:
model = torch.load(
    '/home/icb/gori.camps/forkCPA/28c172ee2884c3204fa0df4b7223ff93.pt',
    map_location=torch.device('cpu')
)

In [8]:
(
    state_dict,
    cov_adv_state_dicts,
    cov_emb_state_dicts,
    init_args,
    history,

) = model

In [9]:
model = ComPert(
        **init_args,
)

In [10]:
adata = sc.read_h5ad(
    '/home/icb/gori.camps/ConditionalOT_Perturbations/Datasets/sciplex_complete_middle_subset.h5ad'
)

In [48]:
genes = torch.Tensor(
    adata[
        (adata.obs['cell_type'] == 'A549') &
        (adata.obs['condition'] == 'control')
    ].X.A
    )

cell_lines = [
    repeat_n(
        torch.Tensor([1,0,0]),
        genes.shape[0]
    )
]

drugs = (
    repeat_n(torch.Tensor([63]), genes.shape[0]).squeeze().to(torch.long),
    repeat_n(torch.Tensor([10000]), genes.shape[0]).squeeze().to(torch.long)
)

means, stds = compute_prediction(
    model,
    genes=genes,
    emb_covs=cell_lines,
    emb_drugs=drugs,
)