In [None]:
# analysis paper
# import libraries
import seaborn as sns
import numpy as np
import scipy.stats as _stats
from functools import reduce
from pipeline.analysis import *
from pipeline.utils import validate
from pipeline import *
import logging
import networkx as nx
import glob
from scipy.stats import sem as sem
from matplotlib import colormaps
import sys

# sys.path.insert(1, "/home/anna/Documents/code/python/pipeline")
from matplotlib.ticker import MaxNLocator

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)
# warnings.simplefilter(action='ignore', category=SettingWithCopyWarning)

logging.getLogger().setLevel(logging.ERROR)


print(BSS.__file__)

In [2]:
def check_normal_dist(values):
    # check normally dist
    if len(values) < 50:
        stat, p = _stats.shapiro(values)
    else:
        stat, p = _stats.kstest(values)
    if p < 0.05:
        return True
    else:
        return False


def flatten_comprehension(matrix):
    return [item for row in matrix for item in row]

In [5]:
# define the analysis method to use
ana_dicts = {
    "plain": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": False,
        "statistical inefficiency": False,
        "truncate lower": 0,
        "truncate upper": 100,
        "name": None,
    },
    "subsampling": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": False,
        "statistical inefficiency": True,
        "truncate lower": 0,
        "truncate upper": 100,
        "name": None,
    },
    "1ns": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": False,
        "statistical inefficiency": True,
        "truncate lower": 0,
        "truncate upper": 25,
        "name": None,
    },
    "2ns": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": False,
        "statistical inefficiency": True,
        "truncate lower": 0,
        "truncate upper": 50,
        "name": None,
    },
    "3ns": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": False,
        "statistical inefficiency": True,
        "truncate lower": 0,
        "truncate upper": 75,
        "name": None,
    },
    "autoeq": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": True,
        "statistical inefficiency": True,
        "truncate lower": 0,
        "truncate upper": 100,
        "name": None,
    },
    # "TI": {
    # "estimator": "TI",
    # "method": "alchemlyb",
    # "check overlap": True,
    # "try pickle": True,
    # "save pickle": True,
    # "auto equilibration": False,
    # "statistical inefficiency": True,
    # "truncate lower": 0,
    # "truncate upper": 100,
    # "name": None,
    # },
    #     "single_0": {
    #     "estimator": "MBAR",
    #     "method": "alchemlyb",
    #     "check overlap": True,
    #     "try pickle": True,
    #     "save pickle": True,
    #     "auto equilibration": False,
    #     "statistical inefficiency": False,
    #     "truncate lower": 0,
    #     "truncate upper": 100,
    #     "name": None,
    # },
    #     "single_1": {
    #     "estimator": "MBAR",
    #     "method": "alchemlyb",
    #     "check overlap": True,
    #     "try pickle": True,
    #     "save pickle": True,
    #     "auto equilibration": False,
    #     "statistical inefficiency": False,
    #     "truncate lower": 0,
    #     "truncate upper": 100,
    #     "name": None,
    # },
    #     "single_2": {
    #     "estimator": "MBAR",
    #     "method": "alchemlyb",
    #     "check overlap": True,
    #     "try pickle": True,
    #     "save pickle": True,
    #     "auto equilibration": False,
    #     "statistical inefficiency": False,
    #     "truncate lower": 0,
    #     "truncate upper": 100,
    #     "name": None,
    # }
}

In [None]:
# set the variables
network = "combined"  # lomap rbfenn combined

prot_dict_name = {
    "tyk2": "TYK2",
    "mcl1": "MCL1",
    "p38": "P38α",
    "syk": "SYK",
    "hif2a": "HIF2A",
    "cmet": "CMET",
}
eng_dict_name = {"AMBER": "AMBER22", "SOMD": "SOMD1", "GROMACS": "GROMACS23"}

# all the options
ana_obj_dict = {}

for protein in ["tyk2", "mcl1", "p38", "syk", "hif2a", "cmet"]:
    ana_obj_dict[protein] = {}

    for ana_dict in ana_dicts:
        ana_prot = analysis_protocol(ana_dicts[ana_dict])
        print(protein, ana_dict)

        if protein == "syk" or protein == "cmet":
            main_dir = f"/backup/{protein}/neutral"
        else:
            main_dir = f"/backup/{protein}"

        bench_folder = f"/home/anna/Documents/benchmark"

        # if need size of protein
        try:
            prot = BSS.IO.readMolecules(
                [
                    f"{bench_folder}/inputs/{protein}/{protein}_prep/{protein}.gro",
                    f"{bench_folder}/inputs/{protein}/{protein}_prep/{protein}.top",
                ]
            )[0]
        except:
            prot = BSS.IO.readMolecules(
                [
                    f"{bench_folder}/inputs/{protein}/{protein}_parameterised.prm7",
                    f"{bench_folder}/inputs/{protein}/{protein}_parameterised.rst7",
                ]
            )[0]

        print(f"no of residues in the protein: {prot.nResidues()}")

        # choose location for the files
        if protein == "syk" or protein == "cmet" or protein == "hif2a":
            # the lomap network
            net_file = f"{main_dir}/execution_model/network_all.dat"
        else:
            net_file = f"{main_dir}/execution_model/network_{network}.dat"

        exp_file = f"{bench_folder}/inputs/experimental/{protein}.yml"
        output_folder = f"{main_dir}/outputs_extracted"

        # prot_file = f"{main_dir}/execution_model/protocol.dat" # no protocol used , name added after if needed
        pipeline_prot = pipeline_protocol(auto_validate=True)
        # pipeline_prot.name("")

        # initialise the network object
        all_analysis_object = analysis_network(
            output_folder,
            exp_file=exp_file,
            net_file=net_file,
            analysis_prot=ana_prot,
            method=pipeline_prot.name(),  # if the protocol had a name
            engines=pipeline_prot.engines(),
        )

        # compute
        all_analysis_object.compute_results()

        if ana_dict == "single":
            all_analysis_object.file_ext = all_analysis_object.file_ext + f"_{ana_dict}"

        # add ligands folder
        if os.path.isdir(f"{bench_folder}/inputs/{protein}/ligands"):
            all_analysis_object.add_ligands_folder(
                f"{bench_folder}/inputs/{protein}/ligands"
            )
        else:
            all_analysis_object.add_ligands_folder(
                f"{bench_folder}/inputs/{protein}/ligands_neutral"
            )

        ana_obj_dict[protein][ana_dict] = all_analysis_object

