In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import datamol as dm

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm
import warnings

warnings.filterwarnings("ignore")
tqdm.pandas()

from TDC_tasks.utils_HTS import open_dataset

datasets = ['ALDH1', 'ESR1_ant', 'GBA', 'HIV', 'MAPK1', 'MTORC1',
       'PPARG', 'SARSCoV2_Vitro_Touret', 'TP53',
       'VDR', 'cav3_t-type_calcium_channels_butkiewicz',
       'm1_muscarinic_receptor_agonists_butkiewicz',
       'm1_muscarinic_receptor_antagonists_butkiewicz',
       'orexin1_receptor_butkiewicz',
        'PKM2', 'OPRK1',]
PAPER_FIGDIR = "../../paper/Paper/fig"
df_easiness_dict = {"task_name":[], "value": [], "i":[], "j":[]}
dist_mat = {}


%matplotlib inline
for task_name in datasets:
    if len(task_name.replace("_", " _").split("_")) == 1:
        task_name_plot = task_name
    else:
        task_name_plot = "".join(task_name.replace("_", " _").split("_")[:-1]).replace("m1 muscarinic receptor","Musc. rec.").replace("t-type calcium ", "")
    df = open_dataset(f"data/HTS/{task_name}_preprocessed.csv", "../fs_mol/preprocessing/utils/helper_files", 30000, True)
    df = df[df.Y==1]
    df["mol"] = df.Drug.progress_apply(dm.to_mol)
    dist_pos = dm.pdist(df.mol, n_jobs=1)
    dist_mat[task_name_plot] = dist_pos
    for i in range(dist_pos.shape[0]):
        for j in range(dist_pos.shape[0]):
            if not i==j:
                df_easiness_dict["task_name"].append(task_name_plot)
                df_easiness_dict["value"].append(dist_pos[i,j])
                df_easiness_dict["i"].append(i)
                df_easiness_dict["j"].append(j)

In [None]:
df_easiness = pd.DataFrame(df_easiness_dict)
df_easiness.task_name = df_easiness.task_name.apply(lambda x: x.replace("antagonists", "ant.").replace("agonists", "agon.").replace("receptor", "rec.").replace("channels", "ch."))
df_easiness["value"] = 1-df_easiness["value"]
df_easiness.groupby(["task_name","i"]).max().reset_index().groupby("task_name")["value"].median()

In [None]:
df_easiness.groupby(["task_name","i"]).max().reset_index().groupby("task_name")["value"].median()

In [None]:
df_easiness.groupby(["task_name","i"]).max().reset_index().groupby("task_name")["value"].quantile(0.75)

In [None]:
df_easiness.groupby(["task_name","i"]).max().reset_index().groupby("task_name")["value"].mean().sort_values()

In [None]:
for i, task in enumerate(dist_mat.keys()):
    cg = sns.clustermap(1-dist_mat[task], cmap="viridis", vmin=0.0,vmax = 1, figsize=(20,20), yticklabels=False, xticklabels=False, cbar=None)
    cg.cax.set_visible(False)
    cg.ax_row_dendrogram.set_visible(False)
    cg.ax_col_dendrogram.set_visible(False)
    cg.xticklabels = []
    cg.savefig(PAPER_FIGDIR + "/hts_tanimot_clustermap_{}.pdf".format(task))
    plt.clf()

In [None]:
FIGSIZE = 11

fig = plt.figure(figsize=(FIGSIZE, FIGSIZE))
subfigs = fig.subfigures(1, 2, width_ratios=[0.8, 2.2])
import matplotlib.image as mpimg

axsLeft = subfigs[0].subplots(1, 1, sharey=True)

closest_neighb = df_easiness.groupby(["task_name","i"]).max().reset_index().sort_values("value")
sns.violinplot(
    closest_neighb,
    x="value", 
    y= "task_name",
    hue="task_name",
    saturation = 0.5,
    palette=sns.color_palette("husl", 8),
    ax = axsLeft,
    bw_adjust=.4,
    cut=0,
    order = closest_neighb.groupby("task_name")["value"].median().sort_values().index.tolist(),
)# common_norm=False)
axsLeft.set_xlabel("Tanimoto similarity to the closest\n neighbor between positive examples")
axsLeft.set_ylabel("")

n_figs = 4
axsRight = subfigs[1].subplots(n_figs, n_figs,)

