# Analysis

In [None]:
import itertools
from typing import List, Tuple, Sequence
import warnings

import holoviews as hv
import h5py
import matplotlib.pyplot as plt
from matplotlib import rc, ticker
from msmtools.flux import tpt
import numpy as np
from scipy.linalg import eig
import seaborn as sns

# Plot settings
sns.set_palette("husl", 8)
rc("font", **{"family": "Helvetica",
              "sans-serif": ["Helvetica"]})
rc("svg", **{"fonttype": "none"})
colors = sns.color_palette("husl", 8)
hv.extension("matplotlib")

warnings.filterwarnings('ignore')

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

## Utility functions

In [2]:
%run data.py

In [3]:
def unflatten(source: np.ndarray, lengths: List[int]) -> List[np.ndarray]:
    """
    Takes an array and returns a list of arrays.
    
    Parameters
    ----------
    source
        Array to be unflattened.
    lengths
        List of integers giving the length of each subarray.
        Must sum to the length of source.
    
    Returns
    -------
    unflat
        List of arrays.
    
    """
    conv = []
    lp = 0
    for arr in lengths:
        arrconv = []
        for le in arr:
            arrconv.append(source[lp:le + lp])
            lp += le
        conv.append(arrconv)
    ccs = list(itertools.chain(*conv))
    return ccs

In [4]:
def triu_inverse(x: np.ndarray, n: int) -> np.ndarray:
    """
    Converts flattened upper-triangular matrices into full symmetric matrices.
    
    Parameters
    ----------
    x
        Flattened matrices
    n
        Size of the n * n matrix
    
    Returns
    -------
    mat
        Array of shape (length, n, n)
    
    """
    length = x.shape[0]
    mat = np.zeros((length, n, n))
    a, b = np.triu_indices(n, k=1)
    mat[:, a, b] = x
    mat += mat.swapaxes(1, 2)
    return mat

In [5]:
def statdist(X: np.ndarray) -> np.ndarray:
    """
    Calculate the equilibrium distribution of a transition matrix.
    
    Parameters
    ----------
    X
        Row-stochastic transition matrix
    
    Returns
    -------
    mu
        Stationary distribution, i.e. the left
        eigenvector associated with eigenvalue 1.
    
    """
    ev, evec = eig(X, left=True, right=False)
    mu = evec.T[ev.argmax()]
    mu /= mu.sum()
    return mu

In [6]:
def unique_sorting(rmsd: np.ndarray) -> np.ndarray:
    """
    Sorts a matrix of RMSD values.
    
    Parameters
    ----------
    rmsd
        Array of shape (n, n) with interstate differences.
        This matrix should be acquired by calculating the RMSD
        between a reference decomposition and a trial decomposition.
    
    Returns
    -------
    sorter
        Array of sorted indices
    
    """
    size = rmsd.shape[0]
    
    # -1 is not yet assigned
    sorter = np.full(size, -1, dtype=np.int8)
    sorted_idx = rmsd.argsort(axis=None)
    
    # We walk through the sorted RMSDs from low to high and assign the 2D indices.
    # If one is already assigned, we just jump to the next one, which will be the next-lowest RMSD.
    for i, j in zip(*np.unravel_index(sorted_idx, (size, size))):
        if sorter[i] < 0 and j not in sorter:
            sorter[i] = j
    return sorter

In [7]:
def idx_to_traj(idx: int, lengths: List[int]) -> Tuple[int, int]:
    """
    Given a trajectory index, find the round and trajectory file number.
    
    Parameters
    ----------
    idx
        Trajectory index
    lengths
        Length of each round
    
    Returns
    -------
    round, number
        Simulation round and corresponding trajectory number
    
    """
    lengths = np.array(lengths)
    lcs = lengths.cumsum()
    if idx >= lengths[0]:
        nr = idx - lcs[lcs < idx][-1]
        i = np.arange(len(lengths))[lcs > idx][0]
    else:
        i, nr = 0, idx
    return i, nr

In [8]:
def renormalize(mat, tol=1e-12, axis=1):
    n = mat.shape[0]
    while abs(np.ones(n) - mat.sum(axis=axis)).max() > tol:
        mat = abs(mat) / abs(mat).sum(axis=axis)
    return mat

