# Comparison Analysis

In [2]:
import matplotlib

%matplotlib inline
# matplotlib.use("Agg")

import mdtraj as md
import numpy as np
import matplotlib.pyplot as plt
import scipy.cluster.hierarchy
from scipy.spatial.distance import squareform
import pandas as pd

sys.path.append(os.getcwd())
import src.dihedrals
from src.noe import compute_NOE_mdtraj, plot_NOE
from src.utils import json_load, dotdict

import src.pyreweight
from sklearn.manifold import TSNE
from sklearn.cluster import DBSCAN
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA

In [3]:
# read in stride from config file
stride = int(snakemake.config["stride"])
stride = 1

In [4]:
compound_index = int(snakemake.wildcards.compound)
print(f"Analysing Compound {compound_index}")

In [5]:
pca = src.utils.pickle_load(snakemake.input.ref_md_dPCA)  # PCA(n_components=2)

dihe_ref_full = src.utils.pickle_load(snakemake.input.ref_md_red_dihe)
dihe_ref = pca.transform(dihe_ref_full)

weights_ref = src.utils.pickle_load(snakemake.input.ref_md_dPCA_weights)

# dihe[0] = pca.transform(dihe_all[0])
# dihe[0].shape
# # Now apply the same transformation to second and third
# dihe[1] = pca.transform(dihe_all[1])
# dihe[2] = pca.transform(dihe_all[2])

In [6]:
dihe = {}
shapes = {}
for i in range(6):
    confgen_t = md.load(snakemake.input[f"confgen{i+1}"])
    confgen_dihedrals = src.dihedrals.getReducedDihedrals(confgen_t)
    dihe[i] = pca.transform(confgen_dihedrals)

    # compute shape
    inertia_tensor = md.compute_inertia_tensor(confgen_t)
    principal_moments = np.linalg.eigvalsh(inertia_tensor)

    # Compute normalized principal moments of inertia
    npr1 = principal_moments[:, 0] / principal_moments[:, 2]
    npr2 = principal_moments[:, 1] / principal_moments[:, 2]
    shapes[i] = np.stack((npr1, npr2), axis=1)

In [7]:
# Plot re-weighted PCA plots
fig, axs = plt.subplots(2, 3, sharex="all", sharey="all")
fig.set_size_inches(12, 6)
scat = {}
for i in range(6):
    scat[i] = axs.flatten()[i].scatter(
        dihe_ref[:, 0],
        dihe_ref[:, 1],
        c=weights_ref,
        cmap="Spectral_r",
        marker=".",
        s=0.5,
        vmin=0,
        vmax=8,
        rasterized=True,
    )
    scat[i] = axs.flatten()[i].scatter(
        dihe[i][:, 0],
        dihe[i][:, 1],
        marker="s",
        s=6,
        color="black",
        label="Conformer generator",
    )  # , c=weights[i], marker='.', cmap='Spectral_r', s=0.5, vmin=min_, vmax=max_)
    axs.flatten()[i].set_title(snakemake.wildcards[f"confgen{i+1}"])
    # axs.flatten().[i].set_title(f"{methods[i]}: {simtime[i]} ns (r# {repeats[i]}).\n {solvent[i]}, {boosting[i]}")
    # axs.flatten().[3].scatter(dihe[i][:,0], dihe[i][:,1], marker='.', s=0.5, alpha=0.1, label=f"#{i}")


fig.suptitle(
    f"{snakemake.wildcards.confgen_method.capitalize()} conformer generator"
)
# if snakemake.wildcards.confgen != "0":
#     plt.scatter(dihe_red_conf[:,0], dihe_red_conf[:,1], marker='.', color='red', label=f"{snakemake.wildcards.confgen}-{snakemake.wildcards.mode}")

lgnd = axs.flatten()[2].legend(bbox_to_anchor=(1.05, 1), loc="upper left")
fig.savefig(snakemake.output.plot, bbox_inches="tight", dpi=300)
# for handle in lgnd.legendHandles:
#     handle.set_sizes([30.0])
#     handle.set_alpha(1)

In [8]:
# Load reference shapes
ref_shape = src.utils.pickle_load(snakemake.input.ref_md_shape)
ref_shape_weights = src.utils.pickle_load(snakemake.input.ref_md_weights)

In [9]:
# Plot shapes

import matplotlib.tri as tri

fig, axs = plt.subplots(2, 3, sharex="all", sharey="all")
fig.set_size_inches(12, 6)

# create the grid
corners = np.array([[1, 1], [0.5, 0.5], [0, 1]])
triangle = tri.Triangulation(corners[:, 0], corners[:, 1])