print(ana_obj_dict)

In [None]:
# make single vs triplicate results
for prot in ana_obj_dict.keys():
    for r in range(0, 3, 1):
        ana_obj = ana_obj_dict[prot][f"single_{r}"]
        # function for single dicts
        ana_obj.compute_single_repeat_results(repeat=r)
        # for eng in ["AMBER","SOMD","GROMACS"]:
        #     print(prot, eng)
        #     ana_obj.change_name(eng, f"{eng}_old")
        #     ana_obj.change_name(f"{eng}_single", eng)
        #     if eng not in ana_obj.engines:
        #         ana_obj.engines.append(eng)
        #     if eng in ana_obj.other_results_names:
        #         ana_obj.other_results_names.remove(eng)
        # print(ana_obj.engines + ana_obj.other_results_names)
        # print(ana_obj.calc_pert_dict[eng])

# # error for a perturbation per single run

# uncertainty_dict_single = {}

# for eng in all_analysis_object.engines:
#     uncertainty_dict_single[eng] = {}
#     repeat = 0
#     for file in all_analysis_object._results_repeat_files[eng]:
#         uncertainty_dict_single[eng][repeat] = {}
#         calc_diff_dict = make_dict.comp_results(
#             file, all_analysis_object.perturbations, eng, name=None
#         )

#         for pert in calc_diff_dict.keys():
#             uncertainty_dict_single[eng][repeat][pert] = calc_diff_dict[pert][1]

#         repeat += 1

In [None]:
# plot convergence. only plot if has been computed (should have run for not stats ineff)

# for prot in ana_obj_dict.keys():

#     try:
#         ana_obj = ana_obj_dict[prot]["plain"]
#         ana_obj.compute_convergence(
#             compute_missing=True
#         )
#         ana_obj.plot_convergence()
#     except Exception as e:
#         print(e)
#         print(f"could not for {prot}")

In [None]:
# identify any outliers and plot again if needed above
failed_perts_dict_percen = {}
failed_perts_dict = {}

for prot in ana_obj_dict.keys():
    failed_perts_dict_percen[prot] = {}
    failed_perts_dict[prot] = {}
    ana_obj = ana_obj_dict[prot]["plain"]
    print(prot)
    for eng in ana_obj.engines:  # ana_obj.engines
        failed_perts_dict_percen[prot][eng] = (
            100 - ana_obj.successful_perturbations(eng)[1]
        )
        failed_perts_dict[prot][eng] = len(ana_obj.failed_perturbations(eng))
        print(
            f"failed percentage for {eng}: {100 - ana_obj.successful_perturbations(eng)[1]} ({len(ana_obj.perturbations) - len(ana_obj.successful_perturbations(eng)[2])} / {len(ana_obj.perturbations)})"
        )
        print(f"{eng} failed perturbations: {ana_obj.failed_perturbations(engine=eng)}")
        print(f"{eng} disconnected ligands: {ana_obj.disconnected_ligands(engine=eng)}")
        print(f"outliers {eng}: {ana_obj.get_outliers(threshold=10, name=eng)}")

In [None]:
ana_obj = ana_obj_dict["p38"]["subsampling"]
ana_obj.draw_perturbations(ana_obj.get_outliers(threshold=10, name="SOMD"))
[
    (pert, ana_obj.calc_pert_dict["SOMD"][pert])
    for pert in ana_obj.get_outliers(threshold=10, name="SOMD")
]

