In [59]:
import mdtraj as md
import pyemma as pm
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, List, Optional, Union, Tuple
from pathlib import Path
import pickle
from scipy.stats import entropy
import seaborn as sns

from msmtools.estimation import transition_matrix as _transition_matrix
from msmtools.analysis import timescales as _timescales

In [10]:
def featurizer(hp_dict: Dict, traj_paths: List[str], top_path: str) -> List[np.ndarray]:
    if hp_dict['feature__value'] == 'dihedrals':
        assert hp_dict['dihedrals__which'] == 'all'
        def f(traj: md.Trajectory, **kwargs) -> np.ndarray:
            _, phi = md.compute_phi(traj)
            _, psi = md.compute_psi(traj)
            _, chi1 = md.compute_chi1(traj)
            _, chi2 = md.compute_chi2(traj)
            _, chi3 = md.compute_chi3(traj)
            _, chi4 = md.compute_chi4(traj)
            _, chi5 = md.compute_chi5(traj)
            ftraj = np.concatenate([phi, psi, chi1, chi2, chi3, chi4, chi5], axis=1)
            ftraj = np.concatenate([np.cos(ftraj), np.sin(ftraj)], axis=1)
            return ftraj

    elif hp_dict['feature__value'] == 'distances':
        def f(traj: md.Trajectory, **kwargs):
            scheme = kwargs['distances__scheme']
            transform = kwargs['distances__transform']
            centre = kwargs['distances__centre']
            steepness = kwargs['distances__steepness']
            ftraj, _ = md.compute_contacts(traj, scheme=scheme)
            if transform=='logistic':
                ftraj = 1.0/(1+np.exp(-steepness*(ftraj - centre)))
            return ftraj
    else:
        raise ValueError
    ftrajs = []
    for traj_path in traj_paths:
        traj = md.load(traj_path, top=top_path)
        ftrajs.append(f(traj, **hp_dict))
    return ftrajs


def tica(hp_dict: Dict[str, Union[float, int, str]], ftrajs: List[np.ndarray]) -> List[np.ndarray]:
    lag = hp_dict['tica__lag']
    stride = hp_dict['tica__stride']
    dim = hp_dict['tica__dim']
    tica = pm.coordinates.tica(ftrajs, lag=lag, dim=dim, kinetic_map=True)
    ttrajs = tica.get_output()
    return ttrajs, tica

def kmeans(hp_dict: Dict, ttrajs: List[np.ndarray], seed: int) -> List[np.ndarray]:
    k = hp_dict['cluster__k']
    max_iter = hp_dict['cluster__max_iter']
    stride = hp_dict['cluster__stride']
    kmeans = pm.coordinates.cluster_kmeans(ttrajs, k=k, max_iter=max_iter, stride=stride, fixed_seed=seed, n_jobs=1)
    dtrajs = kmeans.dtrajs
    return dtrajs, kmeans


def its(dtrajs: List[np.ndarray], lags: List[int], nits: int) -> np.ndarray:
    its_obj = pm.msm.timescales_msm(dtrajs=dtrajs, lags=lags, nits=nits)
    return its_obj.timescales


def score(dtrajs: List[np.ndarray], lags: List[int], nits: int) -> np.ndarray:
    all_vs = []
    for lag in lags: 
        m = pm.msm.estimate_markov_model(dtrajs, lag=lag)
        vs = np.array([m.score(dtrajs, score_k=k) for k in range(2, nits+2)])
        vs = vs.reshape(1, -1)
        all_vs.append(vs)
    all_vs = np.concatenate(all_vs, axis=0)
    return all_vs
        


def bootstrap(ftrajs: List[np.ndarray], rng: np.random.Generator) -> List[np.ndarray]:
    probs = np.array([x.shape[0] for x in ftrajs])
    probs = probs/np.sum(probs)
    ix = np.arange(len(ftrajs))
    new_ix = rng.choice(ix,size=len(ftrajs), p=probs, replace=True)
    return [ftrajs[i] for i in new_ix]



def summarise(df):
    df_summary = df.groupby(['hp_ix', 'lag', 'process']).agg(median=(0, lambda x: np.quantile(x, 0.5)),
                                                                   lb=(0, lambda x: np.quantile(x, 0.025)),
                                                                   ub=(0, lambda x: np.quantile(x, 0.975)), 
                                                                   count =(0, lambda x: x.shape[0]-x.isna().sum()))
    return df_summary


def samples_to_summary(samples: np.ndarray, lags: List[int],  hp_ix: int)-> pd.DataFrame: 
    """
    samples=np.ndarray[lagtime, process, bs_sample]
    """
    df = pd.concat({(hp_ix, lags[i], j+2): pd.DataFrame(samples[i, j, :]) for i in range(samples.shape[0]) for j in range(samples.shape[1])})
    df.index.rename(('hp_ix', 'lag', 'process', 'bs_ix'), inplace=True)
    df_summary = summarise(df)
    return df_summary


