In [2]:
# Confgen NOE notebook

import matplotlib

#%matplotlib inline
# matplotlib.use("Agg")

import mdtraj as md
import numpy as np
import matplotlib.pyplot as plt
import scipy.cluster.hierarchy
from scipy.spatial.distance import squareform
import pandas as pd

sys.path.append(os.getcwd())
import src.noe
import src.stats
from src.pyreweight import reweight
from src.utils import json_load, pickle_dump
from scipy import stats
from scipy.optimize import minimize
from sklearn import metrics, utils
import copy

compound_index = int(snakemake.wildcards.compound)

In [3]:
!ls

In [4]:
# read in conformers as mdtraj trajectory
chem_info_t = md.load(snakemake.input.pdb)
# read in NOE data
NOE_original = src.noe.read_NOE(snakemake.input.noe)
# read in compound details
compound = json_load(snakemake.input.parm)
# Read in conformer energies
energies = np.loadtxt(snakemake.input.energies)
# Detect cis/trans
multi = compound.multi
if multi:
    print(
        "According to the literature reference, there are two distinct structures in solution."
    )
else:
    print(
        "According to the literature reference, there is only one distinct structure in solution."
    )

if multi is not None:
    multi = {v: k for k, v in multi.items()}
    multiple = True
    distinction = compound.distinction
    print("Multiple compounds detected")
    # Show relevant dihedral angle for all conformers if cis/trans
    ca_c = chem_info_t.top.select(f"resid {distinction[0]} and name CA C")
    n_ca_next = chem_info_t.top.select(f"resid {distinction[1]} and name N CA")
    omega = np.append(ca_c, n_ca_next)
    t_omega_rad = md.compute_dihedrals(chem_info_t, [omega])
    t_omega_deg = np.abs(np.degrees(t_omega_rad))
    plt.plot(t_omega_deg)
    plt.hlines(90, 0, chem_info_t.n_frames, color="red")
    plt.xlabel("Frames")
    plt.ylabel("Omega 0-1 [°]")
    plt.title(f"Dihedral angle over time. Compound {compound_index}")
    cis = np.where(t_omega_deg <= 90)[0]
    trans = np.where(t_omega_deg > 90)[0]
else:
    multiple = False

In [9]:
snakemake.input.pdb

In [11]:
# Compute NOE-statistics for individual conformers, show distributions
rmsd = []
mae = []
mse = []
fulfilled = []
results = {}
for i in range(chem_info_t.n_frames):
    NOE = copy.deepcopy(NOE_original)
    if multiple:
        NOE_trans, NOE_cis = NOE
        NOE_cis_dict = NOE_cis.to_dict(orient="index")
        NOE_trans_dict = NOE_trans.to_dict(orient="index")
    else:
        NOE_dict = NOE.to_dict(orient="index")

    current_conformer = chem_info_t[i]

    if multiple:
        if i in cis:
            NOE = NOE_cis
            NOE_dict = NOE_cis_dict
        else:
            NOE = NOE_trans
            NOE_dict = NOE_trans_dict

    NOE["md"], _, _2, NOE_dist, _3 = src.noe.compute_NOE_mdtraj(
        NOE_dict, current_conformer
    )
    # Deal with ambigous NOEs
    NOE = NOE.explode("md")
    # and ambigous/multiple values
    NOE = NOE.explode("NMR exp")

    # Remove duplicate values (keep value closest to experimental value)
    NOE_test = NOE
    if (NOE_test["NMR exp"].to_numpy() == 0).all():
        # if all exp values are 0: take middle between upper / lower bound as reference value
        NOE_test["NMR exp"] = (
            NOE_test["upper bound"] + NOE_test["lower bound"]
        ) * 0.5
    NOE_test["dev"] = NOE_test["md"] - np.abs(NOE_test["NMR exp"])
    NOE_test["abs_dev"] = np.abs(NOE_test["md"] - np.abs(NOE_test["NMR exp"]))

    NOE_test = NOE_test.sort_values("abs_dev", ascending=True)
    NOE_test.index = NOE_test.index.astype(int)
    NOE_test = NOE_test[~NOE_test.index.duplicated(keep="first")].sort_index(
        kind="mergesort"
    )

    # drop NaN values:
    NOE_test = NOE_test.dropna()

    # Compute NOE statistics, since no bootstrap necessary, do a single iteration.. TODO: could clean this up further to pass 0, then just return the value...
    RMSD, upper, lower = src.stats.compute_RMSD(
        NOE_test["NMR exp"], NOE_test["md"], n_bootstrap=1
    )
    MAE, *_ = src.stats.compute_MAE(
        NOE_test["NMR exp"], NOE_test["md"], n_bootstrap=1
    )
    MSE, *_ = src.stats.compute_MSE(NOE_test["dev"], n_bootstrap=1)
    fulfil = src.stats.compute_fulfilled_percentage(NOE_test)
    rmsd.append(RMSD)
    mae.append(MAE)
    mse.append(MSE)
    fulfilled.append(fulfil)