In [5]:
# list of the 2 fs and reverse runs
failed_dict = {
    "tyk2": {
        "AMBER": ["lig_jmc27~lig_jmc28"],
        "SOMD": [],
        "GROMACS": [],
    },
    "mcl1": {
        "AMBER": [
            "lig_38~lig_44",
            "lig_27~lig_40",
            "lig_38~lig_40",
            "lig_33~lig_40",
            "lig_30~lig_40",
            "lig_27~lig_38",
        ],
        "SOMD": ["lig_38~lig_48"],
        "GROMACS": [],
    },
    "p38": {
        "AMBER": ["lig_2b~lig_2z", "lig_2aa~lig_2z", "lig_2o~lig_2s", "lig_2n~lig_2s"],
        "SOMD": [
            "lig_2u~lig_2x",
            "lig_2b~lig_2l",
            "lig_2a~lig_2q",
            "lig_2ff~lig_2hh",
            "lig_2h~lig_2m",
        ],
        "GROMACS": [],
    },
    "syk": {
        "AMBER": [
            "lig_1~lig_13",
            "lig_33~lig_38",
            "lig_21~lig_37",
            "lig_21~lig_8",
            "lig_28~lig_42",
        ],
        "SOMD": ["lig_1~lig_13"],
        "GROMACS": [],
    },
    "hif2a": {
        "AMBER": ["lig_41~lig_9", "lig_35~lig_36"],
        "SOMD": ["lig_27~lig_8"],
        "GROMACS": [],
    },
    "cmet": {
        "AMBER": ["lig_1~lig_4", "lig_1~lig_6", "lig_10~lig_17"],
        "SOMD": ["lig_1~lig_6"],
        "GROMACS": [],
    },
}
twofs_run_dict = {
    "tyk2": {
        "AMBER": [],
        "SOMD": [],
        "GROMACS": [],
    },
    "mcl1": {
        "AMBER": ["lig_27~lig_59", "lig_33~lig_48", "lig_30~lig_31", "lig_35~lig_36"],
        "SOMD": ["lig_27~lig_40", "lig_27~lig_45", "lig_27~lig_47"],
        "GROMACS": [],
    },
    "p38": {
        "AMBER": [
            "lig_2b~lig_2l",
            "lig_2e~lig_2q",
            "lig_2v~lig_2y",
            "lig_2e~lig_2r",
            "lig_2a~lig_2r",
            "lig_2dd~lig_2hh",
            "lig_2k~lig_2q",
            "lig_2n~lig_2p",
            "lig_2ee~lig_2ff",
            "lig_2h~lig_2n",
            "lig_2k~lig_2l",
        ],
        "SOMD": ["lig_2m~lig_2x", "lig_2t~lig_2x", "lig_2l~lig_2x", "lig_2ff~lig_2m"],
        "GROMACS": ["lig_2t~lig_2x", "lig_2l~lig_2t"],
    },
    "syk": {
        "AMBER": [
            "lig_21~lig_41",
            "lig_37~lig_41",
            "lig_1~lig_41",
            "lig_11~lig_5",
            "lig_29~lig_40",
            "lig_11~lig_34",
            "lig_19~lig_40",
            "lig_10~lig_42",
            "lig_16~lig_17",
            "lig_19~lig_25",
            "lig_23~lig_42",
            "lig_33~lig_40",
            "lig_2~lig_20",
            "lig_2~lig_30",
            "lig_28~lig_35",
            "lig_38~lig_44",
            "lig_6~lig_7",
            "lig_35~lig_42",
            "lig_39~lig_42",
            "lig_39~lig_7",
            "lig_20~lig_30",
            "lig_27~lig_38",
        ],
        "SOMD": ["lig_21~lig_41", "lig_37~lig_41", "lig_38~lig_44"],
        "GROMACS": ["lig_28~lig_42"],
    },
    "hif2a": {
        "AMBER": [
            "lig_3~lig_31",
            "lig_31~lig_41",
            "lig_20~lig_37",
            "lig_15~lig_32",
            "lig_20~lig_40",
            "lig_24~lig_5",
            "lig_32~lig_6",
            "lig_34~lig_35",
            "lig_41~lig_6",
            "lig_19~lig_6",
            "lig_24~lig_37",
            "lig_29~lig_7",
            "lig_14~lig_42",
            "lig_25~lig_26",
            "lig_27~lig_3",
            "lig_28~lig_8",
            "lig_21~lig_25",
            "lig_21~lig_39",
            "lig_33~lig_42",
            "lig_34~lig_6",
            "lig_1~lig_25",
            "lig_14~lig_38",
            "lig_15~lig_6",
            "lig_16~lig_4",
            "lig_21~lig_33",
            "lig_21~lig_42",
            "lig_22~lig_8",
            "lig_25~lig_33",
            "lig_30~lig_6",
            "lig_10~lig_21",
            "lig_17~lig_42",
        ],
        "SOMD": [],
        "GROMACS": [],
    },
    "cmet": {
        "AMBER": ["lig_17 lig_6", "lig_4 lig_5", "lig_22 lig_3"],
        "SOMD": [],
        "GROMACS": ["lig_7~lig_9"],
    },
}
reverse_run_dict = {
    "tyk2": {
        "AMBER": [],
        "SOMD": [],
        "GROMACS": [],
    },
    "mcl1": {
        "AMBER": [
            "lig_27~lig_43",
        ],
        "SOMD": [],
        "GROMACS": [],
    },
    "p38": {
        "AMBER": ["lig_2a~lig_2h"],
        "SOMD": [],
        "GROMACS": [],
    },
    "syk": {
        "AMBER": [
            "lig_40~lig_42",
            "lig_40~lig_6",
            "lig_25~lig_29",
            "lig_26~lig_6",
            "lig_36~lig_6",
            "lig_27~lig_44",
            "lig_26~lig_36",
        ],
        "SOMD": [],
        "GROMACS": [],
    },
    "hif2a": {
        "AMBER": [],
        "SOMD": [],
        "GROMACS": [],
    },
    "cmet": {
        "AMBER": [
            "lig_7~lig_9",
            "lig_22~lig_9",
            "lig_1~lig_10",
            "lig_1~lig_17",
            "lig_18~lig_7",
        ],
        "SOMD": ["lig_22~lig_9"],
        "GROMACS": ["lig_22~lig_9"],
    },
}