In [9]:
def plot_its(its, lags, dt=1.0):
    multi = its.ndim == 3
    nits, nlags = its.shape[-2], its.shape[-1]
    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_subplot(111)
    
    if multi:
        itsm = its.mean(axis=0)
        cfl, cfu = np.percentile(its, q=(2.5, 97.5), axis=0)
    else:
        itsm = its
    
    ax.semilogy(lags * dt, lags * dt, color="k")
    ax.fill_between(lags * dt, ax.get_ylim()[0] * np.ones(len(lags)),
                    lags * dt, color="k", alpha=0.2)
    for i in range(nits):
        ax.plot(lags * dt, itsm[i], marker="o",
                    linestyle="dashed", linewidth=1.5, color=colors[-(i + 2)])
        ax.plot(lags * dt, itsm[i], marker="o", linewidth=1.5, color=colors[-(i + 2)])
        if multi:
            ax.fill_between(lags * dt, cfl[i], cfu[i],
                            interpolate=True, color=colors[-(i + 2)], alpha=0.2)
    loc = ticker.LogLocator(base=10.0, subs=(0.2, 0.4, 0.6, 0.8), numticks=12)
    ax.set_ylim(1, 1000000)
    ax.set_yticks(10 ** np.arange(7))
    ax.yaxis.set_minor_locator(loc)
    ax.yaxis.set_minor_formatter(ticker.NullFormatter())
    ax.set_xlabel(r"$\tau$ (ns)", fontsize=24)
    ax.set_ylabel(r"$t_i$ (ns)", fontsize=24)
    ax.tick_params(labelsize=24)
    sns.despine(ax=ax)
    return fig

In [10]:
def plot_ck(cke, ckp, lag):
    multi = cke.ndim == 4
    n = cke.shape[-2]
    steps = cke.shape[-1]
    
    if multi:
        ckem = cke.mean(axis=0)
        ckpm = ckp.mean(axis=0)
        ckep = np.percentile(cke, q=(2.5, 97.5), axis=0)
        ckpp = np.percentile(ckp, q=(2.5, 97.5), axis=0)
    else:
        ckem = cke
        ckpm = ckp
    
    fig, axes = plt.subplots(n, n, figsize=(4 * n, 4 * n), sharex=True)
    for i in range(n):
        for j in range(n):
            ax = axes[i, j]
            x = np.arange(0, steps * lag, lag)
            if multi:
                ax.errorbar(x, ckpm[i, j], yerr=[ckpm[i, j] - ckpp[0, i, j], ckpp[1, i, j] - ckpm[i, j]],
                            linewidth=2, elinewidth=2)
                ax.fill_between(x, ckep[0, i, j], ckep[1, i, j],
                                alpha=0.2, interpolate=True, color=colors[1])
            else:
                ax.plot(x, ckpm[i, j], linestyle="-", color=colors[0], linewidth=2)
            ax.plot(x, ckem[i, j], linestyle="--", color=colors[1], linewidth=2)
            
            if i == j:
                ax.set_ylim(0.78, 1.02)
                ax.text(0, 0.8, r"{0} $\to$ {1}".format(i, j), fontsize=24, verticalalignment="center")
            else:
                ax.set_ylim(-0.02, 0.22)
                ax.text(0, 0.2, r"{0} $\to$ {1}".format(i, j), fontsize=24, verticalalignment="center")
            ax.set_xticks(np.arange(0, steps * lag, lag), minor=True)
            ax.set_xticks(np.arange(0, steps * lag, 2 * lag))
            ax.set_xticklabels((np.arange(0, steps * lag, 2 * lag) * dt).astype(int))
            ax.tick_params(labelsize=24)
    fig.text(0.5, 0.01 * 1.5 * n, r"$\tau$ [ns]", ha="center", fontsize=24)
    fig.text(0.01 * 1.5 * n, 0.5, r"$P$", va="center", rotation="vertical", fontsize=24)
    fig.subplots_adjust(wspace=0.25)
    return fig

## Data
### Trajectories
The conformational dynamics of K18 are investigated through the execution of 100 different MD simulations, lasting about 150 ns per simulation.

In [11]:
lengths = [np.array([1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501,
       1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501,
       1501, 1321, 1136,  890, 1501,  728, 1501,  625, 1214, 1240, 1501,
        727,  959, 1483,  995, 1339,  878,  994,  940, 1501, 1501,  965,
       1123, 1004,  951, 1172, 1501, 1501,  973, 1501,  428, 1501, 1501,
       1136, 1501, 1501, 1501, 1501, 1501,  607, 1501, 1501, 1501, 1501,
       1501,  542, 1083,  972, 1501, 1501, 1501, 1501, 1501, 1501, 1501,
       1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501,
       1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501, 1501,
       1501])]
