In [None]:
import os
import pickle as pkl
import numpy as np
import pandas as pd

import anndata as ad
import scanpy as sc
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
from perturbot.eval.prediction import get_evals_preds, get_evals

In [None]:
matplotlib.rcParams['pdf.fonttype'] = 42 
warnings.filterwarnings('ignore')

In [None]:
methods = [
    "EOT_ott",
    "LEOT_ott",
    "EGW_ott",
    "EGW_all_ott",
    "EGWL_ott",
    "ECOOT",
    "ECOOTL",
    "VAE",
    "VAE_label",
    "perfect",
    "by_conc",
    "random",
]

In [None]:
with open("/gpfs/scratchfs01/site/u/ryuj6/OT/data/chemical_screen/chemical_screen_subsampled_2000.pkl", "rb") as f:
    data = pkl.load(f)
Y_v = data["Xt_dict"][3].mean(axis=0)

In [None]:
treatment_idx = pd.read_csv("../../../../data/chemical_screen/chemical_screen_pca_idx.txt", header=None)[0]
conc_idx = pd.read_csv("../../../../data/chemical_screen/concentration_idx_pca.csv")["0"]

In [None]:
kinase_used = ad.read_h5ad("../../../../data/chemical_screen/kinase_used.h5ad")

In [None]:
treatment_idx

In [None]:
method_to_tidx_to_eval = {m:{} for m in methods}
for m in methods:
    for i in range(5):
        try:
            with open(f"test_{m}.{i}.pkl", "rb") as f:
                d = pkl.load(f)
        except:
            print(f"test_{m}.{i}.pkl does not exist")
            continue
        Y_pred = d['pred']['Y_pred']
        Y_true = d['pred']['Y_true']
        Z = d['pred']['test_Z']
        method_to_tidx_to_eval[m][i] = get_evals(
            Y_true,
            Y_pred,
            prediction_id="eval",
            full=False,
            agg_method="mean",
            norm_Y = Y_v,
        )

In [None]:
cidx = pd.concat(method_to_tidx_to_eval[m], axis=1).columns[[0,1,2,4]]

In [None]:
evals = []
for m in methods:
    metrics = pd.concat(method_to_tidx_to_eval[m], axis=1)
    #mean_metric = metrics.loc[:,metrics.columns.get_level_values(0).isin([0,1,2,4])].mean(axis=1)
    mean_metric = metrics.mean(axis=1)
    mean_metric.name = m
    evals.append(mean_metric)

In [None]:
pred_eval_df = pd.concat(evals, axis=1).T
pred_eval_df

In [None]:
pred_eval_df.iloc[:,:3]

In [None]:
pred_ranks = pd.concat([pred_eval_df.iloc[:-3, :4].rank(ascending=False), pred_eval_df.iloc[:-3, [-1]].rank()], axis=1)
pred_ranks.mean(axis=1)

### Matching

In [None]:
with open("test_LEOT_ott.0.pkl", "rb") as f:
    d = pkl.load(f)

In [None]:
total_sum = 0
for k, v in d["T"]['match'].items():
    total_sum += v.sum()

In [None]:
rel_dfracs_method_to_tidx[m][i]

In [None]:
foscttm_method_to_tidx = {m:{} for m in methods}
dfracs_method_to_tidx = {m:{} for m in methods}
rel_dfracs_method_to_tidx = {m:{} for m in methods}
for m in methods:
    for i in range(5):
        with open(f"test_{m}.{i}.e.pkl", "rb") as f:
            d = pkl.load(f)
        f = d['matching_evals'][0]['foscttm']
        foscttm_method_to_tidx[m][i] = f.mean()
        dfracs_method_to_tidx[m][i] = d['matching_evals'][0]['dfracs']
        rel_dfracs_method_to_tidx[m][i] = d['matching_evals'][0]['rel_dfracs']
        if isinstance(dfracs_method_to_tidx[m][i], pd.Series):
            dfracs_method_to_tidx[m][i] = dfracs_method_to_tidx[m][i].mean().item()
        elif isinstance(dfracs_method_to_tidx[m][i], dict):
            dfracs_method_to_tidx[m][i] = pd.Series(dfracs_method_to_tidx[m][i]).max()
    foscttm_method_to_tidx[m] = pd.Series(foscttm_method_to_tidx[m]).mean().item()
    dfracs_method_to_tidx[m] = pd.Series(dfracs_method_to_tidx[m]).mean().item()

