# Analysis Notebook
A copy of this notebook is run to analyse the molecular dynamics simulations. 
The type of MD simulation is specified in the Snakemake rule as a parameter, 
such that it is accessible via: snakemake.params.method.

There are various additional analysis steps, that are included in the notebook, 
but are not part of the paper. To turn these on, set the `beta_run` parameter to `True`.

There are also some commented out lines in the notebook. These are mainly for 
the purpose of debugging. Some of them are for interactively exploring the
3d structure of the system. These don't work as part of the automated snakemake
workflow, but can be enabled when running a notebook interactively. 

In [2]:
# Check if we should use shortened trajectories for analysis.
if snakemake.config["shortened"]:   #True: #
    print(
        "Using shortened trajectories and dihedrals. This only works if these were created previously!"
    )
    if not (
        os.path.exists(snakemake.params.traj_short)
        and os.path.exists(snakemake.params.dihedrals_short)
        and os.path.exists(snakemake.params.dPCA_weights_MC_short)
        and os.path.exists(snakemake.params.weights_short)
    ):
        raise FileNotFoundError(
            "Shortened trajectories and dihedrals files do not exist, but config value is set to use shortened files! Switch off the use of shortenend files and first analyse this simulation using the full trajectory!"
        )
    else:
        use_shortened = True
        snakemake.input.traj = snakemake.params.traj_short
else:
    use_shortened = False

In [3]:
# Turn on for development
# use_shortened = False
%load_ext autoreload
%autoreload 2

In [4]:
# Imports
import matplotlib 
import mdtraj as md
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import matplotlib.image as mpimg

# set matplotlib font sizes
SMALL_SIZE = 9
MEDIUM_SIZE = 11
BIGGER_SIZE = 13