nframes = 135998

## VAMPNet

### Data preparation
We use the dihedral data as input.

In [12]:
filename = "data/dihedrals.npy"
input_flat = np.load(filename)
input_data = unflatten(input_flat, lengths)
d_flat = np.load(filename)
ddata = unflatten(d_flat, lengths)

In [13]:
lag = 200                         # Lag time
n_dims = input_data[0].shape[1]  # Input dimension
nres = 125                        # Number of residues
dt = 0.1                        # Trajectory timestep in ns
steps = 6                        # CK test steps
bs_frames = 900000               # Number of frames in the bootstrap sample
data_source = "data/data_10000_dep3.hdf5"
attempts = 41

outsizes = np.array([2, 3, 4, 5, 6, 7])
lags = np.array([1, 2, 5, 10, 20, 50, 100, 200, 500, 1000])

bs_frames = nframes

# Analysis

## Neural network models

In [14]:
sorters = {n: np.empty((attempts, n), dtype=int) for n in outsizes}
pfs = {n: np.empty((attempts, nframes, n)) for n in outsizes}
pfsn = {n: np.empty((attempts, nframes, n)) for n in outsizes}
koops = {n: np.empty((attempts, n, n)) for n in outsizes}
pis = {n: np.empty((attempts, n)) for n in outsizes}
with h5py.File(data_source) as read:
    store = read["red"]
    for i in range(1,attempts):
        for n in outsizes:
            sorters[n][i] = store["{0}/{1}/sorter".format(i, n)][:]
            pfs[n][i] = store["{0}/{1}/full_sorted".format(i, n)][:]
            pfsn[n][i] = pfs[n][i] / pfs[n][i].sum(axis=0)
            koops[n][i] = store["{0}/{1}/k".format(i, n)][:][sorters[n][i]][:, sorters[n][i]]
            pis[n][i] = statdist(koops[n][i])

In [None]:
for n in [2,3,4,5,6]:
    global_sorter = np.arange(n)
    pm = pis[n][:, global_sorter]
    pm = pm[1:]
    
    fig = plt.figure(figsize=(n * 1, 4))
    ax = fig.add_subplot(111)
    for i in range(n):
        ax.plot(np.repeat([i + 1],40), pm[:, i],
                c=colors[i], marker="o", linewidth=0, alpha=0.5)

    bp = ax.boxplot(pm, sym="", whis=[5, 95], meanline=False, widths=0.5,
                    patch_artist=False, medianprops=dict(color="k"), showmeans=False,
                    meanprops=dict(marker="o", markersize=10, markeredgecolor="k",
                                markerfacecolor="white", alpha=0.5))

    ax.set_ylim(0, 1)
    ax.set_ylabel("Probability", fontsize=24, labelpad=10)
    ax.set_xlabel("State", fontsize=24, labelpad=10)
    ax.tick_params(labelsize=24)
    ax.tick_params(axis="x", length=0, pad=10)
    sns.despine(ax=ax)

    # plt.savefig("fig_pdf/pops-boxplot-{0}.pdf".format(n), transparent=True, bbox_inches="tight")

In [None]:
n = 5
global_sorter = np.arange(n)
pm = pis[n].mean(axis=0)[global_sorter]
pv = np.percentile(pis[n], q=(2.5, 97.5), axis=0)[:, global_sorter]
for i in range(n):
    fig = plt.figure(figsize=(n * 1, 4))
    ax = fig.add_subplot(111)
    cols = [(0.8, 0.8, 0.8)] * n
    cols[i] = colors[i]
    ax.bar(np.arange(n), pm, yerr=[pm - pv[0], pv[1] - pm], color=cols, capsize=8)
    ax.set_ylim(0, 1)
    ax.set_yticks([0.0, 0.5, 1.0])
    ax.set_xticks(np.arange(n))
    ax.set_ylabel("P", fontsize=32, labelpad=10)
    ax.tick_params(labelsize=32)
    ax.tick_params(axis="x", length=0, pad=10)
    sns.despine(ax=ax)

    # plt.savefig("figs_dpi800/pops-fine-{0}-{1}.png".format(n, i), dpi=800, transparent=True, bbox_inches="tight")

### Implied timescales

