# Visualize cell results

Requires trained and evaluated models on cell data.

In [None]:
import pandas as pd
import numpy as np
import os
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt

from unot.plotting.setup import setup_plt
setup_plt()

In [None]:
DRUGS = [
        "cisplatin",
        "crizotinib",
        "dabrafenib",
        "dacarbazine",
        "dasatinib",
        "decitabine",
        "dexamethasone",
        "erlotinib",
        "everolimus",
        "hydroxyurea",
        "imatinib",
        "ixazomib",
        "lenalidomide",
        "melphalan",
        "midostaurin",
        "mln",
        "olaparib",
        "paclitaxel",
        "palbociclib",
        "panobinostat",
        "regorafenib",
        "sorafenib",
        "staurosporine",
        "temozolomide",
        "trametinib",
        "ulixertinib",
        "vindesine",
    ]

In [None]:
# create the dataframes containing evaluation metrics
results = pd.DataFrame(columns=["drug", "model", "exp_name",  "data", "l2DS", "enrichment-k50", "enrichment-k100", "mmd", "total_cost", "avg_cost"])
results_w = pd.DataFrame(columns=["drug", "model", "exp_name",  "data", "mmd_w", "mmd_w_rs", "weights_mean", "weights_std", "total_cost_w", "avg_cost_w"])

# specify the directory where the results are stored 
#outroot = Path("../../unot/results/rebuttal/rebuttal/0922_submission")
outroot = Path("../results/cell/1312_reproduced")

for drug in os.listdir(outroot):
    if drug in DRUGS:
        outdir = outroot / drug
        if outdir.is_dir():
            for data in os.listdir(outdir):
                if data in ["8h_subm", "24h_subm"]:
                    e = outdir / data
                    if e.is_dir():
                        for exp_name in os.listdir(outdir / data):
                            if exp_name in [
                                "cellot", 
                                "ubot_gan",
                                "nubot", 
                                "naive",
                                "discrete",
                                "gaussian_approx",
                                "gaussian_approx_unb",
                            ]:
                                d = outdir / data / exp_name
                                for model_name in os.listdir(d):
                                    if model_name.startswith("model-"):
                                        model_dir = Path(d / model_name / "evals_iid_data_space")
                                        p = model_dir / "evals.csv"

                                        if os.path.exists(p):
                                            row = pd.read_csv(p, header=None).set_index(0).T
                                            row["drug"]=drug
                                            if data.endswith("_rebuttal"):
                                                data = data.replace("_rebuttal", "")
                                                model_name = model_name + "_norm"
                                            if exp_name == "gaussian_approx_unb":
                                                model_name = "model-gaussian-unb"
                                            row["model"] = model_name
                                            if data == "8h":
                                                exp_name = exp_name + "_old_data"
                                            row["exp_name"] = exp_name
                                            row["data"] = data
                                            results = results.append(row, ignore_index=True)

                                        p = model_dir / "evals_weights.csv"
                                        if os.path.exists(p):
                                            row = pd.read_csv(p, header=None).set_index(0).T
                                            row["drug"]=drug
                                            row["model"] = model_name
                                            row["exp_name"] = exp_name
                                            row["data"] = data
                                            results_w = results_w.append(row, ignore_index=True)

In [None]:
results[results["drug"] == "dasatinib"].head()

In [None]:
results_w.head()

In [None]:
path = outroot / Path("evals_summary.csv")
results.to_csv(path)

path = outroot / Path("evals_weights_summary.csv")
results_w.to_csv(path)

# Visualize Results

In [None]:
# join un-weighted and weighted results
rv = results.set_index(["drug", "model", "exp_name", "data"]).join(results_w.set_index(["drug", "model", "exp_name", "data"]))
rv = rv.reset_index()
rv.drug = rv.drug.str.capitalize()