rmsd = np.array(rmsd)
mae = np.array(mae)
mse = np.array(mse)
fulfilled = np.array(fulfilled)

In [12]:
NOE

In [13]:
# best choice
if multiple:
    if len(cis) > 0 and len(trans) > 0:
        conformer_indices = [cis, trans]
        dict_key = ["cis", "trans"]
    elif len(cis) > 0 and len(trans) == 0:
        conformer_indices = [cis]
        dict_key = ["cis"]
    elif len(trans) > 0 and len(cis) == 0:
        conformer_indices = [trans]
        dict_key = ["trans"]
else:
    conformer_indices = [np.arange(0, chem_info_t.n_frames)]
    dict_key = ["single"]
for key in dict_key:
    results[key] = {}
    results[key]["fulfil"] = {}
    results[key]["rmsd"] = {}
    results[key]["mae"] = {}
for idx, ci in enumerate(conformer_indices):
    best_fulfil = np.argmax(fulfilled[ci])
    best_rmsd = np.argmin(rmsd[ci])
    best_mae = np.argmin(mae[ci])
    # Create dicts to store different values

    # Save metrics
    results[dict_key[idx]]["fulfil"]["best"] = max(fulfilled[ci])
    results[dict_key[idx]]["rmsd"]["best"] = min(rmsd[ci])
    results[dict_key[idx]]["mae"]["best"] = min(mae[ci])

# Plot NOEs for single best conformer(s)
if multiple:
    if len(cis) > 0:
        best_rmsd_cis = np.argmin(rmsd[cis])
        best_fulfilled_cis = np.argmax(fulfilled[cis])
    #         results['cis'] = {'best': f"{max(fulfilled[cis])}"}

    if len(trans) > 0:
        best_rmsd_trans = np.argmin(rmsd[trans])
        best_fulfilled_trans = np.argmax(fulfilled[trans])
#         results['trans'] = {'best': f"{max(fulfilled[trans])}"}

else:
    best_rmsd = np.argmin(rmsd)
    best_fulfilled = np.argmax(fulfilled)
#     results['single'] = {'best': f"{max(fulfilled)}"}
# print(f"best rmsd == best fulfilled: {best_rmsd == best_fulfilled}")

NOE = src.noe.read_NOE(snakemake.input.noe)
if multiple:
    NOE_trans, NOE_cis = NOE
    NOE_cis_dict = NOE_cis.to_dict(orient="index")
    NOE_trans_dict = NOE_trans.to_dict(orient="index")
else:
    NOE_dict = NOE.to_dict(orient="index")
if not multiple:
    current_conformer = chem_info_t[best_fulfilled]
    NOE["md"], _, _2, NOE_dist, _3 = src.noe.compute_NOE_mdtraj(
        NOE_dict, current_conformer
    )
    # Deal with ambigous NOEs
    NOE = NOE.explode("md")
    # and ambigous/multiple values
    NOE = NOE.explode("NMR exp")

    fig, ax = src.noe.plot_NOE(NOE)
    fig.savefig(snakemake.output.best_NOE_plot, dpi=300)