In [None]:
with h5py.File(data_source) as store:
    for n in [5]:
        its = np.stack(store["red/{0}/{1}/its".format(i, n)] for i in range(attempts))
        its = np.delete(its, -1, axis=-1)
        fig = plot_its(its, np.array([1, 2, 5, 10, 20, 50, 100, 200, 500]), dt=dt)
        x_position = 20
        plt.axvline(x=x_position, color='gray', linestyle='--')
        label = r"$\tau$=20ns"
        plt.text(x_position + 1, 100000, label, fontsize=18, color='black')

        # plt.savefig("fig_pdf/its-{0}.pdf".format(n), bbox_inches="tight", transparent=True)

### Chapman-Kolmogorov Test

In [22]:
def plot_ck_rev(cke, ckp, lag):
    
    multi = cke.ndim == 4
    n = cke.shape[-2]
    steps = cke.shape[-1]

    if multi:
        ckem = cke.mean(axis=0)
        ckpm = ckp.mean(axis=0)
        ckep = np.percentile(cke, q=(2.5, 97.5), axis=0)
        ckpp = np.percentile(ckp, q=(2.5, 97.5), axis=0)
    else:
        ckem = cke
        ckpm = ckp

    fig, axes = plt.subplots(n, n, figsize=(4 * n, 4 * n), sharex=True)

    i1 = 0
    i2 = 0

    for i in [4,2,3,1,0]:
        for j in [4,2,3,1,0]:    
            
            if i2 == 5:
                i1 = i1 + 1
                i2 = 0  
            if i1 == 5:
                i1 = 0
            ax = axes[i1, i2]
            x = np.arange(0, steps * lag, lag)
            if multi:
                ax.errorbar(x, ckpm[i, j], yerr=[ckpm[i, j] - ckpp[0, i, j], ckpp[1, i, j] - ckpm[i, j]],
                            linewidth=2, elinewidth=2)
                ax.fill_between(x, ckep[0, i, j], ckep[1, i, j],
                                alpha=0.2, interpolate=True, color=colors[1])
            else:
                ax.plot(x, ckpm[i, j], linestyle="-", color=colors[0], linewidth=2)
            ax.plot(x, ckem[i, j], linestyle="--", color=colors[1], linewidth=2)
            
            if i == j:
                ax.set_ylim(0.78, 1.02)
                ax.text(0, 0.8, r"{0} $\to$ {1}".format(i1+1, i2+1), fontsize=24, verticalalignment="center")
            else:
                ax.set_ylim(-0.02, 0.22)
                ax.text(0, 0.2, r"{0} $\to$ {1}".format(i1+1, i2+1), fontsize=24, verticalalignment="center")
            ax.set_xticks(np.arange(0, steps * lag, lag), minor=True)
            ax.set_xticks(np.arange(0, steps * lag, 2 * lag))
            ax.set_xticklabels((np.arange(0, steps * lag, 2 * lag) * dt).astype(int))
            ax.tick_params(labelsize=24)
            
            i2 = i2 + 1

    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], color=colors[0], lw=2, linestyle='-', label="Predicted $K^n(\\tau)$"),
        Line2D([0], [0], color=colors[1], lw=2, linestyle='--', label="Estimated $K(n\\tau)$"),
    ]
    fig.legend(handles=legend_elements, loc='lower center', ncol=len(legend_elements),bbox_to_anchor=(0.237,0.022),  fontsize=24)

    fig.text(0.5, 0.01 * 1.5 * n, r"$\tau$ (ns)", ha="center", fontsize=24)
    fig.text(0.01 * 1.5 * n, 0.5, r"$P$", va="center", rotation="vertical", fontsize=24)
    fig.subplots_adjust(wspace=0.25)
    return fig

In [None]:
with h5py.File(data_source) as store:
    for n in [5]:
        global_sorter = np.arange(n)
        cke = np.stack(store["red/{0}/{1}/cke".format(i, n)][:][sorters[n][i]][:, sorters[n][i]]
                       for i in range(1,attempts))[:, global_sorter][:, :, global_sorter]
        ckp = np.stack(store["red/{0}/{1}/ckp".format(i, n)][:][sorters[n][i]][:, sorters[n][i]]
                       for i in range(1,attempts))[:, global_sorter][:, :, global_sorter]
        fig = plot_ck_rev(cke, ckp, lag=200)

        # plt.savefig("fig_pdf/ck-{0}.pdf".format(n), bbox_inches="tight", transparent=True)

## Kinetics
### Koopman operators

