# Comparison Analysis

In [2]:
import matplotlib

%matplotlib inline
# matplotlib.use("Agg")

import mdtraj as md
import numpy as np
import matplotlib.pyplot as plt
import scipy.cluster.hierarchy
from scipy.spatial.distance import squareform
import pandas as pd

sys.path.append(os.getcwd())
import src
from src.noe import compute_NOE_mdtraj, plot_NOE
from src.utils import json_load, dotdict

import src.pyreweight
from sklearn.manifold import TSNE
from sklearn.cluster import DBSCAN
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA

In [3]:
# read in stride from config file
stride = int(snakemake.config["stride"])
stride = 1

In [6]:
# load reference
ref_dihedrals = src.utils.pickle_load(snakemake.input.ref_dih[0])
ref_pca = src.utils.pickle_load(snakemake.input.ref_pca[0])
ref_pca_weights = src.utils.pickle_load(snakemake.input.ref_pca_weights[0])
ref_trafo_dihedrals = ref_pca.transform(ref_dihedrals)

In [8]:
trafo_dihedrals = {}
dihedrals = {}
for i, dihe in enumerate(snakemake.input.dihedrals):
    dihedrals[i] = src.utils.pickle_load(dihe)
    trafo_dihedrals[i] = ref_pca.transform(dihedrals[i])
    # plt.scatter(trafo_dihedrals[i][:,0], trafo_dihedrals[i][:,1], marker='.', s=0.5, alpha=0.1)
# plt.legend()
weight_data = {}
weights = {}
for i, wei in enumerate(snakemake.input.weights):
    # Reweighting:
    weight_data[i] = np.loadtxt(wei)
    weight_data[i] = weight_data[i][::stride]
    weights[i] = src.pyreweight.reweight(
        trafo_dihedrals[i], wei, "amdweight_MC", weight_data[i]
    )

# starting_strcts = {}
# for i, start_struct in enumerate(snakemake.input.starting_struct):
#     start_t = md.formats.AmberNetCDFRestartFile(start_struct, mode='r').read_as_traj(topology=snakemake.input.top[0]), #md.load_netcdf(start_struct, top=snakemake.input.top[0], frame=0)
#     start_dihedrals = src.dihedrals.getReducedDihedrals(start_t)
#     start_dihe_pca = ref_pca.transform(start_dihedrals)
#     starting_strcts[i] = start_dihe_pca

# plt.show
url = 'abcdc.com'
if url.endswith('.com'):
    url = url[:-4]

In [19]:
# retrieve cluster id's (starting points in ref. simulation) from input file names
ids = []
for id in snakemake.input.dihedrals:
    suffix = "_dihedrals.dat"
    if id.endswith(suffix):
        id = id[:-len(suffix)]
    ids.append(int(id.split("/")[-1]))
# [int(id.removesuffix("_dihedrals.dat").split("/")[-1]) for id in snakemake.input.dihedrals] # python 3.9
ids

In [26]:
start_structs = [ref_trafo_dihedrals[id] for id in ids]
min_ = 0
max_ = 8

figs_per_row = 4
import math

columns = math.ceil((len(trafo_dihedrals) + 1) / figs_per_row)
fig, axs = plt.subplots(columns, figs_per_row, sharex="all", sharey="all")
fig.set_size_inches(8, 8)
scat = {}
scat[0] = axs.flatten()[0].scatter(
    ref_trafo_dihedrals[:, 0],
    ref_trafo_dihedrals[:, 1],
    c=ref_pca_weights,
    marker=".",
    cmap="Spectral_r",
    s=0.5,
    vmin=min_,
    vmax=max_,
    rasterized=True
)
axs.flatten()[0].set_title("Reference")
for i in range(len(trafo_dihedrals)):
    scat[i + 1] = axs.flatten()[i + 1].scatter(
        trafo_dihedrals[i][:, 0],
        trafo_dihedrals[i][:, 1],
        c=weights[i],
        marker=".",
        cmap="Spectral_r",
        s=0.5,
        vmin=min_,
        vmax=max_,
        rasterized=True
    )
    axs.flatten()[i + 1].plot(
        start_structs[i][0], start_structs[i][1], marker="x", color="black"
    )

colorbar = fig.colorbar(
    scat[0], ax=axs, label="kcal/mol", location="left", anchor=(1.5, 0)
)
fig.savefig(snakemake.output.dih_pca_comparison, bbox_inches="tight", dpi=300)

In [None]:
# blue is post equilibrium starting point
# red is pre equilibrium starting point...