else:
    if len(cis) > 0:
        # cis
        current_conformer = chem_info_t[best_fulfilled_cis]
        NOE_cis["md"], _, _2, NOE_dist, _3 = src.noe.compute_NOE_mdtraj(
            NOE_cis_dict, current_conformer
        )
        # Deal with ambigous NOEs
        NOE_cis = NOE_cis.explode("md")
        # and ambigous/multiple values
        NOE_cis = NOE_cis.explode("NMR exp")
    if len(trans) > 0:
        # trans
        current_conformer = chem_info_t[best_fulfilled_trans]
        NOE_trans["md"], _, _2, NOE_dist, _3 = src.noe.compute_NOE_mdtraj(
            NOE_trans_dict, current_conformer
        )
        # Deal with ambigous NOEs
        NOE_trans = NOE_trans.explode("md")
        # and ambigous/multiple values
        NOE_trans = NOE_trans.explode("NMR exp")

    fig, ax = plt.subplots(2, 1)
    ax[0].set_title("cis")
    ax[1].set_title("trans")
    if len(cis) > 0:
        fig, ax[0] = src.noe.plot_NOE(NOE_cis, fig, ax[0])
    if len(trans) > 0:
        fig, ax[1] = src.noe.plot_NOE(NOE_trans, fig, ax[1])
    fig.tight_layout()
    fig.savefig(snakemake.output.best_NOE_plot, dpi=300)

In [14]:
# Plot distributions of NOE statistics
if multiple:
    fig, axs = plt.subplots(2, 4)
else:
    fig, axs = plt.subplots(1, 4)
fig.set_size_inches(8, 4)
if multiple:
    if len(cis) > 0:
        axs[0][0].violinplot(rmsd[cis], showmeans=True)
        axs[0][0].set_ylabel("RMSD [$\AA$]")
        axs[0][0].set_title("RMSD")
        fig.suptitle(
            f"Compound {snakemake.wildcards.compound}. {snakemake.wildcards.confgen.capitalize()}. top:cis, bottom:trans"
        )  # -{snakemake.wildcards.mode}

        axs[0][1].violinplot(mae[cis], showmeans=True)
        axs[0][1].set_ylabel("MAE [$\AA$]")
        axs[0][1].set_title("MAE")

        axs[0][2].violinplot(mse[cis], showmeans=True)
        axs[0][2].set_ylabel("MSE [$\AA$]")
        axs[0][2].set_title("MSE")

        axs[0][3].violinplot(fulfilled[cis], showmeans=True)
        axs[0][3].set_ylabel("% NOE fulfilled [1/100 %]")
        axs[0][3].set_title("fulfilled NOEs")

    if len(trans) > 0:
        # trans
        axs[1][0].violinplot(rmsd[trans], showmeans=True)
        axs[1][0].set_ylabel("RMSD [$\AA$]")
        axs[1][0].set_title("RMSD")

        axs[1][1].violinplot(mae[trans], showmeans=True)
        axs[1][1].set_ylabel("MAE [$\AA$]")
        axs[1][1].set_title("MAE")

        axs[1][2].violinplot(mse[trans], showmeans=True)
        axs[1][2].set_ylabel("MSE [$\AA$]")
        axs[1][2].set_title("MSE")

        axs[1][3].violinplot(fulfilled[trans], showmeans=True)
        axs[1][3].set_ylabel("% NOE fulfilled [1/100 %]")
        axs[1][3].set_title("% NOE fulfilled")
else:
    axs[0].violinplot(rmsd, showmeans=True)
    axs[0].set_ylabel("RMSD [$\AA$]")
    axs[0].set_title("RMSD")
    fig.suptitle(
        f"Compound {snakemake.wildcards.compound.capitalize()}. {snakemake.wildcards.confgen.capitalize()}"
    )  # -{snakemake.wildcards.mode}

    axs[1].violinplot(mae, showmeans=True)
    axs[1].set_ylabel("MAE [$\AA$]")
    axs[1].set_title("MAE")

    axs[2].violinplot(mse, showmeans=True)
    axs[2].set_ylabel("MSE [$\AA$]")
    axs[2].set_title("MSE")

    axs[3].violinplot(fulfilled, showmeans=True)
    axs[3].set_ylabel("% NOE fulfilled [1/100 %]")
    axs[3].set_title("% NOE fulfilled")
