In [None]:
import sys
import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import sem as sem
import glob

pd.set_option("display.max_rows", None)

from argparse import ArgumentParser
import itertools as it

import pipeline
from pipeline.utils import validate
from pipeline import *
from pipeline.utils import validate
from pipeline.analysis import *

import cinnabar

In [None]:
# define the analysis method to use
ana_dicts = {
    "plain": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": False,
        "statistical inefficiency": False,
        "truncate lower": 0,
        "truncate upper": 100,
        "name": None,
    },
    "subsampling": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": False,
        "statistical inefficiency": True,
        "truncate lower": 0,
        "truncate upper": 100,
        "name": None,
    },
}

In [None]:
# set the variables
network = "combined"  # lomap rbfenn combined

prot_dict_name = {
    "tyk2": "TYK2",
    "mcl1": "MCL1",
    "p38": "P38α",
    "syk": "SYK",
    "hif2a": "HIF2A",
    "cmet": "CMET",
}
eng_dict_name = {"AMBER": "AMBER22", "SOMD": "SOMD1", "GROMACS": "GROMACS23"}

# all the options
ana_obj_dict = {}

for protein in ["tyk2", "mcl1", "p38", "syk", "hif2a", "cmet"]:
    ana_obj_dict[protein] = {}

    for ana_dict in ana_dicts:
        ana_prot = analysis_protocol(ana_dicts[ana_dict])
        print(protein, ana_dict)

        if protein == "syk" or protein == "cmet":
            main_dir = f"/backup/{protein}/neutral"
        else:
            main_dir = f"/backup/{protein}"

        bench_folder = f"/home/anna/Documents/benchmark"

        # if need size of protein
        try:
            prot = BSS.IO.readMolecules(
                [
                    f"{bench_folder}/inputs/{protein}/{protein}_prep/{protein}.gro",
                    f"{bench_folder}/inputs/{protein}/{protein}_prep/{protein}.top",
                ]
            )[0]
        except:
            prot = BSS.IO.readMolecules(
                [
                    f"{bench_folder}/inputs/{protein}/{protein}_parameterised.prm7",
                    f"{bench_folder}/inputs/{protein}/{protein}_parameterised.rst7",
                ]
            )[0]

        print(f"no of residues in the protein: {prot.nResidues()}")

        # choose location for the files
        if protein == "syk" or protein == "cmet" or protein == "hif2a":
            # the lomap network
            net_file = f"{main_dir}/execution_model/network_all.dat"
        else:
            net_file = f"{main_dir}/execution_model/network_{network}.dat"

        exp_file = f"{bench_folder}/inputs/experimental/{protein}.yml"
        output_folder = f"{main_dir}/outputs_extracted"

        # prot_file = f"{main_dir}/execution_model/protocol.dat" # no protocol used , name added after if needed
        pipeline_prot = pipeline_protocol(auto_validate=True)
        # pipeline_prot.name("")

        # initialise the network object
        all_analysis_object = analysis_network(
            output_folder,
            exp_file=exp_file,
            net_file=net_file,
            analysis_prot=ana_prot,
            method=pipeline_prot.name(),  # if the protocol had a name
            engines=pipeline_prot.engines(),
        )

        # compute
        all_analysis_object.compute_results()

        if ana_dict == "single":
            all_analysis_object.file_ext = all_analysis_object.file_ext + f"_{ana_dict}"

        # add ligands folder
        if os.path.isdir(f"{bench_folder}/inputs/{protein}/ligands"):
            all_analysis_object.add_ligands_folder(
                f"{bench_folder}/inputs/{protein}/ligands"
            )
        else:
            all_analysis_object.add_ligands_folder(
                f"{bench_folder}/inputs/{protein}/ligands_neutral"
            )

        ana_obj_dict[protein][ana_dict] = all_analysis_object

print(ana_obj_dict)

In [None]:
pert_overlap_dict = {}
for prot in ana_obj_dict.keys():
    ana_obj = ana_obj_dict[prot]["subsampling"]
    df = ana_obj.perturbing_atoms_and_overlap(read_file=True)
    pert_overlap_dict[prot] = df

In [None]:
# adding the scores