In [None]:
match_edf = pd.concat([pd.Series(foscttm_method_to_tidx), pd.Series(dfracs_method_to_tidx)], axis=1)
match_edf.columns=["FOSCTTM", "Dfracs"]
match_edf

In [None]:
ranks = pd.concat([match_edf.iloc[:-3]["FOSCTTM"].rank(), match_edf.iloc[:-3]["Dfracs"].rank(ascending=False)], axis=1)
ranks.mean(axis=1)

### Draw UMAP

In [None]:
def draw_umap(d):
    Y_preds = []
    Y_trues = []
    Zs = []
    labs = []
    ridx=0
    for k in all_keys:
        treat_idx_to_train_idx[k] = train_keys
    for k in d['pred']["test_Z"].keys():
        rsize = d['pred']["test_Z"][k].shape[0]
        Y_pred = d['pred']['Y_pred'][ridx:(ridx+rsize),:]
        Y_true = d['pred']['Y_true'][ridx:(ridx+rsize),:]
        ridx += rsize
        Y_preds.append(Y_pred)
        Y_trues.append(Y_true)
        Zs.append(d['pred']['test_Z'][k])
        labs.extend([k]*rsize)
    adata_pred = ad.AnnData(X=np.concatenate(Y_preds), obs=pd.DataFrame({"dosage":np.concatenate(Zs), "labs":labs, "class":"pred"}))
    adata_true = ad.AnnData(X=np.concatenate(Y_trues), obs=pd.DataFrame({"dosage":np.concatenate(Zs), "labs":labs, "class":"true"}))
    adata = ad.concat([adata_pred, adata_true])
    sc.pp.pca(adata)
    sc.pp.neighbors(adata)
    sc.tl.umap(adata)
    adata.obs.dosage = adata.obs.dosage.astype('category')
    return adata

In [None]:
all_keys = list(range(13))
treat_idx_to_train_idx = {}

for test_idx in range(5):
    with open(f"test_{method}.{test_idx}.pkl", "rb") as f:
        d = pkl.load(f)
    test_keys = list(d['pred']["test_Z"].keys())
    train_keys = [k for k in all_keys if k not in test_keys]
    for k in test_keys:
        treat_idx_to_train_idx[k] = train_keys

In [None]:
pred_datas = {}
for method in methods:
    adatas = []
    for test_idx in range(5):
        with open(f"test_{method}.{test_idx}.pkl", "rb") as f:
            d = pkl.load(f)
        adata = draw_umap(d)
        adata.obs['test_idx'] = test_idx
        adatas.append(adata)
    pred_datas[method] = adatas

In [None]:
adata_dict = {}
for k, adatas in pred_datas.items():
    adata = ad.concat(adatas)
    adata.obs["treatment"] = adata.obs.labs.map(treatment_idx)
    adata.obs["dosage_"] = adata.obs.dosage.map(conc_idx)
    adata.obs.labs = adata.obs.labs.astype("category")
    adata_dict[k] = adata

In [None]:

fig, ax = plt.subplots(11, 6, figsize=(12, 26), sharex=True, sharey=True)
adatas = []