for ax in axs.flatten():
    ax.get_xaxis().set_visible(False)
fig.tight_layout()
fig.savefig(snakemake.output.NOE_violin_plot, dpi=300)

In [15]:
# Bundle analysis
bundle_sizes = [1, 3, 5, 10, 30]


def bundle_analysis(indices, NOE=None, regular_average=False, weights=None):
    """
    perform bundle analysis for given conformer indices.
    optionally can pass a NOE object.
    Performs NOE averaging by default. Set regular_average=True to perform simple mean computation.
    optionally pass weights to weigh list of indices. (not yet implements..)
    """
    if NOE is None:
        NOE = NOE_original
    if multiple:
        NOE_trans, NOE_cis = NOE
        NOE_cis_dict = NOE_cis.to_dict(orient="index")
        NOE_trans_dict = NOE_trans.to_dict(orient="index")
    else:
        NOE_dict = NOE.to_dict(orient="index")

    # select conformers
    current_conformer = chem_info_t[indices]

    if multiple:
        if indices[0] in cis:
            NOE = NOE_cis
            NOE_dict = NOE_cis_dict
        else:
            NOE = NOE_trans
            NOE_dict = NOE_trans_dict
    if regular_average:
        NOE["md"], _, _2, NOE_dist, _3 = src.noe.compute_NOE_mdtraj(
            NOE_dict, current_conformer, reweigh_type=3, weight_data=weights
        )
    else:
        NOE["md"], _, _2, NOE_dist, _3 = src.noe.compute_NOE_mdtraj(
            NOE_dict, current_conformer, reweigh_type=0, weight_data=weights
        )
    # Deal with ambigous NOEs
    NOE = NOE.explode("md")
    # and ambigous/multiple values
    NOE = NOE.explode("NMR exp")

    # Remove duplicate values (keep value closest to experimental value)
    NOE_test = NOE
    if (NOE_test["NMR exp"].to_numpy() == 0).all():
        # if all exp values are 0: take middle between upper / lower bound as reference value
        NOE_test["NMR exp"] = (
            NOE_test["upper bound"] + NOE_test["lower bound"]
        ) * 0.5
    NOE_test["dev"] = NOE_test["md"] - np.abs(NOE_test["NMR exp"])
    NOE_test["abs_dev"] = np.abs(NOE_test["md"] - np.abs(NOE_test["NMR exp"]))

    NOE_test = NOE_test.sort_values("abs_dev", ascending=True)
    NOE_test.index = NOE_test.index.astype(int)
    NOE_test = NOE_test[~NOE_test.index.duplicated(keep="first")].sort_index(
        kind="mergesort"
    )

    # drop NaN values:
    NOE_test = NOE_test.dropna()

    # Compute NOE statistics, since no bootstrap necessary, do a single iteration.. TODO: could clean this up further to pass 0, then just return the value...
    RMSD, upper, lower = src.stats.compute_RMSD(
        NOE_test["NMR exp"], NOE_test["md"], n_bootstrap=1
    )
    MAE, *_ = src.stats.compute_MAE(
        NOE_test["NMR exp"], NOE_test["md"], n_bootstrap=1
    )
    MSE, *_ = src.stats.compute_MSE(NOE_test["dev"], n_bootstrap=1)
    fulfil = src.stats.compute_fulfilled_percentage(NOE_test)
    return {"rmsd": RMSD, "mae": MAE, "mse": MSE, "fulfil": fulfil}

In [16]:
# Random choice
if multiple:
    if len(cis) > 0 and len(trans) > 0:
        conformer_indices = [cis, trans]
        dict_key = ["cis", "trans"]
    elif len(cis) > 0 and len(trans) == 0:
        conformer_indices = [cis]
        dict_key = ["cis"]
    elif len(trans) > 0 and len(cis) == 0:
        conformer_indices = [trans]
        dict_key = ["trans"]
else:
    conformer_indices = [np.arange(0, chem_info_t.n_frames)]
    dict_key = ["single"]