In [None]:
n = 5
global_sorter = np.arange(n)
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
for ax, mat, title in zip(axes, (koops[n][1:].mean(axis=0)[global_sorter][:, global_sorter],
                                 koops[n][1:].std(axis=0)[global_sorter][:, global_sorter]),
                          ("$P$", r"$\sigma(P)$")):
# for ax, mat, title in zip(axes, (koops[n].mean(axis=0)[global_sorter][:, global_sorter],
#                                  koops[n].std(axis=0)[global_sorter][:, global_sorter]),
#                           ("$P$", r"$\sigma(P)$")):

    ax.matshow(mat, vmin=0.0, vmax=0.02, interpolation="none", cmap="GnBu")
    for i in range(n):
        for j in range(n):
            ax.text(j, i, "{0:2.4f}".format(mat[i, j]), ha="center", va="center", fontsize=12)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(title, fontsize=24)
    ax.tick_params(length=0)

    # plt.savefig("figs_dpi800/p-{0}.png".format(n), dpi=800, bbox_inches="tight", transparent=True)

### Mean first passage times

In [34]:
mfpts, rates = {}, {}
for n in outsizes:
    mfpts[n] = np.zeros((attempts, n, n))
    rates[n] = np.zeros((attempts, n, n))
    for i in range(1,attempts):
        for u in range(n):
            for v in range(n):
                if u == v:
                    continue
                koop = renormalize(koops[n][i])
                f = tpt(koop, [u], [v])
                rates[n][i, u, v] = f.rate
                mfpts[n][i, u, v] = f.mfpt * 200 * dt * 0.001

In [None]:
n = 5
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
global_sorter = np.arange(n)
for ax, mat, title in zip(axes, (mfpts[n][1:].mean(axis=0)[global_sorter][:, global_sorter],
                                 mfpts[n][1:].std(axis=0)[global_sorter][:, global_sorter]),
                          (r"$\mathrm{MFPT}$ [µs]", r"$\sigma(\mathrm{MFPT})$")):
# for ax, mat, title in zip(axes, (mfpts[n].mean(axis=0)[global_sorter][:, global_sorter],
#                                  mfpts[n].std(axis=0)[global_sorter][:, global_sorter]),
#                           (r"$\mathrm{MFPT}$ [µs]", r"$\sigma(\mathrm{MFPT})$")):
    im = ax.matshow(mat, vmin=0.0, vmax=60, interpolation="nearest", cmap="GnBu")
    for i in range(n):
        for j in range(n):
            ax.text(j, i, "{0:2.2f}".format(mat[i, j]), ha="center", va="center", fontsize=12)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(title, fontsize=24)
    ax.tick_params(length=0)

# plt.savefig("figs_dpi800/mfpt-red-tmp-{0}.png".format(n), dpi=800, bbox_inches="tight", transparent=True)

### Timescales

In [23]:
timescales = {}
with h5py.File(data_source) as read:
    store = read["red"]
    for n in outsizes:
        timescales[n] = np.stack(store["{0}/{1}/its".format(i, n)] for i in range(attempts))[:, ::-1] * 1e-3

In [None]:
for n in [5]:
    fig = plt.figure(figsize=(n - 1, 4))
    ax = fig.add_subplot(111)
    for i in range(n - 1):
        ax.plot(np.repeat([i + 1], attempts), timescales[n][:, i, -2],
                c=colors[1:][i], marker="o", linewidth=0, alpha=0.5)
    bp = ax.boxplot(timescales[n][:, :, -2], sym="", whis=[5, 95], meanline=False, widths=0.5,
                    patch_artist=False, medianprops=dict(color="k"), showmeans=False,
                    meanprops=dict(marker="o", markeredgecolor="k", markerfacecolor="white"))
    ax.set_ylim(0, 50)
    ax.set_yticks(np.arange(0,51,10))
    # ax.set_ylabel(r"$t_i$ [µs]", fontsize=24, labelpad=10)
    ax.set_ylabel(r"$t_{implied}$ (µs)", fontsize=24, labelpad=10)
    ax.set_xlabel(r"Timescale", fontsize=24, labelpad=10)
    ax.tick_params(labelsize=24)
    ax.tick_params(axis="x", length=0, pad=10)
    sns.despine(ax=ax)

    # plt.savefig("fig_pdf/timescales-boxplot-3-{0}.pdf".format(n), bbox_inches="tight", transparent=True)