def get_all_projections(msm: pm.msm.MaximumLikelihoodMSM, num_procs: int, dtrajs: List[np.ndarray]) -> List[np.ndarray]:
    """ Project dtrajs onto first num_proc eigenvectors excluding stationary distribution. i.e., if num_proc=1 then project onto the slowest eigenvector only. 
    All projections ignore the stationary distribution
    """
    evs = msm.eigenvectors_right(num_procs+1)
    active_set = msm.active_set
    NON_ACTIVE_PROJ_VAL = 0 # if the state is not in the active set, set the projection to this value. 
    NON_ACTIVE_IX_VAL = -1
    evs = evs[:, 1:] # remove the stationary distribution
    proj_trajs = []
    for dtraj in dtrajs:
        all_procs = []
        for proc_num in range(num_procs):

            tmp = np.ones(dtraj.shape[0], dtype=float)
            tmp[:] = NON_ACTIVE_PROJ_VAL

            for i in range(dtraj.shape[0]):
                x = msm._full2active[dtraj[i]]
                if x != NON_ACTIVE_IX_VAL:
                    tmp[i] = evs[x, proc_num]
                tmp = tmp.reshape(-1, 1)

            all_procs.append(tmp)
        all_procs = np.concatenate(all_procs, axis=1)
        proj_trajs.append(all_procs)

    return proj_trajs

In [18]:
protein = '1fme'
hp_ix = 81
n_procs = 2
lag = 41

seed = 49587
rng = np.random.default_rng(seed)
n_bootstraps = 100
nits=25

hps = pd.read_hdf('../data/msms/hpsample.h5')
top_path = f'/home/rob/Data/DESRES/DESRES-Trajectory_{protein.upper()}-0-protein/{protein.upper()}-0-protein/protein.pdb'
traj_paths = list(Path('/home/rob/Data/DESRES/').rglob(f'*{protein.upper()}*/**/*.xtc'))
traj_paths = [str(x) for x in traj_paths]
traj_paths.sort()
assert traj_paths

In [36]:

ftrajs_all = featurizer(hps.iloc[hp_ix, :].to_dict(), traj_paths, top_path)

# for i in range(n_bootstraps):
#     print(i, end=', ')
ftrajs = bootstrap(ftrajs_all, rng)
assert len(ftrajs) == len(ftrajs_all)
ttrajs, tica_mod = tica(hps.iloc[hp_ix, :].to_dict(), ftrajs)
dtrajs, kmeans_mod = kmeans(hps.iloc[hp_ix, :].to_dict(), ttrajs, seed)
mod = pm.msm.estimate_markov_model(dtrajs, lag=lag)
ptrajs = get_all_projections(mod, n_procs, dtrajs)

In [37]:
index = pd.MultiIndex.from_tuples([(traj_paths[i], j) for i in range(len(traj_paths)) for j in range(ptrajs[i].shape[0])])



In [53]:
ptrajs_all = np.concatenate(ptrajs, axis=0)
ptrajs_df = pd.DataFrame(ptrajs_all, index=index, columns=[f"psi_{i+2}" for i in range(n_procs)])

In [55]:
def mixing_ent(x): 
    x = np.abs(x)
    return entropy(x)

ptrajs_df['mixing'] = ptrajs_df.apply(mixing_ent, axis=1)

In [78]:
ptrajs_df['psi_2_cat'] = pd.qcut(ptrajs_df['psi_2'], q=50)
ptrajs_df['psi_3_cat'] = pd.qcut(ptrajs_df['psi_3'], q=50)

In [87]:
list(ptrajs_df.query('mixing<0.6').groupby('psi_2_cat').sample().index)

[('/home/rob/Data/DESRES/DESRES-Trajectory_1FME-0-protein/1FME-0-protein/1FME-0-protein-034.xtc',
  493),
 ('/home/rob/Data/DESRES/DESRES-Trajectory_1FME-0-protein/1FME-0-protein/1FME-0-protein-034.xtc',
  1218),
 ('/home/rob/Data/DESRES/DESRES-Trajectory_1FME-0-protein/1FME-0-protein/1FME-0-protein-071.xtc',
  1149),
 ('/home/rob/Data/DESRES/DESRES-Trajectory_1FME-1-protein/1FME-1-protein/1FME-1-protein-022.xtc',
  448),
 ('/home/rob/Data/DESRES/DESRES-Trajectory_1FME-0-protein/1FME-0-protein/1FME-0-protein-055.xtc',
  1208),
 ('/home/rob/Data/DESRES/DESRES-Trajectory_1FME-1-protein/1FME-1-protein/1FME-1-protein-003.xtc',
  1678),
 ('/home/rob/Data/DESRES/DESRES-Trajectory_1FME-0-protein/1FME-0-protein/1FME-0-protein-055.xtc',
  1886),
 ('/home/rob/Data/DESRES/DESRES-Trajectory_1FME-0-protein/1FME-0-protein/1FME-0-protein-033.xtc',
  1196),
 ('/home/rob/Data/DESRES/DESRES-Trajectory_1FME-0-protein/1FME-0-protein/1FME-0-protein-038.xtc',
  1306),
 ('/home/rob/Data/DESRES/DESRES-Traject

In [7]:
# df = target_ts.copy(deep=True)

# df = df.droplevel(level=0)

# for i in range(100):
#     df2 = source_ts.copy(deep=True)
#     df2 = df2.loc[(i, slice(None), slice(None)), :]
#     df2 = df2.droplevel(level=0)
#     tmp = df.merge(df2, left_index=True, right_index=True, how='left')
#     # print(tmp.head())
#     diff = np.abs(tmp['median_x']/tmp['median_y']).mean()
#     print(i, diff)
    