failed_actual_dict = {
    "tyk2": {
        "AMBER": len(failed_dict["tyk2"]["AMBER"]),
        "SOMD": len(failed_dict["tyk2"]["SOMD"]),
        "GROMACS": len(failed_dict["tyk2"]["GROMACS"]),
    },
    "mcl1": {
        "AMBER": len(failed_dict["mcl1"]["AMBER"]),
        "SOMD": len(failed_dict["mcl1"]["SOMD"]),
        "GROMACS": len(failed_dict["mcl1"]["GROMACS"]),
    },
    "p38": {
        "AMBER": len(failed_dict["p38"]["AMBER"]),
        "SOMD": len(failed_dict["p38"]["SOMD"]),
        "GROMACS": len(failed_dict["p38"]["GROMACS"]),
    },
    "syk": {
        "AMBER": len(failed_dict["syk"]["AMBER"]),
        "SOMD": len(failed_dict["syk"]["SOMD"]),
        "GROMACS": len(failed_dict["syk"]["GROMACS"]),
    },
    "hif2a": {
        "AMBER": len(failed_dict["hif2a"]["AMBER"]),
        "SOMD": len(failed_dict["hif2a"]["SOMD"]),
        "GROMACS": len(failed_dict["hif2a"]["GROMACS"]),
    },
    "cmet": {
        "AMBER": len(failed_dict["cmet"]["AMBER"]),
        "SOMD": len(failed_dict["cmet"]["SOMD"]),
        "GROMACS": len(failed_dict["cmet"]["GROMACS"]),
    },
}
twofs_dict = {
    "tyk2": {
        "AMBER": len(twofs_run_dict["tyk2"]["AMBER"]),
        "SOMD": len(twofs_run_dict["tyk2"]["SOMD"]),
        "GROMACS": len(twofs_run_dict["tyk2"]["GROMACS"]),
    },
    "mcl1": {
        "AMBER": len(twofs_run_dict["mcl1"]["AMBER"]),
        "SOMD": len(twofs_run_dict["mcl1"]["SOMD"]),
        "GROMACS": len(twofs_run_dict["mcl1"]["GROMACS"]),
    },
    "p38": {
        "AMBER": len(twofs_run_dict["p38"]["AMBER"]),
        "SOMD": len(twofs_run_dict["p38"]["SOMD"]),
        "GROMACS": len(twofs_run_dict["p38"]["GROMACS"]),
    },
    "syk": {
        "AMBER": len(twofs_run_dict["syk"]["AMBER"]),
        "SOMD": len(twofs_run_dict["syk"]["SOMD"]),
        "GROMACS": len(twofs_run_dict["syk"]["GROMACS"]),
    },
    "hif2a": {
        "AMBER": len(twofs_run_dict["hif2a"]["AMBER"]),
        "SOMD": len(twofs_run_dict["hif2a"]["SOMD"]),
        "GROMACS": len(twofs_run_dict["hif2a"]["GROMACS"]),
    },
    "cmet": {
        "AMBER": len(twofs_run_dict["cmet"]["AMBER"]),
        "SOMD": len(twofs_run_dict["cmet"]["SOMD"]),
        "GROMACS": len(twofs_run_dict["cmet"]["GROMACS"]),
    },
}
reverse_dict = {
    "tyk2": {
        "AMBER": len(reverse_run_dict["tyk2"]["AMBER"]),
        "SOMD": len(reverse_run_dict["tyk2"]["SOMD"]),
        "GROMACS": len(reverse_run_dict["tyk2"]["GROMACS"]),
    },
    "mcl1": {
        "AMBER": len(reverse_run_dict["mcl1"]["AMBER"]),
        "SOMD": len(reverse_run_dict["mcl1"]["SOMD"]),
        "GROMACS": len(reverse_run_dict["mcl1"]["GROMACS"]),
    },
    "p38": {
        "AMBER": len(reverse_run_dict["p38"]["AMBER"]),
        "SOMD": len(reverse_run_dict["p38"]["SOMD"]),
        "GROMACS": len(reverse_run_dict["p38"]["GROMACS"]),
    },
    "syk": {
        "AMBER": len(reverse_run_dict["syk"]["AMBER"]),
        "SOMD": len(reverse_run_dict["syk"]["SOMD"]),
        "GROMACS": len(reverse_run_dict["syk"]["GROMACS"]),
    },
    "hif2a": {
        "AMBER": len(reverse_run_dict["hif2a"]["AMBER"]),
        "SOMD": len(reverse_run_dict["hif2a"]["SOMD"]),
        "GROMACS": len(reverse_run_dict["hif2a"]["GROMACS"]),
    },
    "cmet": {
        "AMBER": len(reverse_run_dict["cmet"]["AMBER"]),
        "SOMD": len(reverse_run_dict["cmet"]["SOMD"]),
        "GROMACS": len(reverse_run_dict["cmet"]["GROMACS"]),
    },
}

df_failed = pd.DataFrame(failed_actual_dict).T
df_twofs = pd.DataFrame(twofs_dict).T
df_reverse = pd.DataFrame(reverse_dict).T

In [None]:
# plot the failed perturbations
# df = pd.DataFrame(failed_perts_dict_percen).T
# ax =df.plot(color=pipeline.analysis.set_colours(),
#     kind="bar", xlabel="Protein System", ylabel="failed perturbations (%)")


ax = df_failed.plot(
    color=pipeline.analysis.set_colours(),
    kind="bar",
    xlabel="Protein System",
    ylabel="number of failed perturbations",
)
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
for p in ax.patches:
    ax.annotate(str(p.get_height()), (p.get_x() * 1.005, p.get_height() * 1.005))

In [None]:
adaptive_protocol_dict = {
    "AMBER": {
        "4 fs": 376
        - (
            np.sum(df_reverse["AMBER"])
            + np.sum(df_twofs["AMBER"])
            + np.sum(df_failed["AMBER"])
        ),
        "2 fs": np.sum(df_reverse["AMBER"]),
        "2 fs reverse": np.sum(df_twofs["AMBER"]),
        "failed": np.sum(df_failed["AMBER"]),
    },
    "GROMACS": {
        "4 fs": 376
        - (
            np.sum(df_reverse["GROMACS"])
            + np.sum(df_twofs["GROMACS"])
            + np.sum(df_failed["GROMACS"])
        ),
        "2 fs": np.sum(df_reverse["GROMACS"]),
        "2 fs reverse": np.sum(df_twofs["GROMACS"]),
        "failed": np.sum(df_failed["GROMACS"]),
    },
    "SOMD": {
        "4 fs": 376
        - (
            np.sum(df_reverse["SOMD"])
            + np.sum(df_twofs["SOMD"])
            + np.sum(df_failed["SOMD"])
        ),
        "2 fs": np.sum(df_reverse["SOMD"]),
        "2 fs reverse": np.sum(df_twofs["SOMD"]),
        "failed": np.sum(df_failed["SOMD"]),
    },
}
# for key in adaptive_protocol_dict:
#     print(key)
#     assert np.sum(adaptive_protocol_dict[key].values) == 376
# total no of perturbations is 376
df = pd.DataFrame(adaptive_protocol_dict).T.rename(eng_dict_name)