for idx, ci in enumerate(conformer_indices):
    random_choice_fulfil = []
    random_choice_rmsd = []
    random_choice_mae = []
    for bundle_size in bundle_sizes:
        fulfil_total = []
        rmsd_total = []
        mae_total = []
        for i in range(10):
            # conformer_indices = np.arange(0,chem_info_t.n_frames)
            # random bundle
            # from numpy.random import default_rng
            rng = np.random.default_rng()
            # If there a only a few conformers available (less than bundle size),
            # set bundle_size to max. available conformers
            if bundle_size > len(ci):
                bundle_size = len(ci)
            indices_selection = rng.choice(ci, bundle_size, replace=False)
            indices_selection.sort()
            fulfil_total.append(bundle_analysis(indices_selection)["fulfil"])
            rmsd_total.append(bundle_analysis(indices_selection)["rmsd"])
            mae_total.append(bundle_analysis(indices_selection)["mae"])
        fulfil_total = np.array(fulfil_total)
        rmsd_total = np.array(rmsd_total)
        mae_total = np.array(mae_total)
        random_choice_fulfil.append(fulfil_total.mean())
        random_choice_rmsd.append(rmsd_total.mean())
        random_choice_mae.append(mae_total.mean())
    results[dict_key[idx]]["bundle-size"] = bundle_sizes

    results[dict_key[idx]]["fulfil"]["random"] = random_choice_fulfil
    results[dict_key[idx]]["rmsd"]["random"] = random_choice_rmsd
    results[dict_key[idx]]["mae"]["random"] = random_choice_mae

In [17]:
# Lowest energy conformers
if multiple:
    if len(cis) > 0 and len(trans) > 0:
        conformer_indices = [cis, trans]
        dict_key = ["cis", "trans"]
    elif len(cis) > 0 and len(trans) == 0:
        conformer_indices = [cis]
        dict_key = ["cis"]
    elif len(trans) > 0 and len(cis) == 0:
        conformer_indices = [trans]
        dict_key = ["trans"]
else:
    conformer_indices = [np.arange(0, chem_info_t.n_frames)]
    dict_key = ["single"]
for idx, ci in enumerate(conformer_indices):
    energy_choice_fulfil = []
    energy_choice_rmsd = []
    energy_choice_mae = []
    relevant_energies = energies[ci]
    # print(ci)
    for bundle_size in bundle_sizes:
        # If there a only a few conformers available (less than bundle size),
        # set bundle_size to max. available conformers
        if bundle_size > len(ci):
            bundle_size = len(ci)
        # get indices of the {bundle_size} smallest elements of energies
        ind = np.argsort(relevant_energies)
        min_energies_indices = ci[ind][:bundle_size]
        # ind = np.argpartition(relevant_energies, bundle_size)[:bundle_size]
        # min_energies_indices = ind[np.argsort(relevant_energies[ind])]
        # print(min_energies_indices)
        energy_choice_fulfil.append(
            bundle_analysis(min_energies_indices)["fulfil"]
        )
        energy_choice_rmsd.append(
            bundle_analysis(min_energies_indices)["rmsd"]
        )
        energy_choice_mae.append(bundle_analysis(min_energies_indices)["mae"])

    results[dict_key[idx]]["fulfil"]["low_energy"] = energy_choice_fulfil
    results[dict_key[idx]]["rmsd"]["low_energy"] = energy_choice_rmsd
    results[dict_key[idx]]["mae"]["low_energy"] = energy_choice_mae

In [18]:
# LICUV
if multiple:
    if len(cis) > 0 and len(trans) > 0:
        conformer_indices = [cis, trans]
        dict_key = ["cis", "trans"]
    elif len(cis) > 0 and len(trans) == 0:
        conformer_indices = [cis]
        dict_key = ["cis"]
    elif len(trans) > 0 and len(cis) == 0:
        conformer_indices = [trans]
        dict_key = ["trans"]
else:
    conformer_indices = [np.arange(0, chem_info_t.n_frames)]
    dict_key = ["single"]
