In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt



FS_MOL_CHECKOUT_PATH = "../"

os.chdir(FS_MOL_CHECKOUT_PATH)
sys.path.insert(0, FS_MOL_CHECKOUT_PATH)
PAPER_FIGDIR = "../paper/Paper/fig"
from utils_DTI import get_dataset_from_file, bounds_n_mols_task_name
from notebook_utils import get_dataset_target_stats, get_all_datasets_stats, get_kinases_stats
import seaborn as sns
%matplotlib inline

TASK_NAMES = ["DAVIS", "BindingDB_Kd", "BindingDB_Ki", "BindingDB_IC50", "KIBA"]

In [None]:
full_df = get_all_datasets_stats(TASK_NAMES)
df_target = get_dataset_target_stats(TASK_NAMES)
sns.set_theme(style="white")

In [None]:
sns.set_style("white")

full_df = get_all_datasets_stats(TASK_NAMES)
df_target = get_dataset_target_stats(TASK_NAMES)

fig, axes = plt.subplots(1, 2, figsize=(11, 4))
for task_name in TASK_NAMES:
    sns.kdeplot(
        df_target[df_target.task_name == task_name],
        x="Y_bin",
        label=task_name,
        common_norm=False,
        ax=axes[0],
        clip = (0,0.6),
    )
    axes[0].legend()
    axes[0].set_xlabel("Distribution of the tasks' positive proportion\n in each dataset")
    axes[0].set_xlim(0, 0.6)

    axes[1].set_xlabel("pXC50 threshold")
    sns.histplot(
        df_target[df_target.task_name == task_name],
        x="threshold",
        label=task_name,
        common_norm=False,
        ax=axes[1],
        stat="proportion",
        cumulative=True,
        element="poly",
        alpha=0.05,
        fill=True,
        binrange=(-11,-4)
    )
    sns.histplot(
        df_target[df_target.task_name == task_name],
        x="threshold",
        label=task_name,
        common_norm=False,
        ax=axes[1],
        stat="proportion",
        cumulative=True,
        element="poly",
        alpha=1,
        fill=False,
        binrange=(-11,-4),
        legend=False,
        hue_order=["DAVIS", "BindingDB_Kd", "BindingDB_IC50", "BindingDB_Ki", "KIBA"][::-1]
    )
    axes[1].set_ylim(0, 1)
    axes[1].set_xlim(-10,-4.5)
    axes[1].set_xlabel("Threshold cumulative distribution (pXC50)\n in each dataset")
#axes[1].grid()
#axes[0].grid()
axes[0].legend(bbox_to_anchor=(0.15, -0.27), loc=2, borderaxespad=0., ncol=6,)
fig.savefig(PAPER_FIGDIR + "/DTI_split.pdf", dpi=fig.dpi, bbox_inches = "tight")

In [None]:
df_target.groupby("task_name").Y_bin.mean()

In [None]:
iskinase = get_kinases_stats(TASK_NAMES)
full_df[["task_name", "Target_ID"]].drop_duplicates().join(iskinase.set_index(["Target_ID", "task_name"]), on=["Target_ID", "task_name"]).groupby(["task_name",]).iskinase.agg(["mean", "sum", "count"])

In [None]:
task_avg_pos = full_df.groupby(["task_name"]).Y_bin.mean().reset_index()


In [None]:
import pandas as pd

import torch
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np

X_ORDER = ["adkt","clamp", "q-probe","l-probe", "protonet", "maml","simsearch"]
hue_order = [
    "clamp",
    "l-probe",
    "q-probe",
    "adkt",
    "protonet",
    "simsearch",
]
cmap = {
    "simsearch": "dimgray",
    "adkt": "mediumorchid",
    "protonet":"dodgerblue",
    "l-probe": "gold",
    "q-probe": "red",
    "maml": "indigo",
    "clamp":"yellowgreen"
}

In [None]:
%matplotlib inline
task_avg_pos

In [None]:
model_names = ["adkt", "simplebsl", "linear_probe", "protonet", "simsearch", "clamp"]
results_df = pd.DataFrame()