for prot in ana_obj_dict.keys():
    print(prot)
    ana_obj = ana_obj_dict[prot]["subsampling"]

    if prot == "syk" or prot == "cmet":
        main_dir = f"/backup/{prot}/neutral"
    else:
        main_dir = f"/backup/{prot}"

    df = pert_overlap_dict[prot]
    df["score"] = np.nan
    # read in all the lomap scores
    score_dict = {}
    # print(f"{main_dir}/execution_model/network_scores.dat")
    with open(f"{main_dir}/execution_model/network_scores.dat") as lfile:
        for line in lfile:
            score_dict[
                f"{line.split(',')[0].strip()}~{line.split(',')[1].strip()}"
            ] = float(line.split(",")[-1].strip())

    for index, row in df.iterrows():
        if row["perturbation"] not in ana_obj.perturbations:
            df = df.drop(index)
        else:
            try:
                df.at[index, "score"] = score_dict[row["perturbation"]]
            except:
                try:
                    df.at[index, "score"] = score_dict[
                        f'{row["perturbation"].split("~")[1]}~{row["perturbation"].split("~")[0]}'
                    ]
                except:
                    # print(f"not {row['perturbation']}")
                    pass

    pert_overlap_dict[prot] = df

In [None]:
# # write lomap scores for all of the network
# pl = pipeline.setup.initialise_pipeline()
# # where the ligands for the pipeline are located. These should all be in the same folder in sdf format
# pl.ligands_folder(f"{main_folder}/inputs/{prot}/ligands")
# pl.main_folder(f"{main_folder}/{prot}_benchmark")
# pl.setup_ligands(file_name=f"{main_folder}/{prot}_benchmark/execution_model/combined/ligands.dat")
# pl.setup_network(folder="combined")
# for pert in pl.perturbations:
#     pl.remove_perturbation(pert)
# for pert in perturbations:
#     pl.add_perturbation(pert)

In [None]:
from functools import reduce

all_df = reduce(
    lambda left, right: pd.concat([left, right], axis=0), pert_overlap_dict.values()
)
df = all_df
print(len(df))

In [None]:
# histogram of failed run in terms of perturbing atoms

# df_has = df[df["percen_overlap_okay"] >= 0]
# df_none = (
#     pd.merge(df_has, df, how="outer", indicator=True)
#     .query("_merge != 'both'")
#     .drop("_merge", axis=1)
#     .reset_index(drop=True)
# )
# print(len(df_none))
# df_none["perturbing_atoms"].plot.hist(bins=10)
# plt.xlabel("perturbing atoms")

In [None]:
df_plot = df  # df.dropna()
print(len(df_plot))
df_plot.plot.scatter(
    "score",
    "too_small_avg",
    c="diff_to_exp",
    colormap="viridis",
    vmin=0,
    vmax=5,
)
# plt.title(f"")
# plt.xlabel("lomap score")
# plt.ylabel("average no. of too small off-diagonals per leg")

df_plot.plot.scatter(
    "score",
    "percen_overlap_okay",
    c="diff_to_exp",
    colormap="viridis",
    vmin=0,
    vmax=5,
)
# plt.title(f"")
# plt.xlabel("lomap score")
# plt.ylabel("percentage of okay overlap")
# )

df_plot.plot.scatter("score", "diff_to_exp", c="too_small_avg", colormap="viridis")

In [None]:
import scipy.stats as _stats

In [12]:
pert_overlap_dict["mcl1"]

Unnamed: 0,perturbation,engine,perturbing_atoms,percen_overlap_okay,too_small_avg,diff_to_exp,error,score
0,lig_53~lig_67,AMBER,,50.0,0.5,1.296012,0.070076,0.29396
1,lig_53~lig_67,SOMD,,100.0,0.0,0.340584,0.063599,0.29396
2,lig_53~lig_67,GROMACS,,83.333333,0.333333,2.646186,0.589695,0.29396
3,lig_26~lig_45,AMBER,,100.0,0.0,1.344754,0.088107,0.30036
4,lig_26~lig_45,SOMD,,83.333333,0.166667,6.404486,0.331257,0.30036
5,lig_26~lig_45,GROMACS,,16.666667,2.666667,4.58857,0.965678,0.30036
6,lig_23~lig_26,AMBER,,100.0,0.0,1.336267,0.043992,0.32494
7,lig_23~lig_26,SOMD,,100.0,0.0,2.604751,0.189239,0.32494
8,lig_23~lig_26,GROMACS,,16.666667,2.333333,4.582526,0.955716,0.32494
9,lig_60~lig_67,AMBER,,100.0,0.0,1.824067,0.068327,0.32494