fig, ax = plt.subplots(figsize=(3.25, 3.25))
df.plot(
    color=["darkslateblue", "purple", "orchid", "lavender"],
    kind="bar",
    xlabel="MD engine",
    ylabel="Number of perturbations",
    ax=ax,
    width=0.8,
)
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
for p in ax.patches:
    ax.annotate(
        str(p.get_height()), (p.get_x() - 0.05, p.get_height() * 1.005), fontsize=7
    )
ax.legend(loc="center right", fontsize=10)
plt.xlabel("MD Engine", fontsize=12)
plt.ylabel("Number of perturbations", fontsize=12)
plt.tick_params(axis="x", labelsize=10, rotation=45)
plt.tick_params(axis="y", labelsize=10, rotation=0)

In [None]:
# find the max sem

for prot in ana_obj_dict.keys():
    ana_obj = ana_obj_dict[prot]["subsampling"]

    for eng in ana_obj.engines:


In [None]:
# exclude outliers
threshold = 10
for prot in ana_obj_dict.keys():
    for name in ana_dicts.keys():
        print(prot, name)
        ana_obj = ana_obj_dict[prot][name]

        for eng in ana_obj.engines:
            ana_obj.file_ext = ana_obj.file_ext + f"_outliers{threshold}removed"
            ana_obj.remove_outliers(threshold=threshold, name=eng)
        # print(ana_obj.file_ext)

In [30]:
def val_range(val_list):
    min_val = min(val_list)
    max_val = max(val_list)
    # print(min_val)
    # print(max_val)

    return max_val[0] - min_val[0]

In [None]:
# max edge range
for prot in ana_obj_dict.keys():
    print(prot)
    ana_obj = ana_obj_dict[prot]["subsampling"]

    for eng in ana_obj.engines:
        # names = [val for val in ana_obj.calc_pert_dict[eng].keys()]
        # vals = [val[0] for val in ana_obj.calc_pert_dict[eng].values()]
        # sems = [val[1] for val in ana_obj.calc_pert_dict[eng].values()]
        # print(prot, eng, names[sems.index(max(sems))],
        #       vals[sems.index(max(sems))], sems[sems.index(max(sems))])
        ranges = []
        for pert in ana_obj._perturbations_dict[eng]:
            try:
                ra = val_range(
                    [ana_obj.calc_repeat_pert_dict[eng][r][pert] for r in [0, 1, 2]]
                )
                ranges.append(ra)
            except:
                pass
        clean_ranges = [x for x in ranges if str(x) != "nan"]
        print(f"{eng}, {np.mean(clean_ranges):.2f}")

In [None]:
ana_obj = ana_obj_dict["mcl1"]["subsampling"]
[ana_obj.calc_repeat_pert_dict["SOMD"][r]["lig_27~lig_45"] for r in [0, 1, 2]]

In [None]:
# calcualte the differences in SEM
# SEM differences
sem_dict = {}
sem_dict_name = {}

for name in ana_dicts:
    sem_list_name = []
    sem_dict[name] = {}

    for prot in ana_obj_dict.keys():
        sem_dict[name][prot] = {}

        ana_obj = ana_obj_dict[prot][name]  # subsampling

        for eng in ana_obj.engines:
            sem_dict[name][prot][eng] = {}

            sem_list = []
            sems = [val[1] for val in ana_obj.calc_pert_dict[eng].values()]
            sem_list.append(sems)
            sem_list_name.append(sems)

            sem_list = reduce(lambda xs, ys: xs + ys, sem_list)
            sem_list = [x for x in sem_list if str(x) != "nan"]

            # if not check_normal_dist(sem_list):
            #     print(f"{prot} {name} not normally dist")

            mean = np.mean(sem_list)
            lower_ci, upper_ci = _stats.norm.interval(
                confidence=0.95, loc=np.mean(sem_list), scale=_stats.sem(sem_list)
            )
            print(prot, name, eng, mean, lower_ci, upper_ci)
            sem_dict[name][prot][eng] = (
                mean,
                _stats.tstd(sem_list),
                (lower_ci, upper_ci),
                sem_list,
            )

    sem_list_name = reduce(lambda xs, ys: xs + ys, sem_list_name)
    sem_list_name = [x for x in sem_list_name if str(x) != "nan"]
    mean = np.mean(sem_list_name)
    lower_ci, upper_ci = _stats.norm.interval(
        confidence=0.95, loc=np.mean(sem_list_name), scale=_stats.sem(sem_list_name)
    )
    print(name, mean, lower_ci, upper_ci)
    sem_dict_name[name] = (
        mean,
        _stats.tstd(sem_list_name),
        (lower_ci, upper_ci),
        sem_list_name,
    )

In [None]:
# calcualte the differences in SEM
# SEM differences
sem_dict = {}
sem_dict_name = {}