for method in ["perfect", "random", "EGWL_ott", "ECOOTL"]:
    adata = adata_dict[method]
    adata.var = kinase_used.var
    adata.obs["treatment"] = adata.obs.labs.map(treatment_idx)
    adata.obs["dosage_"] = adata.obs.dosage.map(conc_idx)
    if len(adatas) == 0:
        true_adata = adata[adata.obs["class"] == "true",:].copy()
        true_adata.obs['method'] = "true"
        adatas.append(true_adata)
    pred_adata = adata[adata.obs["class"] == "pred",:].copy()
    pred_adata.obs['method'] = method
    adatas.append(pred_adata)
all_adata = ad.concat(adatas)
sc.pp.pca(all_adata)
sc.pp.neighbors(all_adata)
sc.tl.umap(all_adata)
class_pal = {"pred":pal[1], "true":pal[0]}
dosage_pal = {"100nM":pal[0], "1uM":pal[-2], "10uM":pal[3]}
          
tdata = all_adata[all_adata.obs["class"] == "true",:]
pdata = all_adata[all_adata.obs["class"] == "pred",:]
adf = pd.DataFrame({"UMAP1":tdata.obsm["X_umap"][:,0], "UMAP2":tdata.obsm["X_umap"][:,1], 
                        "class":tdata.obs["class"], "dosage":tdata.obs["dosage_"], "labs":tdata.obs["labs"]}
                  )