In [None]:
for prot in pert_overlap_dict.keys():
    df = pert_overlap_dict[prot]
    print(prot)
    for eng in ana_obj.engines:
        df_plot = df.drop(df[df["engine"] != eng].index).reset_index()
        small_list = list(df_plot["too_small_avg"].dropna())
        lower_ci, upper_ci = _stats.norm.interval(
            confidence=0.95, loc=np.mean(small_list), scale=_stats.sem(small_list)
        )
        print(
            f"{eng}, avg: {np.mean(small_list):.2f}, std: {np.std(small_list):.2f}, 95% CI: {lower_ci:.2f}, {upper_ci:.2f}"
        )

In [None]:
for eng in ana_obj.engines:
    df_plot = df.drop(df[df["engine"] != eng].index).reset_index()
    small_list = list(
        df_plot["diff_to_exp"].dropna().drop(df_plot[df_plot["diff_to_exp"] > 10].index)
    )
    lower_ci, upper_ci = _stats.norm.interval(
        confidence=0.95, loc=np.mean(small_list), scale=_stats.sem(small_list)
    )
    print(eng, np.mean(small_list), np.std(small_list), lower_ci, upper_ci)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5), sharex=True, sharey=True)

for eng, pos in zip(ana_obj.engines, [axes[0], axes[1], axes[2]]):
    # plt.clf()
    # eng = "AMBER"
    df_plot = df.drop(df[df["engine"] != eng].index)
    print(len(df_plot))

    # plt.yscale("log")
    plt.scatter(
        df_plot["score"],
        df_plot["too_small_avg"],
        c=df_plot["diff_to_exp"],
        vmin=0,
        vmax=5,
        cmap="plasma",
    )
    cbar = plt.colorbar()
    cbar.set_label("Difference to experimental\n value (kcal/mol)")
    plt.ylabel("Average number of off-diagonals\n <0.03 for the perturbation")
    plt.xlabel("LOMAP score")
    plt.title(eng_dict_name[eng])

In [None]:
for eng in ana_obj.engines:
    df_plot = df.drop(df[df["engine"] != eng].index).reset_index()
    x = list(
        df_plot["diff_to_exp"].dropna().drop(df_plot[df_plot["diff_to_exp"] > 10].index)
    )
    y = list(
        df_plot["too_small_avg"]
        .dropna()
        .drop(df_plot[df_plot["diff_to_exp"] > 10].index)
    )
    res = stats_engines.compute_stats(
        x=y, y=y, statistic="R2"  # xerr=yerr,  # yerr=yerrexp,
    )
    print(eng, res)

In [None]:
df_plot = df.reset_index()
df_test = df_plot.drop(df_plot[df_plot["diff_to_exp"] > 5].index)
print(df_test["too_small_avg"].mean())
print(
    len(df_test.drop(df_test[df_test["percen_overlap_okay"] != 100].index)),
    len(df_test["percen_overlap_okay"]),
    100
    * len(df_test.drop(df_test[df_test["percen_overlap_okay"] != 100].index))
    / len(df_test["percen_overlap_okay"]),
)

In [None]:
x = []
y = []
for r in range(0, 100, 20):
    # for eng in ana_obj.engines:
    eng = "AMBER"
    df_plot = df.drop(df[df["engine"] != eng].index).reset_index()
    df_test = df_plot.drop(df_plot[df_plot["diff_to_exp"] > 5].index)
    print(eng, df_test["too_small_avg"].mean())
    print(
        eng,
        len(df_test.drop(df_test[df_test["percen_overlap_okay"] <= r].index)),
        len(df_test["percen_overlap_okay"]),
        100
        * len(df_test.drop(df_test[df_test["percen_overlap_okay"] < r].index))
        / len(df_test["percen_overlap_okay"]),
    )
    x.append(r)
    y.append(
        100
        * len(df_test.drop(df_test[df_test["percen_overlap_okay"] < r].index))
        / len(df_test["percen_overlap_okay"])
    )