plt.rc("font", size=MEDIUM_SIZE)  # controls default text sizes
plt.rc("axes", titlesize=BIGGER_SIZE)  # fontsize of the axes title
plt.rc("axes", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("legend", fontsize=MEDIUM_SIZE)  # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE)  # fontsize of the figure title

DPI = 600

import scipy.cluster.hierarchy
from scipy.spatial.distance import squareform
import pandas as pd

sys.path.append(os.getcwd())
import src.dihedrals
import src.pca
import src.noe
import src.Ring_Analysis
import src.stats
from src.pyreweight import reweight
from src.utils import json_load, pickle_dump
from sklearn.manifold import TSNE
from sklearn.cluster import DBSCAN
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
import nglview as nv
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
import py_rdl
import seaborn as sns

IPythonConsole.molSize = (900, 300)  # (450, 150)
IPythonConsole.drawOptions.addStereoAnnotation = True
IPythonConsole.drawOptions.annotationFontScale = 1.5
import tempfile
import io
import svgutils.transform as sg
import svgutils.compose as sc
import scipy.stats as stats
from IPython.display import display, Markdown

In [5]:
# Can set a stride to make prelim. analysis faster. for production, use 1 (use all MD frames)
stride = int(snakemake.config["stride"])
print(f"Using stride {stride} to analyse MD simulations.")

# Analysing compound
compound_index = int(snakemake.wildcards.compound_dir)
simtime = float(snakemake.wildcards.time)

## Compound details

In [6]:
display(Markdown(f"This notebook refers to compound {compound_index}."))

compound = json_load(snakemake.input.parm)
multi = compound.multi
if multi:
    display(
        Markdown(
            "According to the literature reference, there are two distinct structures in solution."
        )
    )
else:
    display(
        Markdown(
            "According to the literature reference, there is only one distinct structure in solution."
        )
    )
display(
    Markdown(
        f"""The sequence of the compound is **{compound.sequence}**. \n
A 2d structure of the compound is shown below."""
    )
)

## Simulation details

In [7]:
# only load protein topology

# Load full simulation box topology (without coordinates), i.e. protein + solvent + ions ..
full_topology = md.load_topology(snakemake.input.top)
protein_topology = full_topology.subset(
    full_topology.select("protein or resname ASH")
)
protein_atoms = full_topology.select("protein or resname ASH")

if not use_shortened:
    full_box = md.load_frame(snakemake.input.traj, 0, top=full_topology)
    protein = full_box.restrict_atoms(protein_atoms)
else:
    protein = md.load_frame(
        snakemake.params.traj_short, 0, top=protein_topology
    )

display(
    Markdown(
        f"The following atom numbers are part of the protein: {protein_atoms}"
    )
)

In [8]:
# Stereo check 1-frame trajectory to tmp-pdb file
t_stereo_check = protein
tf = tempfile.NamedTemporaryFile(delete=False)
# tf.name
t_stereo_check.save_pdb(tf.name)

# Get reference mol
mol_ref = Chem.MolFromMol2File(
    snakemake.input.ref_mol,
    removeHs=False,
)

# Get 1st frame pdb from tempfile
post_eq_mol = Chem.MolFromPDBFile(
    tf.name,
    removeHs=False,
    sanitize=False,
)

# could compare smiles to automate the stereo-check. Problem: mol2 reference file has wrong bond orders
# (amber does not write those correctly). The ref-pdb file cannot be read b/c geometry is not optimized.
# This leads to funky valences in rdkit. The post-eq pdb file reads fine but then charges etc. dont match
# with the reference (b/c of wrong bond orders). But can manually check that all stereocentres are correct (below)
Chem.CanonSmiles(Chem.MolToSmiles(post_eq_mol)) == Chem.CanonSmiles(
    Chem.MolToSmiles(mol_ref)
)
display(
    Markdown(
        """Following we compare an annotated 2d structure of the compound's starting topology, with the 
                 topology post equilibration"""
    )
)

In [9]:
post_eq_mol.RemoveAllConformers()
display(Markdown("2d structure of the compound post equilibration:"))
post_eq_mol

In [10]:
mol_ref.RemoveAllConformers()
display(Markdown("2d structure of the compound reference topology:"))
mol_ref

In [11]:
# load trajectory
display(Markdown("Now we load the MD trajectory."))
if not use_shortened:
    t = md.load(
        snakemake.input.traj,
        top=snakemake.input.top,
        atom_indices=protein_atoms,
        stride=stride,
    )  # added strideint for GaMD 2k
    print(t)
    # Remove solvent from trajectory
    t = t.restrict_atoms(t.topology.select("protein or resname ASH"))
    t = t.superpose(t, 0)

    # for GaMD, skip equlibration...
    if snakemake.params.method == "GaMD":
        weight_lengths = np.loadtxt(snakemake.input.weights)
        weight_lengths = int(len(weight_lengths))
        frames_start = int(t.n_frames - weight_lengths)
        t = t[
            int(frames_start / stride) :
        ]  # added 13000 instead of 26000 for 2k
    else:
        frames_start = 0
    print(t)
else:
    stride = 1  # set stride to 1 for shortened files!
    t = md.load(
        snakemake.params.traj_short, top=protein_topology, stride=1
    )  # added strideint for GaMD 2k
    t = t.restrict_atoms(t.topology.select("protein or resname ASH"))
    t = t.superpose(t, 0)
    print(t)

In [12]:
display(
    Markdown(
        f"The simulation type is {snakemake.params.method}, {snakemake.wildcards.time} ns. The simulation was performed in {snakemake.wildcards.solvent}."
    )
)
display(
    Markdown(f"There are a total of {t.n_frames} frames available to analyse.")
)

In [13]:
# Create a short trajectory & weights if working with the full trajectory
if not use_shortened:
    # determine stride to get 10k frames:
    stride_short = int(t.n_frames / 10000)
    if stride_short == 0:
        stride_short = 1

    # load weights for GaMD
    if snakemake.params.method != "cMD":
        weight_data = np.loadtxt(snakemake.input.weights)
        weight_data = weight_data[::stride]

else:
    # load shortened weights for GaMD
    if snakemake.params.method != "cMD":
        weight_data = np.loadtxt(snakemake.params.weights_short)

# this determines a cutoff for when we consider cis/trans conformers separately.
# only relevant if 2 sets of NOE values present.
# t.n_frames / 1000 -> 0.1% of frames need to be cis/trans to consider both forms.
CIS_TRANS_CUTOFF = int(t.n_frames / 1000)

However, for some of the analysis steps below, only 1% of these frames have been used to ensure better rendering in the browser.

In [14]:
# Interactive plots. Do not require a live jupyter notebook session. Render in jupyter book.

from bokeh.plotting import figure, show, output_notebook
from bokeh.models import (
    Slider,
    CheckboxGroup,
    CustomJS,
    ColumnDataSource,
    CDSView,
    CheckboxButtonGroup,
)
from bokeh.models.filters import CustomJSFilter
from bokeh.layouts import row
from bokeh.transform import factor_cmap
from bokeh.palettes import Category10_10
from bokeh.io import export_svgs

output_notebook()

## Convergence of the simulation
### RMSD
To check for convergence of the simulation, we can look at the root mean squared deviation of the atomic positions over the course of the simulation. 

````{margin}
```{note}
Click on the legend to hide some of the lines!
```
````

In [15]:
# compute rmsd for different atom types
rmsds = md.rmsd(t, t, 0) * 10
bo = protein_topology.select("protein and (backbone and name O)")
ca = protein_topology.select("name CA")
rmsds_ca = md.rmsd(t, t, 0, atom_indices=ca) * 10  # Convert to Angstrom!
rmsds_bo = md.rmsd(t, t, 0, atom_indices=bo) * 10  # Convert to Angstrom!

rmsds = rmsds[::100]
rmsds_ca = rmsds_ca[::100]
rmsds_bo = rmsds_bo[::100]

# Create x data (simulation time)
x = [x / len(rmsds_ca) * simtime for x in range(0, len(rmsds_ca))]

# Make plot
fig = figure(
    plot_width=600,
    plot_height=400,
    title="RMSD of different atom types",
    x_axis_label="Simulation time in ns",
    y_axis_label="RMSD in angstrom, relative to first frame",
    sizing_mode="stretch_width",
    toolbar_location=None,
)
fig.line(
    x,
    rmsds,
    line_width=2,
    line_alpha=0.6,
    legend_label="all atoms",
    color="black",
    muted_alpha=0.1,
)
fig.line(
    x,
    rmsds_ca,
    line_width=2,
    line_alpha=0.6,
    legend_label="C-alpha atoms",
    color="blue",
    muted_alpha=0.1,
)
fig.line(
    x,
    rmsds_bo,
    line_width=2,
    line_alpha=0.6,
    legend_label="backbone O atoms",
    color="orange",
    muted_alpha=0.1,
)
fig.legend.click_policy = "mute"  #'hide'
show(fig)
# TODO: save rmsds as png, instead of manual screenshot https://docs.bokeh.org/en/latest/docs/user_guide/export.html

### Dihedral angles

In [16]:
if multi is not None:
    multi = {v: k for k, v in multi.items()}
    multiple = True
    distinction = compound.distinction
    print("Multiple compounds detected")
else:
    multiple = False
#     pickle_dump(snakemake.output.multiple, multiple)
if multiple:  # if Compound.cistrans:
    ca_c = t.top.select(f"resid {distinction[0]} and name CA C")
    n_ca_next = t.top.select(f"resid {distinction[1]} and name N CA")
    omega = np.append(ca_c, n_ca_next)
    t_omega_rad = md.compute_dihedrals(t, [omega])
    t_omega_deg = np.abs(np.degrees(t_omega_rad))
    plt.plot(t_omega_deg)
    plt.hlines(90, 0, t.n_frames, color="red")
    plt.xlabel("Frames")
    plt.ylabel("Omega 0-1 [°]")
    plt.title(f"Dihedral angle over time. Compound {compound_index}")
    cis = np.where(t_omega_deg <= 90)[0]
    trans = np.where(t_omega_deg > 90)[0]
#     pickle_dump(snakemake.output.multiple, (cis, trans))
    # t[trans]
# TODO: save dihedrals as png

````{margin}
```{note}
Click on the legend to hide some of the lines!
```
````

In [17]:
resnames = []
for i in range(0, t.n_residues):
    resnames.append(t.topology.residue(i))

*_, omega = src.dihedrals.getDihedrals(t)
omega_deg = np.abs(np.degrees(omega))

omega_deg = omega_deg[::100]

simtime = float(snakemake.wildcards.time)

colors = src.utils.color_cycle()

# Create x data (simulation time)
x = [x / len(omega_deg) * simtime for x in range(0, len(omega_deg))]

# Make plot
fig = figure(
    plot_width=600,
    plot_height=400,
    title="Omega dihedral angles over time",
    x_axis_label="Simulation time in ns",
    y_axis_label="Dihedral angle in ˚",
    sizing_mode="stretch_width",
    toolbar_location=None,
)

for res, i, col in zip(resnames, range(len(resnames)), colors):
    fig.line(
        x,
        omega_deg[:, i],
        line_width=2,
        line_alpha=0.6,
        legend_label=str(res),
        color=col,
        muted_alpha=0.1,
    )

fig.legend.click_policy = "mute"  #'hide'
show(fig)

In [18]:
# fig.plot.output_backend ="svg"
# export_svgs(fig, filename=snakemake.output.omega_plot)

In [19]:
# Compute dihedral angles [Phi] [Psi] [Omega]
phi, psi, omega = src.dihedrals.getDihedrals(t)

## Dimensionality Reductions
The simulation trajectories contain the positions of all atoms. This high dimensional data (3*N_atoms) is too complicated to analyse by itself. To get a feeling of the potential energy landscape we need to apply some kind of dimensionality reduction. Here, we apply the PCA (Principal Component Analysis) method.

### Dihedral PCA

In [20]:
pca_d, reduced_dihedrals = src.pca.make_PCA(t, "dihedral")
reduced_dihedrals_full = src.dihedrals.getReducedDihedrals(t)

# save pca object & reduced dihedrals
# pickle_dump(snakemake.output.dPCA, pca_d)
# pickle_dump(snakemake.output.dihedrals, reduced_dihedrals_full)
if not use_shortened:
    pickle_dump(
        snakemake.params.dihedrals_short,
        reduced_dihedrals_full[::stride_short],
    )

# reweighting:
if snakemake.params.method == "cMD":
    d_weights = reweight(reduced_dihedrals, None, "noweight")
else:
    d_weights = reweight(reduced_dihedrals, None, "amdweight_MC", weight_data)
if multiple:
    fig, axs = plt.subplots(
        1, 2, sharex="all", sharey="all", figsize=(6.7323, 3.2677)
    )
    axs[0] = src.pca.plot_PCA(
        reduced_dihedrals,
        "dihedral",
        compound_index,
        d_weights,
        "Energy [kcal/mol]",
        fig,
        axs[0],
        explained_variance=pca_d.explained_variance_ratio_[:2],
    )
    axs[1] = src.pca.plot_PCA_citra(
        reduced_dihedrals[cis],
        reduced_dihedrals[trans],
        "dihedral",
        compound_index,
        [multi["cis"] + " (cis)", multi["trans"] + " (trans)"],
        fig,
        axs[1],
    )
    fig.savefig(snakemake.output.pca_dihe, dpi=DPI)
else:
    fig, ax = plt.subplots(figsize=(3.2677, 3.2677))
    ax = src.pca.plot_PCA(
        reduced_dihedrals,
        "dihedral",
        compound_index,
        d_weights,
        "Energy [kcal/mol]",
        fig,
        ax,
        explained_variance=pca_d.explained_variance_ratio_[:2],
    )
    fig.tight_layout()
    fig.savefig(snakemake.output.pca_dihe, dpi=DPI)

In [21]:
# Check convergence threshold. Divide the PCA landscape into a grid, 
# track all occupied cells at the end of the simulation. Then track
# the fraction of occupied cells during the simulation, relative to the
# full trajectory

pca_values = reduced_dihedrals

# Grid size
grid_size = 10

# When a grid point counts as occupied
occupied_thresh = 1

# Block size (5% of total data, -> 20 blocks)
block_size = int(pca_values.shape[0] * 0.05)

# Number of blocks
num_blocks = pca_values.shape[0] // block_size

# Initialize a grid to hold occupied values for each block
occupied_grids = np.zeros((num_blocks, grid_size, grid_size))

# Get the minimum and maximum values for each dimension
min_val, max_val = np.min(pca_values, axis=0), np.max(pca_values, axis=0)

# Calculate grid cell size
cell_size = (max_val - min_val) / grid_size

# Calculate mean of the full simulation
mean_full = np.mean(pca_values, axis=0)

total_filled_cells = 0
for i in range(grid_size):
    for j in range(grid_size):
        cell_min = min_val + np.array([i * cell_size[0], j * cell_size[1]])
        cell_max = cell_min + cell_size
        
        final_cell_data = pca_values[(pca_values[:,0] >= cell_min[0]) & 
                                   (pca_values[:,0] < cell_max[0]) & 
                                   (pca_values[:,1] >= cell_min[1]) & 
                                   (pca_values[:,1] < cell_max[1])]
            
        if final_cell_data.shape[0] > occupied_thresh:
            total_filled_cells += 1
print(f'{total_filled_cells} grid cells are filled (out of the total {grid_size**2})')

# Calculate RMSD for each cell in the grid, for each block cumulatively
for block in range(num_blocks):
    block_data = pca_values[:(block+1)*block_size]
    
    for i in range(grid_size):
        for j in range(grid_size):
            # Define cell boundaries
            cell_min = min_val + np.array([i * cell_size[0], j * cell_size[1]])
            cell_max = cell_min + cell_size
            
            # Get data points in this cell, within this block
            cell_data = block_data[(block_data[:,0] >= cell_min[0]) & 
                                   (block_data[:,0] < cell_max[0]) & 
                                   (block_data[:,1] >= cell_min[1]) & 
                                   (block_data[:,1] < cell_max[1])]

            if cell_data.shape[0] > occupied_thresh:
                occupied_grids[block, i, j] = 1


# Calculate average occupancy per data point for each block
avg_occupancy = np.sum(occupied_grids, axis=(1,2)) / total_filled_cells #(np.arange(num_blocks)+1) / block_size

# Plot average RMSD as a function of block number
x_data = (np.array(range(num_blocks)) / 20 * simtime) + simtime/20
plt.plot(x_data, avg_occupancy, label='GaMD simulation')
plt.axhline(y=0.95, ls='--', color='red', label='95% of occupied cells visited')
plt.xlim([0,2000])
plt.xlabel('Cummulative simulation time (ns)')
plt.ylabel('Fraction of occupied d-PCA cells filled')
plt.legend()
plt.savefig(snakemake.output.conv_plot, dpi=DPI)

In [22]:
threshhold = 0.95
converged_block = np.where(avg_occupancy >= 0.95)[0][0] + 1
print(f'95% of the PCA landscape are visited within {converged_block / 20 * simtime} ns')
result = {'converged_block': int(converged_block), 'converged_time': int(converged_block / 20 * simtime)}
# save result as json
from src.utils import json_dump
json_dump(snakemake.output.conv_data, result)


In [23]:
from matplotlib import colors
fig, axs = plt.subplots(4,5)

# make a color map of fixed colors
cmap = colors.ListedColormap(['blue', 'yellow'])
bounds=[0,0.5,1]
norm = colors.BoundaryNorm(bounds, cmap.N)

# # tell imshow about color map so that only set colors are used
# img = plt.imshow(zvals, interpolation='nearest', origin='lower',
#                     cmap=cmap, norm=norm)

# # make a color bar
# plt.colorbar(img, cmap=cmap, norm=norm, boundaries=bounds, ticks=[0, 5, 10])


for i in range(occupied_grids.shape[0]):
    im = axs.flatten()[i].imshow(occupied_grids[i], interpolation='nearest', origin='lower',
                    cmap=cmap, norm=norm)
    axs.flatten()[i].set_title(f"<{int(x_data[i])} ns")
    
fig.tight_layout()
fig.colorbar(im, ax=axs, cmap=cmap, norm=norm, boundaries=bounds, ticks=[0, 1])
fig.savefig(snakemake.output.grid_cells, dpi=DPI)

## NOE

In [24]:
NOE = src.noe.read_NOE(snakemake.input.noe)
NOE_output = {}

In [25]:
if multiple:
    fig, axs = plt.subplots(2, 1, figsize=(6.7323, 3.2677))
    NOE_trans, NOE_cis = NOE
    NOE_cis_dict = NOE_cis.to_dict(orient="index")
    NOE_trans_dict = NOE_trans.to_dict(orient="index")
    if len(cis) > CIS_TRANS_CUTOFF:
        NOE_cis["md"], _, _2, NOE_dist_cis, _3 = src.noe.compute_NOE_mdtraj(
            NOE_cis_dict, t[cis]
        )

        NOE_output[f"{multi['cis']}"] = NOE_cis.to_dict(orient="index")
        # Deal with ambigous NOEs
        NOE_cis = NOE_cis.explode("md")
        # and ambigous/multiple values
        NOE_cis = NOE_cis.explode("NMR exp")
        fig, axs[1] = src.noe.plot_NOE(NOE_cis, fig, axs[1])
        axs[1].set_title(f"Compound {multi['cis']} (cis)")
    else:
        print("Cis skipped because no frames are cis.")
    if len(trans) > CIS_TRANS_CUTOFF:
        (
            NOE_trans["md"],
            _,
            _2,
            NOE_dist_trans,
            _3,
        ) = src.noe.compute_NOE_mdtraj(NOE_trans_dict, t[trans])

        NOE_output[f"{multi['trans']}"] = NOE_trans.to_dict(orient="index")
        # Deal with ambigous NOEs
        NOE_trans = NOE_trans.explode("md")
        # and ambigous/multiple values
        NOE_trans = NOE_trans.explode("NMR exp")

        fig, axs[0] = src.noe.plot_NOE(NOE_trans, fig, axs[0])
        axs[0].set_title(f"Compound {multi['trans']} (trans)")
    else:
        print("Trans skipped because no frames are cis")
else:
    NOE_dict = NOE.to_dict(orient="index")
    NOE["md"], _, _2, NOE_dist, _3 = src.noe.compute_NOE_mdtraj(NOE_dict, t)

    # Save NOE dict
    NOE_output = {f"{compound_index}": NOE.to_dict(orient="index")}
    # Deal with ambigous NOEs
    NOE = NOE.explode("md")
    # and ambigous/multiple values
    NOE = NOE.explode("NMR exp")
    fig, ax = src.noe.plot_NOE(NOE)
    ax.set_title(f"Compound {compound_index}. NOE without reweighting.", y=1.2)
fig.tight_layout()
# fig.savefig(snakemake.output.noe_plot, dpi=DPI)
# save as .json file
src.utils.json_dump(snakemake.output.noe_result, NOE_output)

In [26]:
# 1d PMF reweighted NOEs

NOE_output = {}

if snakemake.params.method != "cMD":
    if multiple:
        fig, axs = plt.subplots(2, 1, figsize=(6.7323, 6.7323))
        NOE_trans, NOE_cis = NOE
        NOE_cis_dict = NOE_cis.to_dict(orient="index")
        NOE_trans_dict = NOE_trans.to_dict(orient="index")
        if len(cis) > CIS_TRANS_CUTOFF:
            (
                NOE_cis["md"],
                NOE_cis["lower"],
                NOE_cis["upper"],
                NOE_dist_cis,
                pmf_plot_cis,
            ) = src.noe.compute_NOE_mdtraj(
                NOE_cis_dict,
                t[cis],
                reweigh_type=1,
                slicer=cis,
                weight_data=weight_data,
            )
            # TODO: this should not give an error!

            NOE_output[f"{multi['cis']}"] = NOE_cis.to_dict(orient="index")

            # Deal with ambigous NOEs
            NOE_cis = NOE_cis.explode(["md", "lower", "upper"])
            # and ambigous/multiple values
            NOE_cis = NOE_cis.explode("NMR exp")
            fig, axs[1] = src.noe.plot_NOE(NOE_cis, fig, axs[1])
            axs[1].set_title(f"Compound {multi['cis']} (cis)")
        else:
            print("Cis skipped because no frames are cis.")
        if len(trans) > CIS_TRANS_CUTOFF:
            (
                NOE_trans["md"],
                NOE_trans["lower"],
                NOE_trans["upper"],
                NOE_dist_trans,
                pmf_plot_trans,
            ) = src.noe.compute_NOE_mdtraj(
                NOE_trans_dict,
                t[trans],
                reweigh_type=1,
                slicer=trans,
                weight_data=weight_data,
            )

            NOE_output[f"{multi['trans']}"] = NOE_trans.to_dict(orient="index")
            # Deal with ambigous NOEs
            NOE_trans = NOE_trans.explode(["md", "lower", "upper"])
            # and ambigous/multiple values
            NOE_trans = NOE_trans.explode("NMR exp")
            fig, axs[0] = src.noe.plot_NOE(NOE_trans, fig, axs[0])
            axs[0].set_title(f"Compound {multi['trans']} (trans)")
        else:
            print("Trans skipped because no frames are cis")
        src.utils.json_dump(snakemake.output.noe_result, NOE_output)
        fig.tight_layout()
#         fig.savefig(snakemake.output.noe_plot)
    else:
        NOE = src.noe.read_NOE(snakemake.input.noe)
        NOE_dict = NOE.to_dict(orient="index")
        (
            NOE["md"],
            NOE["lower"],
            NOE["upper"],
            _,
            pmf_plot,
        ) = src.noe.compute_NOE_mdtraj(
            NOE_dict, t, reweigh_type=1, weight_data=weight_data
        )
        plt.close()
        # Save NOE dict
        NOE_output = {f"{compound_index}": NOE.to_dict(orient="index")}
        # save as .json file
        src.utils.json_dump(snakemake.output.noe_result, NOE_output)

        # Deal with ambigous NOEs
        NOE = NOE.explode(["md", "lower", "upper"])
        # and ambigous/multiple values
        NOE = NOE.explode("NMR exp")
        fig, ax = src.noe.plot_NOE(NOE)
        #         ax.set_title(f"Compound {compound_index}. NOE", y=1.5, pad=0)
        fig.tight_layout()
#         fig.savefig(snakemake.output.noe_plot, dpi=DPI)
else:
    print("cMD - no reweighted NOEs performed.")
# final_figure_axs.append(sg.from_mpl(fig))
# pickle_dump(snakemake.output.noe_dist, NOE_dist)

In [27]:
display(NOE)

In [28]:
# matplotlib.rcParams.update(matplotlib.rcParamsDefault)

if snakemake.params.method != "cMD":
    if not multiple:
        pmf_plot.suptitle("NOE PMF plots")
        pmf_plot.tight_layout()
        # pmf_plot.savefig(snakemake.output.noe_pmf)
        fig = pmf_plot
    else:
        # save to image data
        io_cis = io.BytesIO()
        io_trans = io.BytesIO()
        if len(cis) > CIS_TRANS_CUTOFF:
            pmf_plot_cis.savefig(io_cis, format="raw", dpi=pmf_plot_cis.dpi)
        if len(trans) > CIS_TRANS_CUTOFF:
            pmf_plot_trans.savefig(
                io_trans, format="raw", dpi=pmf_plot_trans.dpi
            )

        if len(cis) > CIS_TRANS_CUTOFF:
            io_cis.seek(0)
            img_cis = np.reshape(
                np.frombuffer(io_cis.getvalue(), dtype=np.uint8),
                newshape=(
                    int(pmf_plot_cis.bbox.bounds[3]),
                    int(pmf_plot_cis.bbox.bounds[2]),
                    -1,
                ),
            )
            io_cis.close()

        if len(trans) > CIS_TRANS_CUTOFF:
            io_trans.seek(0)
            img_trans = np.reshape(
                np.frombuffer(io_trans.getvalue(), dtype=np.uint8),
                newshape=(
                    int(pmf_plot_trans.bbox.bounds[3]),
                    int(pmf_plot_trans.bbox.bounds[2]),
                    -1,
                ),
            )
            io_trans.close()

        fig, axs = plt.subplots(2, 1)
        fig.set_size_inches(16, 30)
        if len(cis) > CIS_TRANS_CUTOFF:
            axs[0].imshow(img_cis)
            axs[0].axis("off")
            axs[0].set_title("cis")
        if len(trans) > CIS_TRANS_CUTOFF:
            axs[1].imshow(img_trans)
            axs[1].set_title("trans")
            axs[1].axis("off")
        # fig.suptitle('PMF plots. PMF vs. distance')
        fig.tight_layout()
        # fig.savefig(snakemake.output.noe_pmf, dpi=DPI)
else:
    fig, ax = plt.subplots()
    ax.text(0.5, 0.5, "not applicable.")
    # fig.savefig(snakemake.output.noe_pmf, dpi=DPI)
display(fig)

In [53]:
# Compute deviations of experimental NOE values to the MD computed ones
NOE_stats_keys = []
NOE_i = []
NOE_dev = {}

if multiple:
    if len(cis) > CIS_TRANS_CUTOFF:
        NOE_stats_keys.append("cis")
        NOE_i.append(NOE_cis)
    if len(trans) > CIS_TRANS_CUTOFF:
        NOE_stats_keys.append("trans")
        NOE_i.append(NOE_trans)
else:
    NOE_stats_keys.append("single")
    NOE_i.append(NOE)

for k, NOE_d in zip(NOE_stats_keys, NOE_i):
    if (NOE_d["NMR exp"].to_numpy() == 0).all():
        # if all exp values are 0: take middle between upper / lower bound as reference value
        NOE_d["NMR exp"] = (NOE_d["upper bound"] + NOE_d["lower bound"]) * 0.5

    # Remove duplicate values (keep value closest to experimental value)
    NOE_d["dev"] = NOE_d["md"] - np.abs(NOE_d["NMR exp"])
    NOE_d["abs_dev"] = np.abs(NOE_d["md"] - np.abs(NOE_d["NMR exp"]))

    NOE_d = NOE_d.sort_values("abs_dev", ascending=True)
    NOE_d.index = NOE_d.index.astype(int)
    NOE_d = NOE_d[~NOE_d.index.duplicated(keep="first")].sort_index(
        kind="mergesort"
    )

    NOE_d = NOE_d.dropna()
    NOE_dev[k] = NOE_d

In [54]:
# Compute NOE statistics
NOE_stats = {}

for k in NOE_stats_keys:
    NOE_d = NOE_dev[k]
    NOE_stats_k = pd.DataFrame(columns=["stat", "value", "up", "low"])

    MAE, upper, lower = src.stats.compute_MAE(NOE_d["NMR exp"], NOE_d["md"])
    append = {"stat": "MAE", "value": MAE, "up": upper, "low": lower}
    NOE_stats_k = NOE_stats_k.append(append, ignore_index=True)

    MSE, upper, lower = src.stats.compute_MSE(NOE_d["dev"])
    append = {"stat": "MSE", "value": MSE, "up": upper, "low": lower}
    NOE_stats_k = NOE_stats_k.append(append, ignore_index=True)

    RMSD, upper, lower = src.stats.compute_RMSD(NOE_d["NMR exp"], NOE_d["md"])
    append = {"stat": "RMSD", "value": RMSD, "up": upper, "low": lower}
    NOE_stats_k = NOE_stats_k.append(append, ignore_index=True)

    RMSD_step, upper, lower = src.stats.compute_RMSD_stepwise(NOE_d, NOE_d["NMR exp"], NOE_d["md"])
    append = {"stat": "RMSD_step", "value": RMSD_step, "up": upper, "low": lower}
    NOE_stats_k = NOE_stats_k.append(append, ignore_index=True)

    pearsonr, upper, lower = src.stats.compute_pearsonr(
        NOE_d["NMR exp"], NOE_d["md"]
    )
    append = {"stat": "pearsonr", "value": pearsonr, "up": upper, "low": lower}
    NOE_stats_k = NOE_stats_k.append(append, ignore_index=True)

    kendalltau, upper, lower = src.stats.compute_kendalltau(
        NOE_d["NMR exp"], NOE_d["md"]
    )
    append = {
        "stat": "kendalltau",
        "value": kendalltau,
        "up": upper,
        "low": lower,
    }
    NOE_stats_k = NOE_stats_k.append(append, ignore_index=True)

    chisq, upper, lower = src.stats.compute_chisquared(
        NOE_d["NMR exp"], NOE_d["md"]
    )
    append = {"stat": "chisq", "value": chisq, "up": upper, "low": lower}
    NOE_stats_k = NOE_stats_k.append(append, ignore_index=True)

    fulfilled = src.stats.compute_fulfilled_percentage(NOE_d)
    append = {
        "stat": "percentage_fulfilled",
        "value": fulfilled,
        "up": 0,
        "low": 0,
    }
    NOE_stats_k = NOE_stats_k.append(append, ignore_index=True)

    NOE_stats[k] = NOE_stats_k

In [55]:
# Compute statistics for most populated cluster
# if multiple:
#     NOE_stats_keys = ["cis", "trans"]
#     differentiation = {"cis": cis, "trans": trans}
# else:
#     NOE_stats_keys = ["single"]
# if multiple:
#     differentiation = {"cis": cis, "trans": trans}
# n_cluster_traj = {}
# n_cluster_percentage = {}
# n_cluster_index = {}
# remover = []
# for k in NOE_stats_keys:
#     if multiple:
#         # This checks that if an MD trajectory has bot cis/trans states, whether there is a cluster for both cis/trans
#         cluster_in_x = np.in1d(cluster_index, differentiation[k])
#         print(cluster_in_x)
#         if np.all(cluster_in_x == False):
#             # No clusters found for specific cis/trans/other
#             remover.append(k)
#     else:
#         cluster_in_x = np.ones((len(cluster_index)), dtype=bool)
#     cluster_in_x = np.arange(0, len(cluster_index))[cluster_in_x]
#     n_cluster_traj[k] = cluster_traj[cluster_in_x]
#     n_cluster_percentage[k] = np.array(cluster_percentage)[cluster_in_x]
#     n_cluster_index[k] = np.array(cluster_index)[cluster_in_x]
# cluster_traj = n_cluster_traj
# cluster_percentage = n_cluster_percentage
# cluster_index = n_cluster_index
# [NOE_stats_keys.remove(k) for k in remover]

In [56]:
# # Compute statistics for most populated cluster
# NOE_dict = {}
# NOE = src.noe.read_NOE(snakemake.input.noe)
# NOE_n = {}
# if multiple:
#     NOE_trans, NOE_cis = NOE
#     NOE_n["cis"] = NOE_cis
#     NOE_n["trans"] = NOE_trans
#     NOE_dict["cis"] = NOE_cis.to_dict(orient="index")
#     NOE_dict["trans"] = NOE_trans.to_dict(orient="index")
# else:
#     NOE_dict["single"] = NOE.to_dict(orient="index")
#     NOE_n["single"] = NOE


# for k in NOE_stats_keys:
#     # max. populated cluster
#     # NOE = NOE_n.copy()
#     max_populated_cluster_idx = np.argmax(cluster_percentage[k])
#     max_populated_cluster = cluster_traj[k][max_populated_cluster_idx]
#     NOE_n[k]["md"], *_ = src.noe.compute_NOE_mdtraj(
#         NOE_dict[k], max_populated_cluster
#     )
#     # Deal with ambigous NOEs
#     NOE_n[k] = NOE_n[k].explode("md")
#     # and ambigous/multiple values
#     NOE_n[k] = NOE_n[k].explode("NMR exp")

#     # Remove duplicate values (keep value closest to experimental value)
#     NOE_test = NOE_n[k]
#     if (NOE_test["NMR exp"].to_numpy() == 0).all():
#         # if all exp values are 0: take middle between upper / lower bound as reference value
#         NOE_test["NMR exp"] = (
#             NOE_test["upper bound"] + NOE_test["lower bound"]
#         ) * 0.5
#     NOE_test["dev"] = NOE_test["md"] - np.abs(NOE_test["NMR exp"])
#     NOE_test["abs_dev"] = np.abs(NOE_test["md"] - np.abs(NOE_test["NMR exp"]))

#     NOE_test = NOE_test.sort_values("abs_dev", ascending=True)
#     NOE_test.index = NOE_test.index.astype(int)
#     NOE_test = NOE_test[~NOE_test.index.duplicated(keep="first")].sort_index(
#         kind="mergesort"
#     )

#     # drop NaN values:
#     NOE_test = NOE_test.dropna()
#     # Compute metrics now
#     # Compute NOE statistics, since no bootstrap necessary, do a single iteration.. TODO: could clean this up further to pass 0, then just return the value...
#     RMSD, *_ = src.stats.compute_RMSD(
#         NOE_test["NMR exp"], NOE_test["md"], n_bootstrap=1
#     )
#     MAE, *_ = src.stats.compute_MAE(
#         NOE_test["NMR exp"], NOE_test["md"], n_bootstrap=1
#     )
#     MSE, *_ = src.stats.compute_MSE(NOE_test["dev"], n_bootstrap=1)
#     fulfil = src.stats.compute_fulfilled_percentage(NOE_test)
#     # insert values
#     values = [MAE, MSE, RMSD, None, None, None, fulfil]
#     NOE_stats[k].insert(4, "most-populated-1", values)

# # If there are no cis/trans clusters, still write a column 'most-populated-1', but fill with NaN
# for k in remover:
#     values = [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]
#     NOE_stats[k].insert(4, "most-populated-1", values)

In [58]:
for k in NOE_stats.keys():
    display(NOE_stats[k])
    # convert df to dict for export
    NOE_stats[k] = NOE_stats[k].to_dict()
# Save
src.utils.json_dump(snakemake.output.noe_stats, NOE_stats)

In [59]:
# plt.rc("font", size=MEDIUM_SIZE)  # controls default text sizes
# plt.rc("axes", titlesize=BIGGER_SIZE)  # fontsize of the axes title
# plt.rc("axes", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
# plt.rc("xtick", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
# plt.rc("ytick", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
# plt.rc("legend", fontsize=MEDIUM_SIZE)  # legend fontsize
# plt.rc("figure", titlesize=BIGGER_SIZE)  # fontsize of the figure title


# if multiple:
#     fig, axs = plt.subplots(2, 1)
#     if len(cis) > CIS_TRANS_CUTOFF:
#         # cis
#         axs[0].scatter(NOE_dev["cis"]["NMR exp"], NOE_dev["cis"]["md"])
#         axs[0].set_ylabel("MD")
#         axs[0].set_xlabel("Experimental NOE value")
#         axs[0].axline((1.5, 1.5), slope=1, color="black")
#         axs[0].set_title("Experimental vs MD derived NOE values - cis")

#     if len(trans) > CIS_TRANS_CUTOFF:
#         # trans
#         axs[1].scatter(NOE_dev["trans"]["NMR exp"], NOE_dev["trans"]["md"])
#         axs[1].set_ylabel("MD")
#         axs[1].set_xlabel("Experimental NOE value")
#         axs[1].axline((1.5, 1.5), slope=1, color="black")
#         axs[1].set_title("Experimental vs MD derived NOE values - trans")
#     fig.tight_layout()
#     fig.savefig(snakemake.output.noe_stat_plot)
# else:
#     plt.scatter(NOE_dev["single"]["NMR exp"], NOE_dev["single"]["md"])
#     if snakemake.params.method != "cMD":
#         plt.scatter(
#             NOE_dev["single"]["NMR exp"],
#             NOE_dev["single"]["upper"],
#             marker="_",
#         )
#         plt.scatter(
#             NOE_dev["single"]["NMR exp"],
#             NOE_dev["single"]["lower"],
#             marker="_",
#         )
#     plt.ylabel("MD")
#     plt.xlabel("Experimental NOE value")
#     plt.axline((1.5, 1.5), slope=1, color="black")
#     plt.title("Experimental vs MD derived NOE values")
#     plt.tight_layout()
#     plt.savefig(snakemake.output.noe_stat_plot)

In [None]:
# # is the mean deviation significantly different than 0? if pvalue < 5% -> yes! We want: no! (does not deviate from exp. values)
# if multiple:
#     if len(cis) > CIS_TRANS_CUTOFF:
#         print(stats.ttest_1samp(NOE_dev["cis"]["dev"], 0.0))
#     if len(trans) > CIS_TRANS_CUTOFF:
#         print(stats.ttest_1samp(NOE_dev["trans"]["dev"], 0.0))
# else:
#     print(stats.ttest_1samp(NOE_dev["single"]["dev"], 0.0))

In [60]:
# if multiple:
#     if len(cis) > CIS_TRANS_CUTOFF:
#         print(stats.describe(NOE_dev["cis"]["dev"]))
#     if len(trans) > CIS_TRANS_CUTOFF:
#         print(stats.describe(NOE_dev["trans"]["dev"]))
# else:
#     print(stats.describe(NOE_dev["single"]["dev"]))

In [61]:
print("Done")