In [109]:
import pyemma as pm
import mdtraj as md
from molpx.generate import projection_paths
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
from typing import Dict, Union, List
from msmsense.featurizers import dihedrals, distances
from msmsense.bootstrap_cmatrices import get_sub_dict, get_trajs
from functools import partial

In [110]:
def get_feature_dict(df, row_num):
    row_dict = df.filter(regex='__', axis=1).loc[row_num, :].to_dict()
    feature_dict = get_sub_dict(row_dict, 'feature')
    if feature_dict['value'] == 'distances':
        feature_dict.update(get_sub_dict(row_dict, 'distances'))
    if feature_dict['value'] == 'dihedrals':
        feature_dict.update(get_sub_dict(row_dict, 'dihedrals'))
    return feature_dict

def get_kws_dict(df, row_num, kws):
    row_dict = df.filter(regex='__', axis=1).loc[row_num, :].to_dict()   
    kws_dict = get_sub_dict(row_dict, kws)
    return kws_dict

def set_proper_dtypes(df):
    """
    forgot to save integers as integers. Only the distances feature columns have true floats. 
    """
    potential_integer_cols = df.columns.difference(list(df.filter(regex='distances.*', axis=1)))
    for col in potential_integer_cols:
        if str(df[col].dtype) != 'object':
            df[col] = df[col].astype(int)
    return df

def get_trajs_top(traj_dir: Path, protein_dir: str):
    trajs = list(traj_dir.rglob(f"*{protein_dir.upper()}*/*.xtc"))
    trajs.sort()
    top = list(traj_dir.rglob(f"*{protein_dir.upper()}*/*.pdb"))[0]
    
    return {'trajs': trajs, 'top': top}
    

In [125]:
class MSM(object):
    
    def __init__(self, lag: int, num_procs: int, traj_top_paths: Dict[str,List[Path]], feature_kws: Dict[str, Union[str, int, float]], tica_kws: Dict[str, Union[str, int, float]], cluster_kws: Dict[str, Union[str, int, float]]):
        self.lag = lag
        self.num_proc = num_procs
        self.traj_top_paths = traj_top_paths
        self.feature_kws = feature_kws
        self.tica_kws = tica_kws
        self.cluster_kws = cluster_kws
        self.featurizer = None
        self._set_featurizer()
        
        self.tica = None
        self.cluster = None
        self.msm = None
        
        
    def _set_featurizer(self):
        feature_kws = self.feature_kws.copy()
        feature = feature_kws.pop('value')
        
        if feature == 'distances':
            self.featurizer = partial(distances, **feature_kws)
        elif feature == 'dihedrals':
            self.featurizer = partial(dihedrals, **feature_kws)
        else:
            raise NotImplementedError('Unrecognized feature')
        

    def fit(self):
        trajs = get_trajs(self.traj_top_paths)
        ftrajs = self.featurizer(trajs)
        self.tica = pm.coordinates.tica(data=ftrajs, **self.tica_kws)
        ttrajs = self.tica.get_output()
        self.cluster = pm.coordinates.cluster_kmeans(data=ttrajs, **self.cluster_kws)
        dtrajs = self.cluster.dtrajs
        self.msm = pm.msm.estimate_markov_model(dtrajs=dtrajs, lag=self.lag)

    def projection_paths(self, num_points: int=100):
        

        
        
            
        
    

In [126]:
mod_defs = set_proper_dtypes(pd.read_hdf('../results/best_hps_per_feature.h5', key='best_hps_per_feature'))

In [127]:
traj_dir = Path('/Volumes/REA/MD/12FF/strided/')

In [128]:
row_num = 0
protein_dir = mod_defs.protein_dir[row_num].upper()
lag = mod_defs.lag[row_num]
num_procs = mod_defs.k[row_num]

traj_dict = get_trajs_top(traj_dir, protein_dir)
feat_kws = get_feature_dict(mod_defs, row_num)
tica_kws = get_kws_dict(mod_defs, row_num, 'tica')
cluster_kws = get_kws_dict(mod_defs, row_num, 'cluster')

model = MSM(lag = lag, num_procs=num_procs, traj_top_paths=traj_dict, feature_kws=feat_kws, tica_kws=tica_kws, cluster_kws=cluster_kws)

In [129]:
model.fit()

Loaded 164 trajectories
Featurized trajectories


calculate covariances:   0%|                                                                                  …

TICA finished


getting output of TICA:   0%|                                                                                 …

initialize kmeans++ centers:   0%|                                                                            …

kmeans iterations:   0%|                                                                                      …

getting output of KmeansClustering:   0%|                                                                     …

clutered trajectories
Fitted MSM