In [None]:
# replace weighted metrics by un-weighted if they don't exist (i.e., for balanced models)
replace_dict = {}
for w2_col in rv.filter(regex="w2").columns:
    if "_w" not in w2_col:
        replace_dict["w2_" + "w_" + w2_col.split("w2_")[-1]] = rv[w2_col]
        replace_dict["w2_" + "w_rs_" + w2_col.split("w2_")[-1]] = rv[w2_col]
replace_dict["mmd_w"] = rv["mmd"]
replace_dict["avg_cost_w"] = rv["avg_cost"]
rv = rv.fillna(replace_dict)

In [None]:
rv.head()

In [None]:
# replace names
rv["model"] = rv["model"].replace(
    {
        "model-nubot": "NubOT",
        "model-cellot": "CellOT",
        "model-gan": "ubOT GAN",
        "model-identity": "Identity",
        "model-control": "Observed",
        "model-ot": "Discrete OT",
        "model-gaussian": "Gaussian Approx",
        "model-gaussian-unb": "Gaussian Approx Unb",
        "model-ubot": "Discrete UBOT",
        "model-nubot_v1_norm": "NubOT_norm",
    }
)

In [None]:
palette=["#F2545B", "#A7BED3", "#316594", "#cccccc", "#b0aeae", "#C4B5D0","#966EA7", "#8c0e25"]

In [None]:
# available hyperparameters for w2-distance
rv.filter(regex="w2_w_").columns

In [None]:
# fix negative values

#rv[(rv["model"] == "NubOT") & (rv["drug"] == "Panobinostat")].loc[:,"w2_w_5_0.95"] = rv[(rv["model"] == "NubOT") & (rv["drug"] == "Panobinostat")].loc[:,"w2_w_5_0.95"]
#rv.loc[95,"w2_w_5_0.95"] = rv.loc[95,"w2_w_2_0.95"]
#rv.loc[6,"w2_w_5_0.95"] = rv.loc[6,"w2_w_4_0.95"]
#rv.loc[84,"w2_w_5_0.95"] = rv.loc[84,"w2_w_4_0.95"]
#rv.loc[120,"w2_w_5_0.95"] = rv.loc[120,"w2_w_4_0.95"]
#rv.loc[354,"w2_w_5_0.95"] = rv.loc[354,"w2_w_4_0.95"]
#rv.loc[444,"w2_w_5_0.95"] = rv.loc[444,"w2_w_4_0.95"]

In [None]:
plt.figure(figsize=(20,3.6))

#plt.rcParams["font.size"] = 11
sns.set_context(context='talk', font_scale=1.0)

labels = {"mmd_w": "Weighted MMD", "w2_w": "Weighted Wasserstein Distance"}

sb = rv
# sb = sb[sb["drug"] != "panobinostat"]
timestep = "8h_subm"

if True:
    plt.figure(figsize=(20,3.6))

    # specify which metric to plot
    # metric = "w2_w_1_1.0"
    # metric = "w2_w_5_0.95"
    metric = "mmd_w"
    
    log = True
    sb = sb[sb["data"] == timestep]
    hue_order=["NubOT", "CellOT", "ubOT GAN", "Identity", "Observed", "Discrete OT", "Gaussian Approx"]#, "Gaussian Approx Unb"]
    g = sns.barplot(data=sb, y=metric, x="drug", hue="model", palette=palette, hue_order=hue_order, log=log)
    plt.xticks(rotation=45, ha="right")
    if metric in labels.keys():
        label = labels[metric]
    else:
        label = metric
    plt.legend(bbox_to_anchor=(1, 2), ncol=int(len(hue_order)))
    
    if metric == "w2_w_5_0.95":
        label = "Weighted Wasserstein Distance"
    g.set(ylabel=label, xlabel="Drug")
    t = timestep.replace("_subm", "")
    plt.title(f"Timestep: {t}")
    
    m = metric.replace(".", "-")
    # plt.savefig(f"{timestep}_{m}_COMPLETE.pdf", bbox_inches="tight", format="pdf")
    plt.show()