for idx, ci in enumerate(conformer_indices):
    licuv_choice_fulfil = []
    licuv_choice_rmsd = []
    licuv_choice_mae = []
    relevant_fulfilled_values = fulfilled[ci]
    # print(ci)
    for bundle_size in bundle_sizes:
        # If there a only a few conformers available (less than bundle size),
        # set bundle_size to max. available conformers
        if bundle_size > len(ci):
            bundle_size = len(ci)
        # get indices of the {bundle_size} smallest elements of energies
        ind = np.argsort(relevant_fulfilled_values)
        max_fulfill_indices = ci[ind][-bundle_size:]

        licuv_choice_fulfil.append(
            bundle_analysis(max_fulfill_indices)["fulfil"]
        )
        licuv_choice_rmsd.append(bundle_analysis(max_fulfill_indices)["rmsd"])
        licuv_choice_mae.append(bundle_analysis(max_fulfill_indices)["mae"])

    results[dict_key[idx]]["fulfil"]["LICUV"] = licuv_choice_fulfil
    results[dict_key[idx]]["rmsd"]["LICUV"] = licuv_choice_rmsd
    results[dict_key[idx]]["mae"]["LICUV"] = licuv_choice_mae

In [19]:
# NAMFIS, adapted from Riniker:2022
def Namfis(indices):
    traj = chem_info_t[indices]
    # extract NOE distances for every conformer
    NOEs = []
    for current_conformer in traj:
        # Reload NOE data
        NOE = NOE_original
        if multiple:
            NOE_trans, NOE_cis = NOE
            NOE_cis_dict = NOE_cis.to_dict(orient="index")
            NOE_trans_dict = NOE_trans.to_dict(orient="index")
        else:
            NOE_dict = NOE.to_dict(orient="index")
        if multiple:
            if indices[0] in cis:
                NOE = NOE_cis
                NOE_dict = NOE_cis_dict
            else:
                NOE = NOE_trans
                NOE_dict = NOE_trans_dict
        # Compute NOEs for current conformer
        NOE["md"], *_ = src.noe.compute_NOE_mdtraj(NOE_dict, current_conformer)

        # Deal with ambigous NOEs
        NOE = NOE.explode("md")
        # and ambigous/multiple values
        NOE = NOE.explode("NMR exp")
        # Remove duplicate values (keep value closest to experimental value)
        if (NOE["NMR exp"].to_numpy() == 0).all():
            # if all exp values are 0: take middle between upper / lower bound as reference value
            NOE["NMR exp"] = (NOE["upper bound"] + NOE["lower bound"]) * 0.5
        NOE["dev"] = NOE["md"] - np.abs(NOE["NMR exp"])
        NOE["abs_dev"] = np.abs(NOE["md"] - np.abs(NOE["NMR exp"]))
        NOE = NOE.sort_values("abs_dev", ascending=True)
        NOE.index = NOE.index.astype(int)
        NOE = NOE[~NOE.index.duplicated(keep="first")].sort_index(
            kind="mergesort"
        )
        # drop NaN values:
        NOE = NOE.dropna()
        NOEs.append(NOE["md"].values)
    NOEs = np.array(NOEs)

    # set NAMFIS parameters
    tolerance = 3.0
    ref_distances_ce = NOE["NMR exp"].values
    # define error scale factor for distances in different ranges
    errors_ce = np.ones(len(ref_distances_ce)) * 0.4
    errors_ce[ref_distances_ce < 6.0] = 0.4
    errors_ce[ref_distances_ce < 3.5] = 0.3
    errors_ce[ref_distances_ce < 3.0] = 0.2
    errors_ce[ref_distances_ce < 2.5] = 0.1
    # set distances_ce
    distances_ce = NOEs

    # Define NAMFIS objective
    def objective(w):  # w is weights
        deviation = ref_distances_ce - np.average(
            distances_ce, weights=w, axis=0
        )
        deviation /= errors_ce
        #     deviation = np.heaviside(deviation, 0) * deviation #only penalise upper violation
        #     return np.sum(deviation**2) #squared deviation
        return np.linalg.norm(deviation)  # square rooted

    # Set constraints
    cons = [
        {"type": "eq", "fun": lambda w: np.sum(w) - 1}
    ]  # weights add up to 1

    cons += [  # does not allow any violation
        {
            "type": "ineq",
            "fun": lambda w: (errors_ce + tolerance)
            - np.absolute(
                np.average(distances_ce, weights=w, axis=0) - ref_distances_ce
            ),
        }
    ]

    #     cons += [ #does not allow only upper violations
    #                 {'type':'ineq','fun': lambda w: ref_distances_ce - np.average(distances_ce, weights = w, axis = 0) - tolerance}
    #     ]

    weights = (
        np.random.uniform(low=0, high=1, size=len(distances_ce))
        / len(distances_ce)
        * 2
    )  # uniform weights at start
    # print(sum(weights))

    # Run optimizaton
    out = minimize(
        objective,
        weights,
        constraints=tuple(cons),
        bounds=tuple(
            (0, 1) for _ in range(len(weights))
        ),  # each weight constraint
        method="SLSQP",
    )

    if not out["success"]:
        logger.error("NAMFIS failed: {}".format(out["message"]))

    weights = out["x"]
    num_conf = len(indices)
    # list(zip([int(i) for i in np.argsort(-1 * weights)[:num_conf]], weights[np.argsort(weights * -1)[:num_conf]]))
    return list(
        zip(
            [int(i) for i in np.argsort(-1 * weights)[:num_conf]],
            weights[np.argsort(weights * -1)[:num_conf]],
        )
    )