for i, task in enumerate(dist_mat.keys()):
    img = mpimg.imread("../paper/Paper/fig/hts_tanimot_clustermap_{}.png".format(task))
    axsRight[i%n_figs,i//n_figs].imshow(img)
    axsRight[i%n_figs,i//n_figs].set_axis_off()
    axsRight[i%n_figs,i//n_figs].set_title(task)

fig.savefig(PAPER_FIGDIR + "/htd_tanimot_sim_cplx.pdf",bbox_inches='tight')

In [None]:
import os
import sys

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm
from matplotlib.ticker import FormatStrFormatter


In [None]:
%matplotlib inline

In [None]:
#Import all csv in data/HTS
df = pd.DataFrame()

for file in tqdm(os.listdir("results/HTS")):
    if not "query_prop" in pd.read_csv("results/HTS/"+file, index_col=0):
        print(file)
        continue
    if not "maml" in file:
        df_tmp = pd.read_csv("results/HTS/"+file, index_col=0)

        groupbydf_max_mols = df_tmp.groupby(["task_name", "task_prop", "task_size", "model"]).n_mols.max().reset_index()

        df_tmp = df_tmp.join(
            groupbydf_max_mols.set_index(["task_name", "task_prop", "task_size", "model"]),
            on=["task_name", "task_prop", "task_size", "model"],
            rsuffix="_max"
        )

        df_tmp["top-%"] = df_tmp["n_mols"] / df_tmp["n_mols_max"] *100
        df_tmp = df_tmp[(df_tmp["top-%"] < 30) | (df_tmp["top-%"]> 95)]
        df = pd.concat([df, df_tmp])
        del df_tmp

# 11,32,37,64
df.groupby("task_name").model.unique()

In [None]:
TASKS = ['TP53',
 'OPRK1',
 'HIV',
 'GBA',
 'ESR1_ant',
 'cav3_t-type_calcium_channels_butkiewicz',
 'PKM2',
 'VDR',
 'MTORC1',
 'm1_muscarinic_receptor_antagonists_butkiewicz',
 'PPARG',
 'orexin1_receptor_butkiewicz',
 'MAPK1',
 'ALDH1',
 'm1_muscarinic_receptor_agonists_butkiewicz',
 'SARSCoV2_Vitro_Touret']

df = df[df.task_name.isin(TASKS)]

In [None]:
from tqdm import tqdm
tqdm.pandas()

hue_order = ["clamp", "linear_probe", "simplebsl", "adkt",  "protonet", "simsearch"]

In [None]:
cmap = {
    "simsearch": "black",
    "adkt": "mediumorchid",
    "protonet":"dodgerblue",
    "l-probe": "gold",
    "q-probe": "red",
    "maml": "indigo",
    "clamp":"olive"
}

plt.rcParams['text.usetex'] = False

df.model = df.model.apply(lambda x: x.replace("linear_probe", "l-probe").replace("simplebsl", "q-probe"))

In [None]:
from matplotlib.ticker import FormatStrFormatter
sns.set_style("white")

def plot_hitrate_lp_double_fig(ylim_task,FIG_SIZE = 2.5, fig_multiplier= 2.5, prop = 0.05, bbox_to_anchor=(.5, -8.1)):
    fig_glob = plt.figure(figsize=(FIG_SIZE*df.task_size.nunique()*fig_multiplier, FIG_SIZE*df.task_name.nunique()))
    subfigs = fig_glob.subfigures(1, 2, width_ratios=[1,1])

    for i_plot in range(2):
        fig = subfigs[i_plot]
        axes = fig.subplots(
            df.task_name.nunique()//2,
            df.task_size.nunique(),
            sharex=True,
        )
        considered_tasks = TASKS[:8] if i_plot == 0 else TASKS[8:]

        fig.subplots_adjust(hspace=0.1)

        run_df = df[
            (df["top-%"] <= 30) & (df["task_prop"] == prop) & (df.task_name.isin(considered_tasks))
        ]
        run_df["hitrate"] *=100

        for i, task_name in enumerate(considered_tasks):
            if len(task_name.split("_")) == 1:
                task_name_plot = task_name
            else:
                task_name_plot = "".join(task_name.replace("_", " _").split("_")[:-1]).replace("m1 muscarinic receptor","Musc. rec.").replace("t-type calcium ", "")

            for j, task_size in enumerate(run_df.task_size.unique()):
                sns.lineplot(
                    data = run_df[(run_df.task_name == task_name) & (run_df.task_size == task_size)],
                    x="top-%",
                    y="hitrate",
                    hue="model",
                    ax=axes[i,j],
                    palette=cmap,
                    errorbar=None,
                    legend = (i+j) + i_plot == 0
                )
                # Add hline correspondig to the random hitrate
                axes[i,j].axhline(100*run_df[(run_df.task_name == task_name) & (run_df.task_size == task_size)].query_prop.mean(), alpha = 0.5, color = "grey", linestyle="--")

                axes[i,j].locator_params(axis='y', nbins=3)
                axes[i,j].yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
                if task_name_plot in ylim_task.keys():
                    axes[i,j].set_ylim(ylim_task[task_name_plot][0], ylim_task[task_name_plot][1])
                axes[i,j].set_xlim(1,14)
                axes[i,j].set_ylabel("")
                axes[i,j].set_xlabel("")

                axes[0,j].set_title("$|\mathcal{S}|$ = " + "{}".format(task_size))
                axes[i,0].set_ylabel("{}\nHitrate (%)".format(task_name_plot))
                axes[-1,j].set_xlabel("top-k%" )

                #if not the first column, don't show yticks
                if j != 0:
                    axes[i,j].tick_params(axis='y', which='both', left=False, right=False, labelleft=False)

                axes[i,j].grid()
        if i_plot == 0:
            axes[0,0].legend(bbox_to_anchor=bbox_to_anchor, loc=2, borderaxespad=0., ncol=3, )#prop={'size': 15})




In [None]:
%matplotlib inline
ylim_task = {
    "cav3 channels ": (2,10),
    "orexin1 receptor ": (0.5,5),
    "SARSCoV2 Vitro ":(3,12),
    "Musc. rec. agonists ": (0.4,2.5),
    "Musc. rec. antagonists ": (0.8,5),
    "HIV": (4.5,30),
    "ALDH1": (4,8),
    "ESR1 ": (1,10),
    "GBA": (0,5),
    "MAPK1": (0,3),
    "MTORC1": (0.2,1.8),
    "PPARG": (0,5),
    "TP53": (1.5,10),
    "VDR": (2.5,15),
    "PKM2": (0,7),
    "OPRK1": (0,2),
}
plot_hitrate_lp_double_fig(ylim_task,1.0, fig_multiplier=3, bbox_to_anchor=(0.3, -8.1))
plt.savefig(PAPER_FIGDIR + "/hts_005_lineplot.pdf")



In [None]:
ylim_task = {
    "cav3 channels ": (2,10),
    "orexin1 receptor ": (0.5,5),
    "SARSCoV2 Vitro ":(3,12),
    "Musc. rec. agonists ": (0.4,2.5),
    "Musc. rec. antagonists ": (0.8,5),
    "HIV": (4.5,30),
    "ALDH1": (4,8),
    "ESR1 ": (1,10),
    "GBA": (0,5),
    "MAPK1": (0,3),
    "MTORC1": (0.2,1.8),
    "PPARG": (0,5),
    "TP53": (1.5,10),
    "VDR": (2.5,15),
    "PKM2": (0,7),
    "OPRK1": (0,2),
}

plot_hitrate_lp_double_fig(ylim_task,1.0, fig_multiplier=3, bbox_to_anchor=(0.3, -8.1), prop=0.1)
plt.savefig(PAPER_FIGDIR + "/hts_01_lineplot.pdf")



In [None]:
HITRATES = list(range(1,16)) + [100]

new_df = []
for hitrate in tqdm(HITRATES):
    df[f"top-{hitrate}%"] = abs(df["top-%"] - hitrate)
    closest_value = df.groupby(["task_name", "task_size", "model", "random_seed", "task_prop"]).min()
    tmp = df.groupby(["task_name", "task_size", "model", "random_seed", "task_prop"])[f"top-{hitrate}%"].min()
    tmp = df.join(
        tmp.reset_index().set_index(["task_name", "task_size", "model", "random_seed", "task_prop", f"top-{hitrate}%"]),
        on=["task_name", "task_size", "model", "random_seed", "task_prop", f"top-{hitrate}%"],
        rsuffix="_min",
        how="inner"
    )
    tmp = tmp.groupby(["task_name", "task_size", "model", "random_seed", "task_prop"]).mean().reset_index()
    new_df.append(
        tmp[
            ["task_name", "task_size", "model", "hitrate", "random_seed", "task_prop"]
        ].rename(columns={"hitrate": f"top-{hitrate}%"})
    )



In [None]:
df_bins = pd.DataFrame()
for df_hit in new_df:
    if df_bins.empty:
        df_bins = df_hit
    else:
        df_bins = df_bins.join(
            df_hit.set_index(["task_name", "task_size", "model", "random_seed", "task_prop"]),
            on=["task_name", "task_size", "model", "random_seed", "task_prop"],
            how="inner"
        )
        print(df_bins.shape)

In [None]:
HITRATES = HITRATES[:-1]

In [None]:
df_bins.model.unique()

In [None]:
for hitrate in main_HITRATES:
    df_bins[f"delta-top-{hitrate}%"] = (df_bins[f"top-{hitrate}%"] - df_bins[f"top-100%"])/ df_bins[f"top-100%"]

#delta-simsearch is the same but comparing all models to the simsearch baseline
hitrates_simsearch = df_bins[df_bins.model == "simsearch"].drop(columns=["model"]).rename(columns={"delta-top-1%": "delta-top-1%-simsearch", "delta-top-5%": "delta-top-5%-simsearch", "delta-top-10%": "delta-top-10%-simsearch", "delta-top-15%": "delta-top-15%-simsearch"})

df_bins = df_bins.join(
    hitrates_simsearch.set_index(["task_name", "task_size", "random_seed"]),
    on=["task_name", "task_size", "random_seed"],
    rsuffix="_simsearch"
)

In [None]:
#Ranking of the models
HITRATES = list(range(1,16))

In [None]:
from autorank import autorank, plot_stats, create_report, latex_table

# Rank models on each task_name, task_prop, task_size
df_ranks_full = pd.DataFrame()
for hitrate in HITRATES:

    print(f"Ranking for top-{hitrate}% hitrate")
    df_hitrate = df_bins[[col for col in df_bins.columns if (not col.startswith("top") or col == f"top-{hitrate}%")]]
    df_hitrate= df_hitrate.pivot_table(index=["task_name", "task_size", "random_seed",], columns="model", values=f"top-{hitrate}%")
    df_hitrate.columns.name=None
    df_hitrate = df_hitrate.reset_index()

    df_ranks = {
        "task_name": [],
        "task_size": [],
        "model": [],
        "rank": [],
    }

    for task_name in tqdm(df.task_name.unique()):
        for task_size in df.task_size.unique():
            df_to_rank = df_hitrate[(df_hitrate.task_name == task_name) & (df_hitrate.task_size == task_size)].drop(columns=["random_seed", "task_size","task_name",])
            ranks = autorank(df_to_rank, alpha=0.0125, verbose=False).rankdf.meanrank
            for model, rank in ranks.items():
                df_ranks["task_name"].append(task_name)
                df_ranks["task_size"].append(task_size)
                df_ranks["model"].append(model)
                df_ranks["rank"].append(rank)

    df_ranks = pd.DataFrame(df_ranks)
    df_ranks["top-%"] = hitrate
    df_ranks_full = pd.concat([df_ranks_full, df_ranks])

In [None]:
TDC = [
 "HIV", "SARSCoV2_Vitro_Touret", "orexin1_receptor_butkiewicz", "m1_muscarinic_receptor_agonists_butkiewicz", "m1_muscarinic_receptor_antagonists_butkiewicz", "cav3_t-type_calcium_channels_butkiewicz"
]

In [None]:
%matplotlib inline
#barplot rank
SIZES = [16, 32, 64,128]


FIGSIZE = 3.5

def plot_ranking_topk(df, topks, ylim = (5.5,2.1), alpha = 0.3):
    fig,axes = plt.subplots(2,2,figsize=(FIGSIZE,FIGSIZE,), sharex=True,sharey=True)
    fig.subplots_adjust(hspace=0.3)
    for i, top in enumerate(topks):
        ax = axes[i//2,i%2]
        df_ranks_this_plot = df[(df["top-%"] == top)]
        df_ranks_this_plot["task_name_plot"] = df_ranks_this_plot.task_name.apply(lambda x: "".join(x.replace("receptor_", "").replace("_", " _").split("_")[:3] if "muscarinic" in x else x.replace("_", " _").split("_")[:2]))
        df_ranks_this_plot.groupby(["task_size", "model"]).mean().reset_index().pivot_table(index="task_size", columns="model", values="rank")

        meanrank = df_ranks_this_plot.groupby(["task_size", "model", "task_name_plot"]).mean().reset_index()
        hue_order = ["clamp","q-probe","l-probe",  "adkt",  "protonet", "simsearch"]
        sns.pointplot(
            data = meanrank,
            x="task_size",
            y="rank",
            hue = "model",
            ax=ax,
            palette=cmap,
            hue_order=hue_order,
            errorbar=("ci", 90),
            alpha=0.1,
            capsize=.1,
            scale=.5,
            legend=False
        )
        sns.pointplot(
            data = meanrank,
          x="task_size",
          y="rank",
          hue = "model",
          ax=ax,
          palette=cmap,
          hue_order=hue_order,
          errorbar=None,
            scale = 0.6,
          alpha = alpha,
            legend=False
         )
        sns.pointplot(
            data = meanrank[meanrank.model=="q-probe"],
          x="task_size",
          y="rank",
          hue = "model",
          ax=ax,
          palette=cmap,
          hue_order=hue_order,
          errorbar=None,
            scale = 0.6,
          alpha = 1
         )
        # invert y scale
        ax.invert_yaxis()
        ax.set_ylim(ylim[0], ylim[1])
        if i == 3:
            ax.legend(bbox_to_anchor=(0,-0.65), loc="center", borderaxespad=0., ncol=3)
        else:
            ax.get_legend().remove()
        if i//2 == 1:
            ax.set_xlabel("$|\mathcal{S}|$")
        else:
            ax.set_xlabel("")
        ax.set_ylabel("Mean rank")
        ax.set_title("top-{}%".format(top))
        #ax.grid()


In [None]:
plot_ranking_topk(df_ranks_full[(df_ranks_full.task_name.isin(TDC))], topks = [1,5,10,15])
plt.savefig(PAPER_FIGDIR + "/hts_ranking.pdf", dpi=300, bbox_inches='tight')

In [None]:
plot_ranking_topk(df_ranks_full[~(df_ranks_full.task_name.isin(TDC))], topks = [1,5,10,15])
plt.savefig(PAPER_FIGDIR + "/hts_ranking_lit.pdf", dpi=300, bbox_inches='tight')

In [None]:
plot_ranking_topk(df_ranks_full, topks = [1,5,10,15], ylim=(5.3,2.))
plt.savefig(PAPER_FIGDIR + "/hts_ranking_full.pdf", dpi=300, bbox_inches='tight')

In [None]:
# barplot top 10% hitrate
%matplotlib inline
TASKS_ORDER = [
    "m1_muscarinic_receptor_agonists_butkiewicz",
    "orexin1_receptor_butkiewicz",
    "SARSCoV2_Vitro_Touret",
    "HIV",
    "m1_muscarinic_receptor_antagonists_butkiewicz",
    "cav3_t-type_calcium_channels_butkiewicz",

]

def plot_barplot_hitrate(
    axes,
    fig,
    prop = 0.05,
    size = 64,
    top = 5,
    bbox_to_anchor=(.5, -.1)
):
    df_this_plot = df_bins[
               (df_bins.task_size == size) #& (df_bins.task_prop == prop)
    ]
    df_this_lineplot = df[
                      (df.task_size == size) #& (df.task_prop == prop)
        ]
    
    df_this_plot["top-{}%".format(top)] = df_this_plot["top-{}%".format(top)] * 100
    df_this_lineplot.hitrate= df_this_lineplot.hitrate * 100
    
    df_this_plot.model = df_this_plot.model.apply(lambda x: x.replace("linear_probe", "l-probe").replace("simplebsl", "q-probe"))
    df_this_lineplot.model = df_this_lineplot.model.apply(lambda x: x.replace("linear_probe", "l-probe").replace("simplebsl", "q-probe"))
    hue_order_intro = ["clamp","q-probe", "l-probe", "adkt",  "protonet", "simsearch"]

    
    for i, task_name in enumerate(TASKS_ORDER):
        sns.barplot(
            data = df_this_plot[(df_this_plot.task_name == task_name)],
            x="model",
            y="top-{}%".format(top),
            hue = "model",
            ax=axes[i],
            palette=cmap,
            hue_order=hue_order_intro,
            errorbar=("ci", 90),
            err_kws={"color": ".5", "linewidth": .5},
            capsize=.4,
            legend=i==0,
            order = hue_order_intro,
            gap = 0.1
        )
        real_prop_active = df_this_plot[(df_this_plot.task_name == task_name)]["top-100%"].mean().round(4)*100
    
        axes[i].tick_params(bottom=False)
        axes[i].set(xticklabels=[], xlabel=None)
        #ylim to the max of the task

        if i!=0:
            axes[i]
        if task_name == "HIV":
            task_name_plot = task_name
        else:
            task_name_plot = "".join(task_name.replace("_", " _").split("_")[:-1]).replace("m1 muscarinic receptor","Musc. rec.").replace("t-type calcium ", "").replace("nists",".")
        axes[i].set_title(task_name_plot)
        for item in axes[i].get_xticklabels():
            item.set_rotation(45)
        axes[i].set_ylim(0,df_this_plot[(df_this_plot.task_name == task_name)]["top-{}%".format(top)].max()*0.9)
        axes[i].set_xlabel("")
        axes[i].set_ylabel("")
        axes[i].locator_params(axis='y', nbins=3)
        axes[i].yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
    axes[0].legend(bbox_to_anchor=bbox_to_anchor, loc=2, borderaxespad=0., ncol=6, prop={'size': 12})
    fig.supylabel("top {}% hitrate (%)".format(top,))
    fig.supxlabel("Few-shot models trained with {} datapoints".format(size,))

def plot_meanrank(ax):
    df_ranks_this_plot = df_ranks_full[(df_ranks_full["top-%"] <= 15) & (df_ranks_full.task_name.isin(TDC))]

    SIZES = [16, 32, 64,128]
    df_ranks_this_plot["task_name_plot"] = df_ranks_this_plot.task_name.apply(lambda x: "".join(x.replace("receptor_", "").replace("_", " _").replace("nists",".").split("_")[:3] if "muscarinic" in x else x.replace("_", " _").split("_")[:2]))
    df_ranks_this_plot.groupby(["task_size", "model"]).mean().reset_index().pivot_table(index="task_size", columns="model", values="rank")
    meanrank = df_ranks_this_plot.groupby(["task_size", "model", "task_name_plot"]).mean().reset_index()
    
    hue_order = ["clamp","q-probe","l-probe",  "adkt",  "protonet", "simsearch"]
    sns.pointplot(data = meanrank, x="task_size", y="rank", hue = "model", ax=ax, palette=cmap, hue_order=hue_order, errorbar=("ci", 90), alpha=0.1, capsize=.1, scale=.7, legend=False)
    sns.pointplot(data = meanrank, x="task_size", y="rank", hue = "model", ax=ax, palette=cmap, hue_order=hue_order, errorbar=None, legend=False, alpha=0.2)
    sns.pointplot(data = meanrank[meanrank.model.isin(["q-probe","l-probe",])], x="task_size", y="rank", hue = "model", ax=ax, palette=cmap, hue_order=hue_order, errorbar=None, legend=False,)
    sns.pointplot(data = meanrank[meanrank.model.isin(["clamp",])], x="task_size", y="rank", hue = "model", ax=ax, palette=cmap, hue_order=hue_order, errorbar=None, legend=False, alpha = 0.5)

    # invert y scale
    ax.invert_yaxis()
    ax.set_ylim(5,2.2)
    
    ax.set_xlabel("Number of labelled points in the HTS task")
    ax.set_ylabel("Mean ranks of the models")


In [None]:

FIGSIZE=11
fig = plt.figure(figsize=(FIGSIZE, FIGSIZE/3.5))
subfigs = fig.subfigures(1, 2, width_ratios=[1.7, 1])

axsLeft = subfigs[0].subplots(2, 3)
subfigs[0].subplots_adjust(wspace=0.5, hspace=0.3)
plot_barplot_hitrate(axsLeft.flatten(), subfigs[0], bbox_to_anchor=(0.3,-1.8), top = 5, size=16)
axsRight = subfigs[1].subplots(1, 1)
plot_meanrank(axsRight)

fig.savefig(PAPER_FIGDIR + "/intro1.pdf" ,bbox_inches='tight')

In [None]:

FIGSIZE=11
fig = plt.figure(figsize=(FIGSIZE, FIGSIZE/3.5))
subfigs = fig.subfigures(1, 2, width_ratios=[1.7, 1])

axsLeft = subfigs[0].subplots(2, 3)
subfigs[0].subplots_adjust(wspace=0.5, hspace=0.3)
plot_barplot_hitrate(axsLeft.flatten(), subfigs[0], bbox_to_anchor=(0.3,-1.8), top = 5, size=32)
axsRight = subfigs[1].subplots(1, 1)
plot_meanrank(axsRight)

fig.savefig(PAPER_FIGDIR + "/intro2.pdf" ,bbox_inches='tight')

In [None]:
%matplotlib inline
#barplot rank
SIZES = [16, 32, 64,128]


FIGSIZE = 3.5

def plot_ranking_topk_inv(df, ylim = (5.5,2.1), alpha = 0.3, dotsize = .2):
    fig,axes = plt.subplots(2,2,figsize=(FIGSIZE,FIGSIZE,), sharex=True,sharey=True)
    fig.subplots_adjust(hspace=0.3)
    for i, size in enumerate([16,32,64,128]):
        ax = axes[i//2,i%2]
        df_ranks_this_plot = df[(df["task_size"] == size)]
        df_ranks_this_plot["task_name_plot"] = df_ranks_this_plot.task_name.apply(lambda x: "".join(x.replace("receptor_", "").replace("_", " _").split("_")[:3] if "muscarinic" in x else x.replace("_", " _").split("_")[:2]))
        df_ranks_this_plot.groupby(["task_size", "model"]).mean().reset_index().pivot_table(index="task_size", columns="model", values="rank")

        meanrank = df_ranks_this_plot.groupby(["top-%", "model", "task_name_plot"]).mean().reset_index()
        hue_order = ["clamp","q-probe","l-probe",  "adkt",  "protonet", "simsearch"]
        sns.pointplot(
            data = meanrank,
            x="top-%",
            y="rank",
            hue = "model",
            ax=ax,
            palette=cmap,
            hue_order=hue_order,
            errorbar=None,
            alpha=0.1,
            capsize=.1,
            scale=dotsize,
            legend=False
        )
        sns.pointplot(
            data = meanrank,
          x="top-%",
          y="rank",
          hue = "model",
          ax=ax,
          palette=cmap,
          hue_order=hue_order,
          errorbar=None,
            scale = dotsize,
          alpha = alpha,
            legend=False
         )
        sns.pointplot(
            data = meanrank[meanrank.model=="q-probe"],
          x="top-%",
          y="rank",
          hue = "model",
          ax=ax,
          palette=cmap,
          hue_order=hue_order,
          errorbar=None,
            scale = dotsize,
          alpha = 1
         )
        # invert y scale
        ax.invert_yaxis()
        ax.set_ylim(ylim[0], ylim[1])
        if i == 3:
            ax.legend(bbox_to_anchor=(0,-0.65), loc="center", borderaxespad=0., ncol=3)
        else:
            ax.get_legend().remove()
        if i//2 == 1:
            ax.set_xlabel("top-k%")
        else:
            ax.set_xlabel("")
        ax.set_ylabel("Mean rank")
        ax.set_title("|$\mathcal{S}$| = "+"{}".format(size))
        # Only place 1 5 10 15 on the x axis
        ax.set_xticks([0,4,9,14])
        ax.grid()


In [None]:
plot_ranking_topk_inv(df_ranks_full[(df_ranks_full.task_name.isin(TDC))], dotsize=.4)
plt.savefig(PAPER_FIGDIR + "/hts_ranking_inv.pdf", dpi=300, bbox_inches='tight')

In [None]:
plot_ranking_topk_inv(df_ranks_full[~(df_ranks_full.task_name.isin(TDC))], dotsize=.4)
plt.savefig(PAPER_FIGDIR + "/hts_ranking_inv_lit.pdf", dpi=300, bbox_inches='tight')

In [None]:
plot_ranking_topk_inv(df_ranks_full, dotsize=.4, ylim=(4.9,2))
plt.savefig(PAPER_FIGDIR + "/hts_ranking_inv_full.pdf", dpi=300, bbox_inches='tight')