In [None]:
import pandas as pd
import numpy as np

import seaborn as sns
from matplotlib import pyplot as plt

In [None]:
annotation="flex"

In [None]:
# There is no difference in *cluster_annotated.csv and *nc_annotated.csv files!
df = pd.read_csv(f"{annotation}nc_annotated.csv")
df

In [None]:
def load_scores(df, model, annotation, prefix):

    tfname = annotation if annotation else "ligand"

    print(tfname, model, prefix)


    score = pd.concat((
    pd.read_csv(f"../training/{tfname}/{model}/{annotation}{prefix}test0.out", sep=" ", header=None),
    pd.read_csv(f"../training/{tfname}/{model}/{annotation}{prefix}test1.out", sep=" ", header=None),
    pd.read_csv(f"../training/{tfname}/{model}/{annotation}{prefix}test2.out", sep=" ", header=None),
    ))

    if prefix == "nc":
        # No affinity column
        to_drop = [1,2,3,5,6,7]
    elif prefix == "cluster":
        to_drop = [1,2,3,4,6,7,8]

    score.drop(columns=to_drop, inplace=True)
    # Rename either column 4 (no clustering) or 5 (clustering) as ligand name
    # The shift between the two cases is given by the 
    score.rename(columns={0: "CNNscore", 4: "ligname", 5: "ligname"}, inplace=True)
    score.dropna(inplace=True) # Last row contains NaN (it is  actually a comment)

    def getid(row):
        namesplit = row["ligname"].split("/")

        pocket = namesplit[1]

        split = namesplit[-1].split("_")

        protein = split[0]
        ligand = split[2]

        rank = int(split[-1][1:].replace(".gninatypes", ""))

        return (pocket, protein, ligand, rank)

    score[["pocket", "protein", "ligand", "rank"]] = score.apply(getid, axis=1, result_type="expand")
    score.drop(columns="ligname", inplace=True)

    df_score = df.merge(score, on=["pocket", "protein", "ligand", "rank"])

    return df_score

In [None]:
def topN(df, nmax):

    n_pockets = 0
    
    # Store top N metrics
    top_smina = [0] * nmax
    top_gnina = [0] * nmax
    top_best = [0] * nmax

    # Loop over pockets
    for _, group in df.groupby(by=["pocket"]):
        n_smina = [0] * nmax
        n_gnina = [0] * nmax
        n_best = [0] * nmax

        # Compute percentage of targets with good pose in top N
        n_targets = 0
        for _, tgroup in group.groupby("protein"):
            smina = tgroup.sort_values(by="score", ascending=True)
            gnina = tgroup.sort_values(by="CNNscore", ascending=False)
            best = tgroup.sort_values(by="rmsd", ascending=True)

            for n in range(1, nmax + 1):
                # At least one good pose amongst the top N
                if (smina["annotation"].iloc[:n] == 1).any():
                    n_smina[n-1] += 1

                if (gnina["annotation"].iloc[:n] == 1).any():
                    n_gnina[n-1] += 1
                
                if (best["annotation"].iloc[:n] == 1).any():
                    n_best[n-1] += 1

            n_targets += 1

        # Accumulate results for all targets
        for n in range(1, nmax + 1):
            top_smina[n-1] += n_smina[n-1] / n_targets * 100
            top_gnina[n-1] += n_gnina[n-1]  / n_targets * 100
            top_best[n-1] += n_best[n-1]  / n_targets * 100

        n_pockets += 1

    # One pocket has been removed from the training set
    # for lack of actives
    assert n_pockets == 91

    # Return TopN of targets, averaged per pocket
    top_smina_avg = np.array(top_smina) / n_pockets
    top_gnina_avg = np.array(top_gnina) / n_pockets
    top_best_avg =  np.array(top_best) / n_pockets
    return np.array([list(range(1,nmax+1)), top_smina_avg, top_gnina_avg, top_best_avg]).T

In [None]:
nmax = 10

In [None]:
for model in ["default2017", "default2018", "dense"]:
    for prefix in ["cluster", "nc"]:
        if prefix == "nc":
            modelname = f"{model}-noaffinity-nostratified"
        elif prefix == "cluster":
            modelname = f"{model}-noaffinity"

        df_score = load_scores(df, modelname, annotation, prefix)

        for crystal in ["", "nocrystal_"]:
            if crystal == "":
                t  = topN(df_score, nmax)
            elif crystal == "nocrystal_":
                # Remove crystal
                t = topN(df_score[df_score["rank"] != 0], nmax)
            else:
                raise Exception

            df_top = pd.DataFrame(t, columns=["N", "smina", "gnina", "best"])
            df_top["annotation"] = annotation
            df_top["prefix"] = prefix
            df_top["model"] = model
            df_top.to_csv(f"TopN/{crystal}{modelname}_{annotation}{prefix}.csv", index=None)

