In [1]:
import pyemma as pm
import mdtraj as md
from molpx.generate import projection_paths
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
from typing import Dict, Union, List
import msmsense.featurizers as feats
from msmsense.bootstrap_cmatrices import get_sub_dict, get_trajs
from functools import partial

%matplotlib inline

import pickle
import time



# Globals and functions

In [2]:
def distances(trajs, **kwargs):
    return feats.distances(trajs, scheme=kwargs.get('scheme'), transform=kwargs.get('transform'), steepness=kwargs.get('steepness'), centre=kwargs.get('centre'))


def dihedrals(trajs, **kwargs):
    return feats.dihedrals(trajs, which=kwargs.get('which'))


def get_feature_dict(df, row_num):
    row_dict = df.filter(regex='__', axis=1).loc[row_num, :].to_dict()
    feature_dict = get_sub_dict(row_dict, 'feature')
    if feature_dict['value'] == 'distances':
        feature_dict.update(get_sub_dict(row_dict, 'distances'))
    if feature_dict['value'] == 'dihedrals':
        feature_dict.update(get_sub_dict(row_dict, 'dihedrals'))
    return feature_dict

def get_kws_dict(df, row_num, kws):
    row_dict = df.filter(regex='__', axis=1).loc[row_num, :].to_dict()   
    kws_dict = get_sub_dict(row_dict, kws)
    return kws_dict

def set_proper_dtypes(df):
    """
    forgot to save integers as integers. Only the distances feature columns have true floats. 
    """
    potential_integer_cols = df.columns.difference(list(df.filter(regex='distances.*', axis=1)))
    for col in potential_integer_cols:
        if str(df[col].dtype) != 'object':
            df[col] = df[col].astype(int)
    return df

def get_trajs_top(traj_dir: Path, protein_dir: str, rng: Union[np.random.Generator, None]=None):
    trajs = list(traj_dir.rglob(f"*{protein_dir.upper()}*/*.xtc"))
    trajs.sort()
    if rng is not None:
        ix = rng.choice(np.arange(len(trajs)), size=len(trajs), replace=True)
        trajs = [trajs[i] for i in ix]
    
    top = list(traj_dir.rglob(f"*{protein_dir.upper()}*/*.pdb"))[0]
    
    return {'trajs': trajs, 'top': top}
    
    
def get_random_traj(trajs: List[md.Trajectory], num_frames: int, rng: np.random.Generator)-> md.Trajectory: 
    traj_ix = np.arange(len(trajs))
    frame_ix = [np.arange(traj.n_frames) for traj in trajs]
    
    rand_ix = [(ix, rng.choice(frame_ix[ix])) for ix in rng.choice(traj_ix, size=num_frames)]
    rand_traj = md.join([trajs[x[0]][x[1]] for x in rand_ix])
    return rand_traj
    

In [3]:
class MSM(object):
    
    def __init__(self, lag: int,  trajs: List[md.Trajectory], top: md.Trajectory,
                 feature_kws: Dict[str, Union[str, int, float]], tica_kws: Dict[str, Union[str, int, float]], cluster_kws: Dict[str, Union[str, int, float]], seed: int):
        """
        Defines the whole MSM pipeline.
        lag: markov lag time 
        num_evs: number of eigenvectors in VAMP score. This includes stationary distribution. note: all projections are done onto the processes 1 - num_evs, i.e., exclude the stationary distribution (process 0)
        traj_top_paths: dictionary with 'trajs' - list of Paths to trajectories, and 'top' Path to topology file. 
        
        """
        self.lag = lag
        self.trajs = trajs
        self.top = top
        
        if not isinstance(feature_kws, (list, tuple)): 
            self.feature_kws = [feature_kws]
        else:
            self.feature_kws = feature_kws
            
        self.tica_kws = tica_kws
        self.cluster_kws = cluster_kws
        self.featurizer = None
        self._set_featurizer()
        self.seed = seed

        self.ttrajs = None
        self.tica = None
        self.cluster = None
        self.msm = None
        self.paths = None
        
    def _set_featurizer(self):
        
        featurizers = []
        for feature_kws in self.feature_kws: 
            print(feature_kws)
            value = feature_kws.pop('value')
            if value == 'distances':
                featurizers.append(partial(distances, **feature_kws))
            elif value == 'dihedrals':
                featurizers.append(partial(dihedrals, **feature_kws))
            elif value == 'reciprocal_distances': 

                def f(trajs, **kwargs): 
                    dist_trajs = distances(trajs, **kwargs)
                    recip_trajs = [1/x for x in dist_trajs]
                    return recip_trajs

                featurizers.append(partial(f, **feature_kws))

            else:
                raise NotImplementedError('Unrecognized feature')
        
        def featurizer(trajs):
            ftrajs_all = None
            for featurizer in featurizers: 
                
                tmp = featurizer(trajs)
                if tmp[0].ndim == 1: 
                    tmp = [x.reshape(-1, 1) for x in tmp]
                    
                if ftrajs_all is None: 
                    ftrajs_all = tmp
                else:
                    for i in range(len(ftrajs_all)): 
                        ftrajs_all[i] = np.concatenate([ftrajs_all[i], tmp[i]], axis=1)
            return ftrajs_all
        
        self.featurizer = partial(featurizer)
        
                    
    def fit(self):
        ftrajs = self.featurizer(self.trajs)
        self.tica = pm.coordinates.tica(data=ftrajs, **self.tica_kws)
        ttrajs = self.tica.get_output()
        self.ttrajs = ttrajs
        self.cluster = pm.coordinates.cluster_kmeans(data=ttrajs, **self.cluster_kws, fixed_seed=self.seed)
        dtrajs = self.cluster.dtrajs
        self.msm = pm.msm.estimate_markov_model(dtrajs=dtrajs, lag=self.lag)
    