# creating the outline
refiner = tri.UniformTriRefiner(triangle)
outline = refiner.refine_triangulation(subdiv=0)

# creating the outline
refiner = tri.UniformTriRefiner(triangle)
trimesh = refiner.refine_triangulation(subdiv=2)

scat = {}
for i in range(6):
    scat[i] = axs.flatten()[i].scatter(
        ref_shape[:, 0],
        ref_shape[:, 1],
        c=ref_shape_weights,
        marker=".",
        cmap="Spectral_r",
        s=0.5,
        vmin=0,
        vmax=8,
        rasterized=True,
    )
    axs.flatten()[i].scatter(
        shapes[i][:, 0], shapes[i][:, 1], marker=".", c="black", s=4
    )
    axs.flatten()[i].set_xlabel(r"$I_{1}/I_{3}$")
    axs.flatten()[i].triplot(trimesh, "--", color="grey")
    axs.flatten()[i].triplot(outline, "k-")
    axs.flatten()[i].text(0, 1.01, "rod")
    axs.flatten()[i].text(0.85, 1.01, "sphere")
    axs.flatten()[i].text(0.44, 0.45, "disk")
    axs.flatten()[i].scatter(0, 1.05, alpha=0, s=0.1)
    axs.flatten()[i].scatter(1.05, 1.05, alpha=0, s=0.5)
    axs.flatten()[i].scatter(0.5, 0.45, alpha=0, s=0.5)
    axs.flatten()[i].axis("off")
    axs.flatten()[i].set_title(snakemake.wildcards[f"confgen{i+1}"])
#     axs[i].spines['right'].set_visible(False)
#     axs[i].spines['top'].set_visible(False)
#     axs[i].spines['left'].set_visible(False)
#     axs[i].spines['bottom'].set_visible(False)
# Only show ticks on the left and bottom spines
#     axs[i].yaxis.set_ticks_position('left')
#     axs[i].xaxis.set_ticks_position('bottom')
axs.flatten()[0].set_ylabel("$I_{2}/I_{3}$")

colorbar = fig.colorbar(
    scat[0], ax=axs, label="kcal/mol", location="left", anchor=(8.5, 0)
)
# fig.tight_layout()
fig.savefig(snakemake.output.shape_plot, bbox_inches="tight", dpi=300)

In [39]:
fig_to_plot = 0

In [58]:
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].scatter(
    dihe_ref[:, 0],
    dihe_ref[:, 1],
    c=weights_ref,
    cmap="Spectral_r",
    marker=".",
    s=0.5,
    vmin=0,
    vmax=8,
    rasterized=True,
)
axs[0].scatter(
    dihe[fig_to_plot][:, 0],
    dihe[fig_to_plot][:, 1],
    marker="s",
    s=6,
    color="black",
    label=f"{snakemake.wildcards.confgen_method.capitalize()} {snakemake.wildcards[f'confgen{fig_to_plot+1}']}",
)
axs[0].legend()  # bbox_to_anchor=(2.05,0), loc='upper left')

axs[0].set_xlabel("PC1")
axs[0].set_ylabel("PC2")

scat = axs[1].scatter(
    ref_shape[:, 0],
    ref_shape[:, 1],
    c=ref_shape_weights,
    marker=".",
    cmap="Spectral_r",
    s=0.5,
    vmin=0,
    vmax=8,
    rasterized=True,
)
axs[1].scatter(
    shapes[fig_to_plot][:, 0],
    shapes[fig_to_plot][:, 1],
    marker=".",
    c="black",
    s=4,
)
# axs[1].set_xlabel(r'$I_{1}/I_{3}$')
axs[1].triplot(trimesh, "--", color="grey")
axs[1].triplot(outline, "k-")
axs[1].text(0, 1.01, "rod")
axs[1].text(0.85, 1.01, "sphere")
axs[1].text(0.44, 0.45, "disk")
axs[1].scatter(0, 1.05, alpha=0, s=0.1)
axs[1].scatter(1.05, 1.05, alpha=0, s=0.5)
axs[1].scatter(0.5, 0.45, alpha=0, s=0.5)
axs[1].axis("off")
# axs[1].set_title(snakemake.wildcards[f"confgen{fig_to_plot+1}"])

# axs[1].set_ylabel('$I_{2}/I_{3}$')

colorbar = fig.colorbar(
    scat, ax=axs, label="kcal/mol", location="left", anchor=(8, 0)
)

fig.savefig(snakemake.output.single_comp_plot, bbox_inches="tight", dpi=300)