In [None]:
x2 = []
y2 = []
for r in range(0, 100, 20):
    # for eng in ana_obj.engines:
    eng = "AMBER"
    df_plot = df.drop(df[df["engine"] != eng].index).reset_index()
    df_test = df_plot.drop(df_plot[df_plot["diff_to_exp"] < 5].index)
    print(eng, df_test["too_small_avg"].mean())
    print(
        eng,
        len(df_test.drop(df_test[df_test["percen_overlap_okay"] <= r].index)),
        len(df_test["percen_overlap_okay"]),
        100
        * len(df_test.drop(df_test[df_test["percen_overlap_okay"] < r].index))
        / len(df_test["percen_overlap_okay"]),
    )
    x2.append(r)
    y2.append(
        100
        * len(df_test.drop(df_test[df_test["percen_overlap_okay"] < r].index))
        / len(df_test["percen_overlap_okay"])
    )

In [None]:
plt.plot(x, y)  # blue, for more than 5 kcal/mol
plt.plot(x2, y2)  # orange, for less than 5 kcal/mol

In [None]:
from sklearn import preprocessing
from sklearn.metrics import r2_score
from sklearn.linear_model import LinearRegression

In [None]:
df_plot = df[["diff_to_exp", "too_small_avg"]]
x = df_plot.values  # returns a numpy array
min_max_scaler = preprocessing.MinMaxScaler()
x_scaled = min_max_scaler.fit_transform(x)
df_plot = pd.DataFrame(x_scaled, columns=["diff_to_exp", "too_small_avg"])

r2_score(df_plot["diff_to_exp"].dropna(), df_plot["too_small_avg"].dropna())

In [None]:
df_plot = df  # .dropna()

# plt.xscale("log")
# plt.yscale("log")
fig, ax = plt.subplots()
ax.scatter(
    df_plot["diff_to_exp"],
    df_plot["too_small_avg"],
    c=df_plot["score"],
)
y_test = list(df_plot["diff_to_exp"].dropna())
y_pred = list(df_plot["too_small_avg"].dropna())
ax.plot(y_test, LinearRegression().fit(y_test, y_pred).predict(y_test))
cbar = plt.colorbar()
cbar.set_label("LOMAP-score")
plt.ylabel("Average number of off-diagonals < 0.03")
plt.xlabel("Diff to exp")
# these will obviously correlate

In [None]:
# plotting logarithmically
df_plot = df  # .dropna()

# plt.yscale("log")
# plt.xscale("symlog")
# plt.gca().invert_xaxis()
plt.scatter(
    df_plot["diff_to_exp"],
    df_plot["too_small_avg"],
    c=df_plot["score"],
)
cbar = plt.colorbar()
cbar.set_label("score")
plt.ylabel("too_small_avg")
plt.xlabel("diff_to_exp")

In [None]:
# exclude outliers as needed
df3 = df_plot[df_plot["diff_to_exp"] >= 20]
df_out = (
    pd.merge(df3, df_plot, how="outer", indicator=True)
    .query("_merge != 'both'")
    .drop("_merge", axis=1)
    .reset_index(drop=True)
)

# df_out = df_plot

df_out.plot.scatter(
    "perturbing_atoms", "diff_to_exp", c="percen_overlap_okay", colormap="viridis"
)

In [None]:
columns = ["perturbing_atoms", "percen_overlap_okay", "too_small_avg"]
bins = [6, 3, 5]
for column, bin in zip(columns, bins):
    fig = plt.figure()
    df_plot = df[column]
    df_plot.plot.hist(subplots=True, bins=bin)
    plt.title(column)

In [None]:
# for per engine
eng = "GROMACS"
df2 = df[df["engine"] == eng]
df_plot = df2.dropna()
df_plot.plot.scatter(
    "perturbing_atoms", "percen_overlap_okay", c="too_small_avg", colormap="viridis"
)
df_plot.plot.scatter(
    "perturbing_atoms", "too_small_avg", c="percen_overlap_okay", colormap="viridis"
)
df_plot.plot.scatter(
    "percen_overlap_okay", "too_small_avg", c="perturbing_atoms", colormap="viridis"
)

In [None]:
engs = ["AMBER", "SOMD", "GROMACS"]
eng_dict = {}

col_dict = pipeline.utils.set_colours()

for eng in engs:
    df2 = df[df["engine"] == eng]
    eng_dict[eng] = df2