In [None]:
crystal_nostratified = pd.concat(
    (
        pd.read_csv(f"TopN/default2017-noaffinity-nostratified_{annotation}{prefix}.csv",index_col=0).rename(columns={"gnina": "default2017"}),
        pd.read_csv(f"TopN/default2018-noaffinity-nostratified_{annotation}{prefix}.csv",index_col=0).drop(columns=["smina", "best"]).rename(columns={"gnina": "default2018"}),
        pd.read_csv(f"TopN/dense-noaffinity-nostratified_{annotation}{prefix}.csv",index_col=0).drop(columns=["smina", "best"]).rename(columns={"gnina": "dense"}),
    ),
    axis=1,
)
crystal_nostratified

In [None]:
sns.lineplot(data=crystal_nostratified)
plt.ylabel("TopN (%)")
plt.savefig(f"plots/TopN-crystal-nostratified-{annotation}.pdf")
plt.savefig(f"plots/TopN-crystal-nostratified-{annotation}.png")

In [None]:
crystal_stratified = pd.concat(
    (
        pd.read_csv(f"TopN/default2017-noaffinity_{annotation}{prefix}.csv",index_col=0).rename(columns={"gnina": "default2017"}),
        pd.read_csv(f"TopN/default2018-noaffinity_{annotation}{prefix}.csv",index_col=0).drop(columns=["smina", "best"]).rename(columns={"gnina": "default2018"}),
        pd.read_csv(f"TopN/dense-noaffinity_{annotation}{prefix}.csv",index_col=0).drop(columns=["smina", "best"]).rename(columns={"gnina": "dense"}),
    ),
    axis=1,
)
crystal_stratified

In [None]:
sns.lineplot(data=crystal_stratified)
plt.ylabel("TopN (%)")
plt.savefig(f"plots/TopN-crystal-stratified-{annotation}.pdf")
plt.savefig(f"plots/TopN-crystal-stratified-{annotation}.png")

In [None]:
nocrystal_nostratified = pd.concat(
    (
        pd.read_csv(f"TopN/nocrystal_default2017-noaffinity-nostratified_{annotation}{prefix}.csv",index_col=0).rename(columns={"gnina": "default2017"}),
        pd.read_csv(f"TopN/nocrystal_default2018-noaffinity-nostratified_{annotation}{prefix}.csv",index_col=0).drop(columns=["smina", "best"]).rename(columns={"gnina": "default2018"}),
        pd.read_csv(f"TopN/nocrystal_dense-noaffinity-nostratified_{annotation}{prefix}.csv",index_col=0).drop(columns=["smina", "best"]).rename(columns={"gnina": "dense"}),
    ),
    axis=1,
)
nocrystal_nostratified

In [None]:
sns.lineplot(data=nocrystal_nostratified)
plt.ylabel("TopN (%)")
plt.savefig(f"plots/TopN-nocrystal-nostratified-{annotation}.pdf")
plt.savefig(f"plots/TopN-nocrystal-nostratified-{annotation}.png")

In [None]:
nocrystal_stratified = pd.concat(
    (
        pd.read_csv(f"TopN/nocrystal_default2017-noaffinity_{annotation}{prefix}.csv",index_col=0).rename(columns={"gnina": "default2017"}),
        pd.read_csv(f"TopN/nocrystal_default2018-noaffinity_{annotation}{prefix}.csv",index_col=0).drop(columns=["smina", "best"]).rename(columns={"gnina": "default2018"}),
        pd.read_csv(f"TopN/nocrystal_dense-noaffinity_{annotation}{prefix}.csv",index_col=0).drop(columns=["smina", "best"]).rename(columns={"gnina": "dense"}),
    ),
    axis=1,
)
nocrystal_stratified

In [None]:
sns.lineplot(data=nocrystal_stratified)
plt.ylabel("TopN (%)")
plt.savefig(f"plots/TopN-nocrystal-stratified-{annotation}.pdf")
plt.savefig(f"plots/TopN-nocrystal-stratified-{annotation}.png")