for model_name in model_names:
    for task_name in TASK_NAMES:
        path = f"TDC_tasks/results/DTI/{model_name}_{task_name}_results.csv"
        if os.path.exists(path):
            model_results_df = pd.read_csv(path, index_col=0)
            model_results_df["model"] = model_name
            model_results_df["task"] = task_name
            results_df = pd.concat([results_df, model_results_df])
results_df = results_df.join(task_avg_pos.set_index("task_name"), on="task")
results_df.model = results_df.model.apply(lambda x: x.replace("simplebsl", "q-probe").replace("linear_probe","l-probe"))

In [None]:
results_df.model.unique()

In [None]:
run_df = results_df.groupby([col for col in results_df.columns if not (col.startswith("au") or col =="runtime")]).mean().reset_index()

for metric in ["auroc", "auroc_neg", "auprc", "auprc_neg"]:
    if "neg" not in metric:
        run_df["delta-{}".format(metric)] = run_df[metric] - run_df["Y_bin"]
    else:
        run_df["delta-{}".format(metric)] = run_df[metric] - (1-run_df["Y_bin"])

In [None]:
run_df.model = run_df.model.apply(lambda x: x.replace("linear_probe", "l-probe").replace("simplebsl", "q-probe").replace("clamp", "clamp"))

In [None]:
#same with barplots
%matplotlib inline
x_order = ["KIBA", "BindingDB_Kd", "DAVIS"]
hue_order = [
    "clamp",
    "l-probe",
    "q-probe",
    "adkt",
    "protonet",
    "simsearch",
    #"maml",
    #"multitask",
]

metric = "delta-auprc"
fig, ax= plt.subplots(1, 1, figsize=(5, 2.2))

sns.set_theme(style="whitegrid")
sns.barplot(
    x="task_name",
    y=metric,
    hue="model",
    data=run_df,
    ax=ax,
    saturation=0.7,
    capsize=.4,
    palette=cmap,
    hue_order=hue_order,
    order=x_order,
    err_kws={"color": ".5", "linewidth": 1.5, "alpha":0.5},
)
hfont = {'fontname':'Helvetica'}
ax.set_title("Performances on DTI tasks")
ax.annotate('', xy=(.9, 1), xycoords='axes fraction', xytext=(.1, 1),
            arrowprops=dict(arrowstyle="-|>", color='lightblue', linewidth=6),)
ax.text(-0.3,0.38,"Prior shift", c= "lightslategrey")
ax.legend(loc='lower center')
ax.set_xlabel("")
ax.set_ylabel("$\Delta AUPRC$")

ax.set_ylim(0.16,)
ax.legend(bbox_to_anchor=(-0.05, -0.18), loc=2, borderaxespad=0., ncol=3)



fig.savefig(PAPER_FIGDIR + "/barplot_kn_DTI.pdf",bbox_inches='tight')
plt.show()

In [None]:
#same with barplots
%matplotlib inline
x_order = ["BindingDB_Ki", "BindingDB_IC50"]
hue_order = [
    "clamp",
    "l-probe",
    "q-probe",
    "adkt",
    "protonet",
    "simsearch",
    #"maml",
    #"multitask",
]

metric = "delta-auprc"
fig, ax= plt.subplots(1, 1, figsize=(5, 2))

sns.set_theme(style="whitegrid")
sns.barplot(
    x="task_name",
    y=metric,
    hue="model",
    data=run_df,
    ax=ax,
    saturation=0.7,
    capsize=.4,
    palette=cmap,
    hue_order=hue_order,
    order=x_order,
    err_kws={"color": ".5", "linewidth": 1.5, "alpha":0.5},
)

ax.set_title("Performances on DTI tasks")
ax.legend(loc='lower center')
ax.set_xlabel("")
ax.set_ylabel("$\Delta AUPRC$")

ax.set_ylim(0.15,)
ax.legend(bbox_to_anchor=(-0.1, -0.3), loc=2, borderaxespad=0., ncol=3)

fig.savefig(PAPER_FIGDIR + "/barplot_full_DTI.pdf",bbox_inches='tight')
plt.show()