In [20]:
# NAMFIS run
if multiple:
    if len(cis) > 0 and len(trans) > 0:
        conformer_indices = [cis, trans]
        dict_key = ["cis", "trans"]
    elif len(cis) > 0 and len(trans) == 0:
        conformer_indices = [cis]
        dict_key = ["cis"]
    elif len(trans) > 0 and len(cis) == 0:
        conformer_indices = [trans]
        dict_key = ["trans"]
else:
    conformer_indices = [np.arange(0, chem_info_t.n_frames)]
    dict_key = ["single"]
for idx, ci in enumerate(conformer_indices):
    namfis_choice_fulfil = []
    namfis_choice_rmsd = []
    namfis_choice_mae = []
    namfis_results = Namfis(ci)
    for bundle_size in bundle_sizes:
        # If there a only a few conformers available (less than bundle size),
        # set bundle_size to max. available conformers
        if bundle_size > len(ci):
            bundle_size = len(ci)
        indices = [a[0] for a in namfis_results[:bundle_size]]
        weights = [a[1] for a in namfis_results[:bundle_size]]
        print(bundle_size)
        if bundle_size == 1:
            weights = None
        namfis_choice_fulfil.append(
            bundle_analysis(indices, regular_average=False, weights=weights)[
                "fulfil"
            ]
        )
        namfis_choice_rmsd.append(
            bundle_analysis(indices, regular_average=False, weights=weights)[
                "rmsd"
            ]
        )
        namfis_choice_mae.append(
            bundle_analysis(indices, regular_average=False, weights=weights)[
                "mae"
            ]
        )

    results[dict_key[idx]]["fulfil"]["NAMFIS"] = namfis_choice_fulfil
    results[dict_key[idx]]["rmsd"]["NAMFIS"] = namfis_choice_rmsd
    results[dict_key[idx]]["mae"]["NAMFIS"] = namfis_choice_mae

In [21]:
results

In [22]:
src.utils.json_dump(snakemake.output.fulfilled, results)

In [23]:
bundle_sizes_plot = [str(i) for i in bundle_sizes]
plt.scatter(
    bundle_sizes_plot, random_choice_fulfil, label="random", marker="o"
)
plt.scatter(
    bundle_sizes_plot, energy_choice_fulfil, label="min-energy", marker="x"
)
plt.scatter(bundle_sizes_plot, licuv_choice_fulfil, label="LICUV", marker=".")
plt.scatter(
    bundle_sizes_plot, namfis_choice_fulfil, label="NAMFIS", marker="+"
)
plt.legend()
plt.xlabel("Bundle size")
plt.ylabel("% NOE fulfilled [1/100 %]")
plt.tight_layout()
plt.savefig(snakemake.output.bundle_plot, dpi=300)