#     def score(self, use_new_trajs=False, new_trajs = None,  score_ks = [2, 3, 4]): 
#         if use_new_trajs:
#             ftrajs = self.featurizer(trajs)
#             ttrajs = self.tica.transform(ftrajs)
#             dtrajs = self.cluster.transform(ttrajs)
#         else: 
#             dtrajs = self.msm.discrete_trajectories_active
#         for k in score_ks: 
#             score = self.msm.
            
        

In [4]:
traj_dir = Path('/Volumes/REA/MD/12FF/strided/')
all_mod_defs = set_proper_dtypes(pd.read_hdf('../best_hps_per_feature.h5', key='best_hps_per_feature'))
seed = 12098345


In [5]:
protein_dir = '1FME'

# Introduction

This notebook does boostraps the CK Test for my four selected models


# Set model definitions. 

In [6]:
mod_defs = all_mod_defs.loc[all_mod_defs.protein_dir==protein_dir.lower(), :].copy()
mod_defs.sort_values(by=['protein', 'hp_rank'], inplace=True)
mod_defs.reset_index(inplace=True, drop=True)


model_definitions = {}

for row_num in [0, 1, 2, 3]:

    lag = mod_defs.lag[row_num]
    feat_kws = get_feature_dict(mod_defs, row_num)
    tica_kws = get_kws_dict(mod_defs, row_num, 'tica')
    cluster_kws = get_kws_dict(mod_defs, row_num, 'cluster')
    
    model_definitions[row_num] = dict(lag=lag, feature_kws=feat_kws, tica_kws=tica_kws, cluster_kws=cluster_kws)


# model_definitions[4] = dict(lag=1, 
#                         feature_kws=[dict(value='distances', scheme='closest-heavy', transform='linear'), 
#                                          dict(value='reciprocal_distances', scheme='closest-heavy', transform='linear'), 
#                                         dict(value='dihedrals', which='all')],
#                         tica_kws=dict(lag=10, var_cutoff=0.95, stride=1, commute_map=True, kinetic_map=False), 
#                         cluster_kws=dict(k=1000, max_iter=1000, stride=1))




# model_definitions[5] = dict(lag=50, 
#                        feature_kws=dict(value='distances', scheme='ca', transform='linear'), 
#                        tica_kws=dict(lag=46, dim=2, commute_map=False, kinetic_map=True, stride=1), 
#                        cluster_kws=dict(k=102, max_iter=1000, stride=1))


# Reselect lag and num dominant processes

The lag and num dominant processes were selected using bulk methods.  Let's re-select them here. 

In [8]:
traj_paths = get_trajs_top(traj_dir, protein_dir)
traj_paths_str = dict(top=str(traj_paths['top']), trajs=[str(x) for x in traj_paths['trajs']])

top = md.load(str(traj_paths['top']))
trajs = [md.load(str(x), top=top) for x in traj_paths['trajs']]

In [7]:

model_definition = model_definitions[row_num].copy()


for i in range(n_bs_samples):
#     Fit model
    model = MSM(trajs=trajs, top=top, seed=seed, **model_definition)
    model.fit()
    models[row_num] = model


SyntaxError: invalid syntax (2148907440.py, line 1)