for name in ana_dicts:
    for eng in ["GROMACS"]:
        sem_list = []
        for prot in ana_obj_dict.keys():
            ana_obj = ana_obj_dict[prot][name]
            sems = [val[1] for val in ana_obj.calc_pert_dict[eng].values()]
            # print(prot, sems)
            sem_list.append(sems)

        sem_list = reduce(lambda xs, ys: xs + ys, sem_list)
        sem_list = [x for x in sem_list if str(x) != "nan"]
        print(len(sem_list))
        mean = np.mean(sem_list)
        lower_ci, upper_ci = _stats.norm.interval(
            confidence=0.95, loc=np.mean(sem_list), scale=_stats.sem(sem_list)
        )
        print(name, eng, mean, lower_ci, upper_ci)

In [None]:
# plot all the ddG
# also calc mae perts

mae_dict = {}

for name in ana_dicts:
    mae_dict[name] = {}

    for prot in ana_obj_dict.keys():
        print(prot, name)

        mae_dict[name][prot] = {}

        ana_obj = ana_obj_dict[prot][name]

        stats_string_all = ""
        try:
            mae = ana_obj.calc_mae_engines(pert_val="pert", recalculate=False)
        except Exception as e:
            print(e)

        for eng in ana_obj.engines:
            stats_string = ""
            try:
                mae_dict[name][prot][eng] = (
                    mae[0][eng]["experimental"],
                    mae[1][eng]["experimental"],
                    mae[2][eng]["experimental"],
                )
                stats_string += f"{eng} MAE: {mae[0][eng]['experimental']:.2f} +/- {mae[1][eng]['experimental']:.2f} kcal/mol, "

                if sem_dict[name][prot][eng][0]:
                    stats_string += f"SEM: {sem_dict[name][prot][eng][0]:.2f} +/- {sem_dict[name][prot][eng][1]:.2f} kcal/mol\n"
                elif name == "single":
                    errors = [val[1] for val in ana_obj.calc_pert_dict[eng].values()]
                    stats_string += f"error: {np.mean(errors):.2f} +/- {_stats.tstd(errors):.2f} kcal/mol\n"

                print(stats_string)

            except Exception as e:
                print(e)
                print(f"could not compute for {prot} {name} {eng}")

            # try:
            #     ana_obj.plot_scatter_ddG(
            #         engines=eng, suptitle=f"{prot}, {method}\n", title=f"{stats_string}")
            #     ana_obj.plot_scatter_ddG(engines=eng, use_cinnabar=True)
            # except:
            #     pass
            # stats_string_all+=stats_string

        # try:
        #     ana_obj.plot_scatter_ddG(
        #         suptitle=f"{prot}, {method}\n \n \n \n \n", title=f"{stats_string_all}", engines=ana_obj.engines)
        # except:
        #     print(f"could not plot {prot} {method}")

In [None]:
# graphs based on engine
plotting_dict = mae_dict  # mae_dict or sem_dict
stats_name = "ΔΔG MAE"  # MAE or SEM

fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(20, 20), sharex=True, sharey=True)
plt.xlim = ()
plt.ylim = ()
for engine, pos in zip(ana_obj.engines, [axes[0], axes[1], axes[2]]):
    df_list = []
    df_err_list = []
    for name in ana_dicts:
        df = (
            pd.DataFrame(plotting_dict[name])
            .applymap(lambda x: x[0])
            .rename(prot_dict_name, axis=1)
            .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
            .rename({engine: name}, axis=1)
            .rename(
                {
                    "plain": "Full data",
                    "subsampling": "Subsampling",
                    "autoeq": "Auto-equilibration",
                    "1ns": "1 ns sampling",
                    "2ns": "2 ns sampling",
                    "3ns": "3 ns sampling",
                },
                axis=1,
            )
        )

        df_err = (
            pd.DataFrame(plotting_dict[name])
            .applymap(lambda x: x[2])
            .rename(prot_dict_name, axis=1)
            .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
            .rename({engine: name}, axis=1)
            .rename(
                {
                    "plain": "Full data",
                    "subsampling": "Subsampling",
                    "autoeq": "Auto-equilibration",
                    "1ns": "1 ns sampling",
                    "2ns": "2 ns sampling",
                    "3ns": "3 ns sampling",
                },
                axis=1,
            )
        )

        df_lower = df_err.applymap(lambda x: x[0])
        df_upper = df_err.applymap(lambda x: x[1])
        df_err = (df_upper - df_lower) / 2

        df_list.append(df)
        df_err_list.append(df_err)

    df = reduce(
        lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
        df_list,
    )
    df_err = reduce(
        lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
        df_err_list,
    )

    print(df)
    print(engine)
    print(df.mean())
    print(df.sem())
    print(df_err)

    # engine colours
    col_dict = {
        "AMBER": plt.get_cmap("autumn"),
        "SOMD": plt.get_cmap("cool"),
        "GROMACS": plt.get_cmap("viridis"),
    }

    # scale data for compatibility with cmap
    data = [i for i in range(1, len(df.columns) + 1)]
    den = max(data) - min(data)
    scaled_data = [(datum - min(data)) / den for datum in data]

    # get colors corresponding to data
    colors = []
    my_cmap = plt.get_cmap("plasma")  # col_dict[engine]

    for decimal in scaled_data:
        colors.append(my_cmap(decimal))

    df.plot(
        kind="bar",
        color=colors,
        yerr=df_err,
        title=eng_dict_name[engine],
        ax=pos,
        xlabel="Protein System",
        ylabel=f"{stats_name} (kcal/mol)",
    )
# fig.suptitle(f'{stats_name} perturbations for LOMAP/RBFENN-score')

In [None]:
# graph for each engine and each method
plotting_dict = mae_dict  # mae_dict or sem_dict
stats_name = "ddG MAE"  # MAE or SEM

fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(20, 20), sharex=True, sharey=True)
plt.xlim = ()
plt.ylim = ()