In [17]:
# def pick_namfis(num_conf, noe, pre_selection = None, tolerance = .0):
#     """Pick conformers based on NMR analysis of molecular flexibility in solution (NAMFIS).
#     NAMFIS is a constrained optimisation scheme, starting with assigning all conformer with an equal weight factor (sum of weights for all conformer equals 1),
#     the weights are optimised to give conformer ensemble that obey the NOE measurments (within experimental tolarence).
#     Parameters
#     ----------
#     num_conf : int
#         Number of conformers required.
#     noe : customETKDG.NOE object
#         NOE object containing chemically equivalent hydrogen information.
#     pre_selection : list, optional
#         A subset of conformer indices in a list from which NAMFIS is run in order to prevent convergence issues, by default None meaning all conformers are considered.
#     tolerance : float, optional
#         Allow some overall violation (e.g. due to experimental uncertainty) of specified bounds in the constrained minimisation, by default 0 meaning no tolerance.
#     Returns
#     -------
#     list
#         A list of conformer indices that have the highest weights.
#     """
#     from scipy.optimize import minimize


#     MAX_CONF_LIMIT = 200
#     if pre_selection is None:
#         pre_selection = range(self.GetNumConformers())
#     if len(pre_selection) > MAX_CONF_LIMIT:
#         logger.warning("Number of conformers exceed {}, NAMFIS might not converge.".format(MAX_CONF_LIMIT))

#     distance_matrix_for_each_conformer = np.array([Chem.Get3DDistanceMatrix(self, i) for i in pre_selection])

#     df = noe.add_noe_to_mol(self, remember_chemical_equivalence = True).distance_upper_bounds

#     distances = distance_matrix_for_each_conformer[:, df.idx1, df.idx2]
#     ref_distances = np.array(df.distance)

#     # define error scale factor for distances in different ranges
#     errors = np.ones(len(ref_distances)) * 0.4
#     errors[ref_distances < 6.0] = 0.4
#     errors[ref_distances < 3.5] = 0.3
#     errors[ref_distances < 3.0] = 0.2
#     errors[ref_distances < 2.5] = 0.1

#     #### ce means chemical equivalent
#     distances_ce = np.split(distances, np.unique(noe.chemical_equivalence_list, return_index=True)[1][1:], axis = 1) #here I group the distances by their chemical equivalence track
#     distances_ce = np.stack([np.mean(d, axis = 1) for d in distances_ce], axis = 1)


#     ref_distances_ce = np.split(ref_distances, np.unique(noe.chemical_equivalence_list, return_index=True)[1][1:])
#     ref_distances_ce = np.stack([np.mean(d) for d in ref_distances_ce])

#     errors_ce = np.split(errors, np.unique(noe.chemical_equivalence_list, return_index=True)[1][1:])
#     errors_ce = np.stack([np.mean(d) for d in errors_ce])

#     def objective(w): #w is weights
#         deviation  = ref_distances_ce - np.average(distances_ce, weights = w, axis = 0)
#         deviation /= errors_ce
#     #         deviation = np.heaviside(deviation, 0) * deviation #only penalise upper violation
#         return np.sum(deviation**2) #squared deviation
#     #         return np.linalg.norm(deviation) #square rooted

#     cons = [{'type':'eq','fun': lambda w: np.sum(w) - 1}] #weights add up to 1


#     cons += [ #does not allow any violation
#             {'type':'ineq','fun': lambda w:  (errors_ce + tolerance) - np.absolute(np.average(distances_ce, weights = w, axis = 0) - ref_distances_ce)}
#     ]

#     #     cons += [ #does not allow only upper violations
#     #                 {'type':'ineq','fun': lambda w: ref_distances_ce - np.average(distances_ce, weights = w, axis = 0) - tolerance}
#     #     ]

#     weights = np.random.uniform(low = 0, high = 1, size = len(pre_selection)) #uniform weights at start

#     out = minimize(
#         objective,
#         weights,
#         constraints = tuple(cons),
#         bounds = tuple((0,1) for _ in range(len(weights))), #each weight constraint
#         method='SLSQP')

#     if not out["success"]:
#         logger.error("NAMFIS failed: {}".format(out["message"]))

#     weights = out["x"]

#     return list(zip([int(i) for i in np.argsort(-1 * weights)[:num_conf]], weights[np.argsort(weights * -1)[:num_conf]]))