i=0  
for treatment in [t for t in treatment_idx.tolist() if t != "No stim" and t != "Vehicle"]:
    ttdata = tdata[tdata.obs.treatment == treatment,:]
    tdf = pd.DataFrame({"UMAP1":ttdata.obsm["X_umap"][:,0], 
                        "UMAP2":ttdata.obsm["X_umap"][:,1], 
                        "class":ttdata.obs["class"], 
                        "dosage":ttdata.obs["dosage_"]})
    ptdata = pdata[pdata.obs.treatment == treatment,:]
    pdf = pd.DataFrame({"UMAP1":ptdata.obsm["X_umap"][:,0], 
                        "UMAP2":ptdata.obsm["X_umap"][:,1], 
                        "class":ptdata.obs["class"], 
                        "dosage":ptdata.obs["dosage_"],
                       "method":ptdata.obs['method']})
    # All embedding
    sns.scatterplot(adf.loc[adf.labs.isin(treat_idx_to_train_idx[l]),:], 
                    x="UMAP1", y="UMAP2", color="lightgrey", ax = ax[i//6, i%6], s=5, label="train", edgecolor=None, rasterized=True)
    sns.scatterplot(tdf, x="UMAP1", y="UMAP2", color=pal[0], 
                    ax = ax[i//6, i%6], s=5, edgecolor=None, rasterized=True, label="true")
    sns.scatterplot(pdf, x="UMAP1", y="UMAP2", hue="method", 
                    ax = ax[i//6, i%6], s=5, edgecolor=None, rasterized=True)
    ax[i//6, i%6].set_title(treatment)
    try:
        ax[i//6, i%6].get_legend().remove()
    except:
        pass
    i+=1

    # dosage
    sns.scatterplot(tdf, x="UMAP1", y="UMAP2", hue="dosage", palette=dosage_pal, s=5, hue_order=["100nM", "1uM", "10uM"],
                    ax = ax[i//6, i%6], edgecolor=None, rasterized=True)
    ax[i//6, i%6].get_legend().remove()
    ax[i//6, i%6].set_title(f"True")
    i += 1
    for method in ["perfect", "random", "EGWL_ott", "ECOOTL"]:
        sns.scatterplot(pdf.loc[pdf["method"] == method], x="UMAP1", y="UMAP2", palette=dosage_pal, s=5, hue_order=["100nM", "1uM", "10uM"], 
                        ax = ax[i//6, i%6], edgecolor=None, rasterized=True, hue="dosage")
        ax[i//6, i%6].get_legend().remove()
        ax[i//6, i%6].set_title(f"{method}")
        i += 1

handles, labels = ax[0,0].get_legend_handles_labels()
fig.legend(handles, labels, bbox_to_anchor=(1,0.5), loc="lower left")
handles2, labels2 = ax[0,2].get_legend_handles_labels()
fig.legend(handles2, labels2, bbox_to_anchor=(1,0.5), loc="upper left")
plt.setp(ax, box_aspect=1)
plt.tight_layout()

fig.savefig("all_umaps.pdf")

### Visualize DE genes?

In [None]:
activation_genes = ['TNFRSF18',
 'TNFRSF4',
 'IL12RB2',
 'LMNA',
 'RRM2',
 'DUSP2',
 'GBE1',
 'ZBED2',
 'IER3',
 'LTA',
 'CD109',
 'TNFAIP3',
 'SYTL3',
 'GARS',
 'SNHG15',
 'NAMPT',
 'HILPDA',
 'DUSP4',
 'RNF19A',
 'NINJ1',
 'IL2RA',
 'DDIT4',
 'PGAM1',
 'MICAL2',
 'SLC43A3',
 'SLC3A2',
 'LAG3',
 'LINC02341',
 'GNA15',
 'ZBTB32',
 'MIR155HG',
 'PIM3',
 'GK']
genes = [g for g in activation_genes if g in adata.var_names.tolist()]

In [None]:
pdata.obs

In [None]:
def draw_genescores(trt, ax):
    dfs = []
    
    adata = adata_dict["perfect"]
    adata.var = kinase_used.var
    pdata = adata[adata.obs.treatment == trt,:]
    sc.tl.score_genes(adata, genes) 
    
    tdf = adata[(adata.obs.treatment == trt) & (adata.obs['class'] == 'true'),:].obs.copy()
    tdf["group"] = "true"
    dfs.append(tdf)
    
    pdf = adata[(adata.obs.treatment == trt) & (adata.obs['class'] == 'pred'),:].obs.copy()
    pdf["group"] = "perfect"
    dfs.append(pdf)
    
    adata = adata_dict["random"]
    adata.var = kinase_used.var
    sc.tl.score_genes(adata, genes) 
    pdata = adata[adata.obs.treatment == trt,:]
    sc.tl.score_genes(pdata, genes) 
    pdf = adata[(adata.obs.treatment == trt) & (adata.obs['class'] == 'pred'),:].obs.copy()
    pdf["group"] = "random"
    dfs.append(pdf)
    
    adata = adata_dict["EGWL_ott"]
    adata.var = kinase_used.var
    sc.tl.score_genes(adata, genes) 
    pdata = adata[adata.obs.treatment == trt,:]
    sc.tl.score_genes(pdata, genes) 
    pdf = adata[(adata.obs.treatment == trt) & (adata.obs['class'] == 'pred'),:].obs.copy()
    pdf["group"] = "EGWL_ott"
    dfs.append(pdf)
    adata = adata_dict["ECOOTL"]
    adata.var = kinase_used.var
    sc.tl.score_genes(adata, genes) 
    pdata = adata[adata.obs.treatment == trt,:]
    sc.tl.score_genes(pdata, genes) 
    pdf = adata[(adata.obs.treatment == trt) & (adata.obs['class'] == 'pred'),:].obs.copy()
    pdf["group"] = "ECOOTL"
    dfs.append(pdf)
    df = pd.concat(dfs)
    sns.violinplot(data=df, x="dosage_", order=["100nM", "1uM", "10uM"], y="score", hue="group", 
                   hue_order=["true", "perfect", "random", "ECOOTL", "EGWL_ott"], ax=ax, linewidth=0.5)
    ax.get_legend().remove()
    ax.set_title(trt)
    #handles, labels = ax.get_legend_handles_labels()
    return ax

In [None]:
fig, ax = plt.subplots(4,4, figsize=(16,16))
for i, trt in enumerate(adata.obs.treatment.unique()):
    draw_genescores(trt, ax[i//4, i%4])
handles, labels = ax[0,0].get_legend_handles_labels()
fig.legend(handles, labels, loc="center left", bbox_to_anchor=(1.0, 0.5))
plt.tight_layout()
fig.savefig("activation_gene_scores.pdf")

### Running time

In [None]:
val_time_dict = {m:{} for m in methods}
epsilons = [0.01, 1e-3, 1e-4, 1e-5]
for m in methods[:-3]:
    epsilons = [0.01, 1e-3, 1e-4, 1e-5]
    if "VAE" in m:
        with open(f"val_CV_{m}.{i}.pkl", "rb") as f:
            d = pkl.load(f)
        epsilons = list(d['log'].keys())
    eps_times = {eps:[] for eps in epsilons}
    for i in range(5):
        try:
            with open(f"val_CV_{m}.{i}.pkl", "rb") as f:
                d = pkl.load(f)
        except:
            print(f"test_{m}.{i}.pkl does not exist")
            continue
        for eps in epsilons:
            val_times = []
            for k, v in d['log'][eps].items():
                if 'time' in v:
                    eps_times[eps].append(v['time'])
                else:
                    times = []
                    for x_key, evals in v.items():
                        times.append(evals['time'])
                    eps_times[eps].append(sum(times))
    for eps in epsilons:
        val_time_dict[m][eps] = sum(eps_times[eps]) / len(eps_times[eps])

In [None]:
perf_df = pd.DataFrame(val_time_dict)
perf_df = perf_df.iloc[:4,:7]

In [None]:
epsilons

In [None]:
perf_df["ECOOTL"] / perf_df["EGWL_ott"]

In [None]:
fig, ax = plt.subplots()
perf_df.plot(logy=True, xlabel='eps', ylabel='time (s)', ax = ax)
ax.get_legend().remove()
handles, labels = ax.get_legend_handles_labels()
ax.set_xticks([0, 1, 2, 3], [0.01, 1e-3, 1e-4, 1e-5])
fig.legend(handles, labels, loc="center left", bbox_to_anchor=(0.9, 0.5))
fig.savefig("time_complexity.pdf", bbox_inches="tight")

#### Matrix

In [None]:
dosage_pal2 = {"100nM":"yellow", "1uM":"orange", "10uM":"red"}

In [None]:
from matplotlib.gridspec import GridSpec
os.system(f"mkdir -p match_matrix")
for trt_key in treatment_idx.index.tolist():
    for i, m in enumerate(["EGWL_ott", "ECOOTL"]):
        mats = []
        for j in range(5):
            with open(f"test_{m}.{j}.pkl", "rb") as f:
                d = pkl.load(f)
            try:
                mats.append(d["T"]["pred"][trt_key])
            except KeyError:
                continue
        mat = sum(mats)/len(mats)
                            
        z = data["Zs_dict"]["dosage"][trt_key]
        cols = pd.Series(z).map(conc_idx).map(dosage_pal2)

        g=sns.clustermap(mat, col_cluster=False, 
                         row_cluster=True, linewidths=0, cmap='coolwarm', 
                         row_colors=cols.values, col_colors=cols.values, figsize=(0.1,0.1))
        reordered_ind = g.dendrogram_row.reordered_ind
        g=sns.clustermap(mat[reordered_ind,:][:, reordered_ind], col_cluster=False, 
                         row_cluster=False, linewidths=0, cmap='coolwarm', 
                         row_colors=cols.values[reordered_ind], col_colors=cols.values[reordered_ind], 
                         vmax=np.quantile(mat, 0.8), figsize=(4,4), rasterized=True)
        g.ax_row_dendrogram.set_visible(False)
        g.ax_col_dendrogram.set_visible(False)
        g.ax_cbar.set_position([1, 0.5, 0.02, 0.1])
        ax = g.ax_heatmap
        ax.set_xticks([])
        ax.set_yticks([])
        #ax.set_title(m)
        g.fig.savefig(f"match_matrix/{treatment_idx[trt_key]}.{m}.svg", bbox_inches="tight")
