In [1]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import pathlib
import yaml

pd.set_option('display.max_rows', 100)


In [2]:
matplotlib.style.use("fivethirtyeight")
matplotlib.pyplot.rcParams['savefig.facecolor'] = 'white'
sns.set_style("whitegrid")

matplotlib.rcParams['font.family'] = "monospace"
matplotlib.rcParams['figure.dpi'] = 300
sns.set_context("poster")
plt.rcParams.update({'font.size': 5, "font.weight": "bold"})

In [3]:
drugs = ["abexinostat", "belinostat", "dacinostat", "entinostat", "givinostat", 
             "mocetinostat", "pracinostat", "tacedinaline", "trametinib"]
Drugs = [d.capitalize() for d in drugs]
Drugs

['Abexinostat',
 'Belinostat',
 'Dacinostat',
 'Entinostat',
 'Givinostat',
 'Mocetinostat',
 'Pracinostat',
 'Tacedinaline',
 'Trametinib']

# ID - trainded on all drugs

In [4]:
res = pd.read_csv("/Users/alicedriessen/Box/otperturb/alice_exp/chemCPA/chemCPA_on_cmonge_sciplex/no_ood/r2_mean_results.csv")
res.groupby(["drug", "dose"]).size()

drug                dose 
2-Methoxyestradiol  0.001    3
                    0.010    3
                    0.100    3
                    1.000    3
A-366               0.001    3
                            ..
ZM                  1.000    3
Zileuton            0.001    3
                    0.010    3
                    0.100    3
                    1.000    3
Length: 748, dtype: int64

In [5]:
len(res["drug"].unique())*4

748

In [6]:
res = res[res["drug"].isin(Drugs)]
res

Unnamed: 0.1,Unnamed: 0,cmonge_r2,cell line,drug,dose
48,A549_Abexinostat_0.001,0.895290,A549,Abexinostat,0.001
49,A549_Abexinostat_0.01,0.899969,A549,Abexinostat,0.010
50,A549_Abexinostat_0.1,0.865841,A549,Abexinostat,0.100
51,A549_Abexinostat_1.0,0.727243,A549,Abexinostat,1.000
124,A549_Belinostat_0.001,0.938475,A549,Belinostat,0.001
...,...,...,...,...,...
2132,MCF7_Tacedinaline_1.0,0.948407,MCF7,Tacedinaline,1.000
2169,MCF7_Trametinib_0.001,0.825543,MCF7,Trametinib,0.001
2170,MCF7_Trametinib_0.01,0.780346,MCF7,Trametinib,0.010
2171,MCF7_Trametinib_0.1,0.764849,MCF7,Trametinib,0.100


In [7]:
res.groupby("dose").mean(numeric_only=True)

Unnamed: 0_level_0,cmonge_r2
dose,Unnamed: 1_level_1
0.001,0.896272
0.01,0.84862
0.1,0.735158
1.0,0.736319


In [8]:
res.groupby("dose").std(numeric_only=True)

Unnamed: 0_level_0,cmonge_r2
dose,Unnamed: 1_level_1
0.001,0.136813
0.01,0.185168
0.1,0.23101
1.0,0.154024


# 5% OOD - CV

In [14]:
with open("/Users/alicedriessen/Box/otperturb/alice_exp/chemCPA/chemCPA_on_cmonge_sciplex/5perc_ood/cmonge_eval_drugdose.yaml") as f:
    logs = yaml.safe_load(f)

In [15]:
all_res = []
for drugdose in logs.keys():
    res = logs[drugdose]["mean_statistics"]
    res = pd.DataFrame.from_dict(res, orient="index").T
    res["drugdose"] = drugdose
    all_res.append(res)
res = pd.concat(all_res).reset_index()
res[["drug", "dose"]] = [c.split("_") for c in res['drugdose']]
res["dose"] = (res["dose"].astype(float) * 10000).astype(int).astype(str)

In [16]:
res