df_list = []
df_err_list = []
for name in ana_dicts:
    df = pd.DataFrame(plotting_dict[name]).applymap(lambda x: x[0]).T

    df_err = pd.DataFrame(plotting_dict[name]).applymap(lambda x: x[2]).T

    df_lower = df_err.applymap(lambda x: x[0])
    df_upper = df_err.applymap(lambda x: x[1])
    df_err = (df_upper - df_lower) / 2

    df_list.append(df)
    df_err_list.append(df_err)

df = reduce(
    lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
    df_list,
)
df_err = reduce(
    lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
    df_err_list,
)

# scale data for compatibility with cmap
data = [i for i in range(1, len(df.columns) + 1)]
den = max(data) - min(data)
scaled_data = [(datum - min(data)) / den for datum in data]

# get colors corresponding to data
colors = []
my_cmap = plt.get_cmap("plasma")

for decimal in scaled_data:
    colors.append(my_cmap(decimal))

df.plot(
    kind="bar",
    color=colors,
    yerr=df_err,
    title=engine,
    ax=pos,
    xlabel="MD Engine",
    ylabel=f"{stats_name} (kcal/mol)",
)

In [None]:
# one graph for one method but that compared for each protein
plotting_dict = mae_dict  # mae_dict or sem_dict
stats_name = "ddG MAE"  # MAE or SEM
plt.xlim = ()
plt.ylim = (0, 2)

name = "subsampling"
df = pd.DataFrame(plotting_dict[name]).applymap(lambda x: x[0]).T
df_err = pd.DataFrame(plotting_dict[name]).applymap(lambda x: x[2]).T
df_lower = df_err.applymap(lambda x: x[0])
df_upper = df_err.applymap(lambda x: x[1])
df_err = (df_upper - df_lower) / 2
ax = df.plot(
    kind="bar",
    color=pipeline.analysis.set_colours(),
    xlabel="protein system",
    ylabel=f"{stats_name} (kcal/mol)",
    yerr=df_err,
)
ax.set_ylim(bottom=0)

In [None]:
# check significant difference

# calculating if the stats for a given dict are actually also sig diff
results_dict = mae_dict["subsampling"]

for protein in ["tyk2", "mcl1", "p38"]:
    for eng1, eng2 in it.product(ana_obj.engines, ana_obj.engines):
        results = _stats.ttest_ind_from_stats(
            mean1=results_dict[protein][eng1][0],
            mean2=results_dict[protein][eng2][0],
            std1=results_dict[protein][eng1][1],
            std2=results_dict[protein][eng2][1],
            nobs1=40,
            nobs2=40,
        )
        print(protein, eng1, eng2, results)

In [None]:
# get the cinnabar stats into a dict
net_ana_method_dict = {"method": [], "engine": [], "protein": [], "value": []}

for method in list(ana_dicts.keys()) + ["single_0", "single_1", "single_2"]:
    for eng in ana_obj.engines:
        overall_dg_list = []

        for prot in ana_obj_dict.keys():
            dg_list = []

            if "single" in method:
                print(f"method is {method}!")
                ana_obj = ana_obj_dict[prot]["subsampling"]
                use_dict = ana_obj.calc_repeat_pert_dict[eng][
                    int(method.split("_")[-1])
                ]
            else:
                ana_obj = ana_obj_dict[prot][method]
                use_dict = ana_obj.calc_pert_dict[eng]

            for key in use_dict.keys():
                if key not in ana_obj._perturbations_dict[eng]:
                    print(f"{key} not in pert dict")
                    continue
                # try:
                value = abs(abs(use_dict[key][0] - ana_obj.exper_pert_dict[key][0]))
                # if value > 10:
                #     print(prot, eng, key, value)
                # else:
                dg_list.append(value)
                # except:
                #     print(f"{key} not in dict for {eng} {method}")

            net_ana_method_dict["method"].append(
                [method for l in range(0, len(dg_list))]
            )
            net_ana_method_dict["engine"].append([eng for l in range(0, len(dg_list))])
            net_ana_method_dict["protein"].append([prot for val in dg_list])
            net_ana_method_dict["value"].append([val for val in dg_list])
            overall_dg_list.append(dg_list)

        print(
            f"{eng} {method} mean is {np.mean([dg for dg in flatten_comprehension(overall_dg_list) if dg])}"
        )


plotting_dict = {
    "method": flatten_comprehension(net_ana_method_dict["method"]),
    "MD engine": flatten_comprehension(net_ana_method_dict["engine"]),
    "MAE ddG (kcal/mol)": flatten_comprehension(net_ana_method_dict["value"]),
    "Protein": flatten_comprehension(net_ana_method_dict["protein"]),
}

df = pd.DataFrame(plotting_dict)
ax = sns.boxplot(df, x="MD engine", y="MAE ddG (kcal/mol)", hue="method")
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

In [None]:
ax = sns.displot(df, x="MAE ddG (kcal/mol)", hue="MD engine")
# ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

In [None]:
ax = sns.barplot(
    df, x="MD engine", y="MAE ddG (kcal/mol)", hue="method", errorbar=("ci")  # 95%
)
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

val_dict = {
    method: {}
    for method in list(ana_dicts.keys()) + ["single_0", "single_1", "single_2"]
}

for p, l, name in zip(
    ax.patches,
    ax.lines,
    it.product(
        list(ana_dicts.keys()) + ["single_0", "single_1", "single_2"], ana_obj.engines
    ),
):
    xy = l.get_xydata()
    # print(f"{name}: {p.get_height():.2f} ({xy[0][1]:.2f}, {xy[1][1]:.2f})")
    val_dict[name[0]][
        name[1]
    ] = f"{p.get_height():.2f} ({xy[0][1]:.2f}, {xy[1][1]:.2f})"

df_val = pd.DataFrame(val_dict)
df_val.T