for eng in engs:
    df_plot = eng_dict[eng].dropna()
    ax = df_plot.plot.scatter(
        "perturbing_atoms", "percen_overlap_okay", c=col_dict[eng]
    )

df_plot = eng_dict["AMBER"].dropna()
ax1 = df_plot.plot.scatter(
    "perturbing_atoms", "percen_overlap_okay", c=col_dict["AMBER"]
)
df_plot = eng_dict["SOMD"].dropna()
ax2 = df_plot.plot.scatter(
    "perturbing_atoms", "percen_overlap_okay", c=col_dict["SOMD"], ax=ax1
)
df_plot = eng_dict["GROMACS"].dropna()
ax3 = df_plot.plot.scatter(
    "perturbing_atoms", "percen_overlap_okay", c=col_dict["GROMACS"], ax=ax1
)
plt.legend(col_dict, loc="upper right")
print(ax1 == ax2 == ax3)

In [None]:
engs = ["AMBER", "SOMD", "GROMACS"]
eng_dict = {}

col_dict = pipeline.utils.set_colours()

for eng in engs:
    df2 = df[df["engine"] == eng]
    df3 = df2[df2["diff_to_exp"] >= 5]
    df4 = (
        pd.merge(df3, df2, how="outer", indicator=True)
        .query("_merge != 'both'")
        .drop("_merge", axis=1)
        .reset_index(drop=True)
    )
    eng_dict[eng] = df4

for eng in engs:
    df_plot = eng_dict[eng].dropna()
    ax = df_plot.plot.scatter("perturbing_atoms", "diff_to_exp", c=col_dict[eng])

df_plot = eng_dict["AMBER"].dropna()
ax1 = df_plot.plot.scatter("perturbing_atoms", "diff_to_exp", c=col_dict["AMBER"])
df_plot = eng_dict["SOMD"].dropna()
ax2 = df_plot.plot.scatter(
    "perturbing_atoms", "diff_to_exp", c=col_dict["SOMD"], ax=ax1
)
df_plot = eng_dict["GROMACS"].dropna()
ax3 = df_plot.plot.scatter(
    "perturbing_atoms", "diff_to_exp", c=col_dict["GROMACS"], ax=ax1
)
plt.legend(col_dict, loc="upper right")
print(ax1 == ax2 == ax3)

checking which perts are bad overlap

In [None]:
file = f"{main_folder}/extracted/mcl1/perturbing_overlap.dat"

eng_dict_ok = {"AMBER": None, "SOMD": None, "GROMACS": None}
eng_dict_not = {"AMBER": None, "SOMD": None, "GROMACS": None}

for engine in ["SOMD", "AMBER", "GROMACS"]:
    print(engine)
    perts_okay = []
    perts_not = []
    with open(file, "r") as f:
        for line in f.readlines():
            pert = line.split(",")[0].strip()
            overlap_okay = line.split(",")[3].strip()
            eng = line.split(",")[1].strip()

            if eng == engine:
                if overlap_okay == "100.0":  #  or overlap_okay == "50.0"
                    if pert not in perts_okay:
                        perts_okay.append(pert)
                else:
                    if pert not in perts_not:
                        perts_not.append(pert)

    # for pert in perts_not:
    #     if pert in perts_okay:
    #         perts_okay.remove(pert)

    print(len(perts_okay))
    print(perts_okay)
    print(len(perts_not))
    print(perts_not)
    print(" ")

    eng_dict_ok[engine] = perts_okay
    eng_dict_not[engine] = perts_not


both_perts = []

for pert in eng_dict_ok["AMBER"]:
    if pert in eng_dict_ok["SOMD"]:
        both_perts.append(pert)

print(len(both_perts))
print(both_perts)
print(" ")

all_perts = eng_dict_ok["SOMD"] + eng_dict_not["SOMD"]
not_okay = []
for pert in all_perts:
    if pert not in both_perts:
        not_okay.append(pert)

print(len(not_okay))
print(not_okay)
print(" ")

In [None]:
prot = pipeline_protocol()
prot.num_lambda(16)
pert_dict = {(pert.split("~")[0], pert.split("~")[1]): "None" for pert in not_okay}
write_network(pert_dict, prot, "new_mcl1_network_both2.dat")