Unnamed: 0,index,mean_drug_signature,mean_mmd,mean_monge_gap,mean_r2,mean_sinkhorn div,mean_wasserstein,drugdose,drug,dose
0,0,1.964186,0.203498,14.263881,0.611097,5.776612,6.262887,2-Methoxyestradiol_0.001,2-Methoxyestradiol,10
1,0,2.105005,0.205171,14.217563,0.617975,5.961744,6.446365,2-Methoxyestradiol_0.01,2-Methoxyestradiol,100
2,0,1.840803,0.200833,14.226675,0.588790,5.622760,6.107369,2-Methoxyestradiol_0.1,2-Methoxyestradiol,1000
3,0,2.968639,0.213198,14.086956,0.378828,6.113287,6.591412,2-Methoxyestradiol_1.0,2-Methoxyestradiol,10000
4,0,0.778569,0.181318,13.877471,0.946436,3.673875,4.107665,A-366_0.001,A-366,10
...,...,...,...,...,...,...,...,...,...,...
743,0,1.779837,0.205713,14.218963,0.838126,4.795526,5.275452,ZM_1.0,ZM,10000
744,0,0.538017,0.155248,14.024675,0.970957,3.435529,3.850911,Zileuton_0.001,Zileuton,10
745,0,0.804322,0.147255,13.865276,0.967168,3.337262,3.743446,Zileuton_0.01,Zileuton,100
746,0,0.563835,0.147676,13.999237,0.970010,3.324787,3.731151,Zileuton_0.1,Zileuton,1000


In [17]:
res.groupby("dose").mean(numeric_only=True)

Unnamed: 0_level_0,index,mean_drug_signature,mean_mmd,mean_monge_gap,mean_r2,mean_sinkhorn div,mean_wasserstein
dose,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
10,0.0,1.175923,0.185564,13.725155,0.836408,4.252337,4.707038
100,0.0,1.286943,0.18718,13.703816,0.815491,4.313386,4.767268
1000,0.0,1.346401,0.189195,13.704067,0.79226,4.358272,4.811571
10000,0.0,1.522227,0.195243,13.647488,0.7599,4.491325,4.946657


In [18]:
res.groupby("dose").std(numeric_only=True)

Unnamed: 0_level_0,index,mean_drug_signature,mean_mmd,mean_monge_gap,mean_r2,mean_sinkhorn div,mean_wasserstein
dose,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
10,0.0,0.68025,0.022213,0.346955,0.154174,0.892245,0.913377
100,0.0,0.765484,0.024461,0.355923,0.177977,0.953498,0.975527
1000,0.0,0.954005,0.028888,0.384992,0.205082,1.047437,1.069549
10000,0.0,1.270437,0.034966,0.422368,0.211152,1.246603,1.271008


# Check size chemCPA

In [20]:
# Run in chemCPA environment
import torch

ModuleNotFoundError: No module named 'torch'

In [None]:
state_dict = torch.load("/Users/alicedriessen/Box/otperturb/alice_exp/chemCPA/chemCPA_on_cmonge_sciplex/9drugs_ood/model_checkpoint.pt",
          map_location=torch.device('cpu'))

In [7]:
state_dict[0]

OrderedDict([('encoder.network.0.weight',
              tensor([[-0.0101, -0.0312, -0.0267,  ...,  0.0187, -0.0038,  0.0044],
                      [ 0.0107,  0.0275,  0.0250,  ...,  0.0081, -0.0315,  0.0235],
                      [-0.0184, -0.0274,  0.0194,  ...,  0.0026, -0.0113,  0.0006],
                      ...,
                      [ 0.0015, -0.0307, -0.0128,  ...,  0.0229, -0.0075, -0.0053],
                      [ 0.0041, -0.0214,  0.0213,  ..., -0.0251, -0.0045,  0.0250],
                      [-0.0165,  0.0168,  0.0075,  ..., -0.0071, -0.0314, -0.0280]])),
             ('encoder.network.0.bias',
              tensor([-2.9492e-05, -2.4637e-05, -4.5694e-06, -1.1409e-05,  3.7845e-06,
                      -1.9967e-05, -7.6084e-07, -1.6658e-05,  2.1733e-07, -3.7652e-06,
                       1.4351e-05, -4.0652e-06, -8.6999e-06, -5.9694e-05,  7.9890e-06,
                      -2.8235e-05,  4.0962e-05,  1.5462e-05, -3.5074e-05, -3.4882e-05,
                      -2.3162e-05,  

In [23]:
n_total_params  = sum(p.numel() for p in state_dict[0].values())
n_total_params


1374742