In [None]:
# default autoeq time for each engine
# recalculate eq times if needed
for prot in ana_obj_dict.keys():
    try:
        print(prot)
        ana_obj = ana_obj_dict[prot]["autoeq"]
        ana_obj.compute_equilibration_times(compute_missing=False)
    except Exception as e:
        print(e)
        print(f"could not for {prot}")

# check equilibration times
eq_dict_avg = {}
std_dict_avg = {}
sem_dict_avg = {}
overall_array = {}
for eng in ana_obj_dict["tyk2"]["autoeq"].engines:
    overall_array[eng] = np.array([])

for prot in ana_obj_dict.keys():
    print(prot)

    eq_dict_avg[prot] = {}
    std_dict_avg[prot] = {}
    sem_dict_avg[prot] = {}

    ana_obj = ana_obj_dict[prot]["autoeq"]

    for eng in ana_obj.engines:
        # print(eng)
        eq_dict_avg[prot][eng] = []
        for pert in ana_obj.eq_times_dict[eng].keys():
            eq_dict_avg[prot][eng].append(
                [
                    ana_obj.eq_times_dict[eng][pert][key]["mean"]
                    for key in ana_obj.eq_times_dict[eng][pert].keys()
                ]
            )

        eq_dict_avg[prot][eng] = np.array(eq_dict_avg[prot][eng])[
            np.array(eq_dict_avg[prot][eng]) != None
        ]
        eq_dict_avg[prot][eng] = eq_dict_avg[prot][eng][
            ~np.isnan(list(eq_dict_avg[prot][eng]))
        ]
        # print(eq_dict_avg[prot][eng])
        overall_array[eng] = np.concatenate(
            [overall_array[eng], eq_dict_avg[prot][eng]]
        )

        sem_dict_avg[prot][eng] = _stats.sem(eq_dict_avg[prot][eng])
        std_dict_avg[prot][eng] = _stats.tstd(eq_dict_avg[prot][eng])
        eq_dict_avg[prot][eng] = np.mean(eq_dict_avg[prot][eng])

df = pd.DataFrame.from_dict(eq_dict_avg).transpose() * 100  # transpose for per engine
df_sem = (
    pd.DataFrame.from_dict(sem_dict_avg).transpose() * 100
)  # transpose for per engine
df_std = (
    pd.DataFrame.from_dict(std_dict_avg).transpose() * 100
)  # transpose for per engine
print(df)

dict_lower = {}
dict_higher = {}
for eng in ana_obj.engines:
    # check normally dist
    # if not check_normal_dist(overall_array[eng]):
    #     print("not normal distribution")

    mean = np.mean(overall_array[eng])
    lower_ci, upper_ci = _stats.norm.interval(
        confidence=0.95,
        loc=np.mean(overall_array[eng]),
        scale=_stats.sem(overall_array[eng]),
    )
    print(eng, mean, lower_ci, upper_ci)
    dict_lower[eng] = [lower_ci * 100]
    dict_higher[eng] = [upper_ci * 100]

In [None]:
# plot the average for the engines and also per system
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5), sharex=False, sharey=True)

# col_dict = pipeline.analysis.set_colours()
col_dict = {
    "AMBER22": "orange",
    "SOMD1": "darkturquoise",
    "GROMACS23": "orchid",
    "experimental": "midnightblue",
}

df_plot = df
df_plot.rename(prot_dict_name, inplace=True)
df_plot.rename(eng_dict_name, inplace=True, axis=1)

# plt.tick_params(axis="x", labelsize=10, rotation=45)
# plt.tick_params(axis="y", labelsize=10)

df_plot.T.mean().plot.bar(
    color="purple",
    yerr=df_plot.T.sem(),
    xlabel="Protein System",
    ylabel="Average amount of run\ndiscarded by auto-equilibration (%)",
    ax=axes[0],
)

# plt.tick_params(axis="x", labelsize=10, rotation=45)
# plt.tick_params(axis="y", labelsize=10)

df_plot.plot.bar(
    color=pipeline.analysis.set_colours().values(),
    yerr=df_sem,
    xlabel="Protein System",
    ylabel="Average amount of run\ndiscarded by auto-equilibration (%)",
    ax=axes[1],
)

df_plot.mean().plot.bar(
    color=col_dict.values(),
    yerr=df_plot.sem(),
    xlabel="MD engine",
    ylabel="Average amount of run\ndiscarded by auto-equilibration (%)",
    ax=axes[2],
)
plt.tick_params(axis="x", rotation=20)

In [None]:
try:
    df = df.set_index("protein")
except:
    pass
# plotting w the bars representing how much of the average is each engine
df.div(df.sum(axis=1), axis=0).mul(df.mean(axis=1), axis=0).plot(
    kind="bar",
    stacked=True,
    color=pipeline.analysis.set_colours(),
    xlabel="protein system",
    ylabel="Average amount of run\ndiscarded by auto-equilibration (%)",
)
plt.errorbar(x=df.index, y=df.T.mean(), yerr=df.T.sem(), ecolor="black", linestyle="")

In [None]:
# make list of single run uncertainties - same across engines?
# can feed into histogram function and also average to compare

# abstract histogram plotting so can put in data
# abstract other plotting functions so can put in whatever
# also need to abstract colour?
# histogram
all_analysis_object.plot_histogram_sem(pert_val="pert")

all_analysis_object.plot_histogram_runs()

In [None]:
# plot the dG
for prot in ana_obj_dict.keys():
    ana_obj = ana_obj_dict[prot]["subsampling"]

    ana_obj.plot_scatter_dG()

    for eng in ana_obj.engines:
        ana_obj.plot_scatter_dG(engines=eng)
        ana_obj.lpot_scatter_dG(engines=eng, use_cinnabar=True)