# MSM workflow
###### Thomas, T.; Yuriev, E; Chalmers, D. K. "MarkovState Model Analysis of Haloperidol Binding to the D3 Dopamine Receptor."

##### Choose working directory

In [None]:
workdir = '/home/msm_data/'
%cd /home/msm_data

In [None]:
import numpy as np

import mdtraj as md
from mdtraj.geometry import compute_distances, compute_angles

import os,sys,shutil
import re

from msmbuilder.cluster import *
from msmbuilder.msm import *
from msmbuilder.dataset import dataset
from msmbuilder.featurizer import *
from msmbuilder import utils
from msmbuilder.decomposition import tICA, PCA
from msmbuilder.lumping import *
from msmbuilder.tpt import mfpts
from msmbuilder import hmm

from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.patches import Ellipse
#plot graphs inline, e.g. don't need plot.show()
%matplotlib inline

from sklearn.utils import check_random_state
from sklearn import mixture
from itertools import product, chain

##### Choose dataset location (relative to working directory)

In [None]:
'''make sure any molecules of interest in system.pdb are whole.
dcd files may be faster to read than xtc due to compression'''
ds = dataset('center_fit/.dcd', topology='system.pdb')

In [None]:
'''Fitting trajectories is important for analysis, but raw (non-fit)
trajectories are necessary as starting structures for subsquent runs'''
ds_nofit = dataset('no_fit/*.dcd', topology='system.pdb')

##### Guess pass number based on existing folders

In [None]:
next_pass = 1
while os.path.exists("pass"+str(next_pass).zfill(2)):
    next_pass += 1
this_pass = next_pass-1
pass_folder = "pass"+str(this_pass).zfill(2)+"/"
print(next_pass, this_pass, pass_folder)

# OR

##### Specify manually

In [None]:
pass_folder = "./pass01/"
next_pass, this_pass= None, None

# Featurization

In [None]:
def make_featurization(output_folder=None, ds=None, featurizer=None):
    '''Generic featurization function
    featurizes trajectories in chunks to lower memory usage
    smaller chunks use less memory but featurize slower'''
    out_dataset = ds.create_derived(output_folder, fmt='dir-npy')

    for key in ds.keys():
        trajectory = []
        for i, chunk in enumerate(ds.iterload(key, chunk=10000)):
            trajectory.append(featurizer.partial_transform(chunk))
        out_dataset[key] = np.concatenate(trajectory)
        out_dataset.close()

## Protein

In [None]:
'''It is safest to process trajectories and verify results before loading into MSMBuilder.
AtomIndices_ca.dat should be 0-indexed, 1 number per line.
This can be conveniently extracted from gromacs ndx files
Alternatively pass a list from mdtraj, e.g. traj.topology.select("name CA")
'''
ref_traj = dataset('system.pdb')[0]
featurizer = RawPositionsFeaturizer(atom_indices=np.loadtxt('AtomIndices_ca.dat',dtype=int), ref_traj=ref_traj)
make_featurization(pass_folder+"protein_rawpos_featurized",featurizer=featurizer, ds=ds)

In [None]:
protein_featurized = dataset(pass_folder+"protein_rawpos_featurized")

## Ligand

In [None]:
'''Contacts.dat is 0 indexed, 1 pair per line
This can be conveniently extracted from gromacs ndx files
Alternatively pass a list from mdtraj'''
featurizer = ContactFeaturizer(contacts=np.loadtxt('Contacts.dat'), scheme='closest-heavy', ignore_nonprotein=False)
make_featurization(pass_folder+"ligand_contact_featurized",featurizer=featurizer, ds=ds)

### Convert to binary

In [None]:
parent_dataset = dataset(pass_folder+"ligand_contact_featurized")
output_dataset = parent_dataset.create_derived(pass_folder+'bin_contact_featurized')
new_key = 0
for key in parent_dataset.keys():
    output_dataset[new_key] = np.where(parent_dataset[key] > 0.6, 0., 1.)
    output_dataset.close()
    new_key += 1

In [None]:
bin_contact_featurized = dataset(pass_folder+'bin_contact_featurized')

### Hbond featurization

In [None]:
def calc_hbonds_fingerprint(traj, input_triplets=None, periodic=True):
'''
Currently only works for a provided list of hbonds (input_triplets)
'''
    distance_cutoff = 0.25            # nanometers
    angle_cutoff = 2.0 * np.pi / 3.0  # radians

    if traj.topology is None:
        raise ValueError('baker_hubbard requires that traj contain topology '
                         'information')

    angle_triplets = input_triplets
    distance_pairs = angle_triplets[:, [1,2]]  # possible H..acceptor pairs

    angles = compute_angles(traj, angle_triplets, periodic=periodic)
    distances = compute_distances(traj, distance_pairs, periodic=periodic)
    mask = np.logical_and(distances < distance_cutoff, angles > angle_cutoff)

    return np.where(mask,1,0)

In [None]:
def populate_bonds(itp_file, topology):
    """
    Read the bonds from an itp file and apply them to a similarly
    indexed MDTraj topology (e.g. both itp and topology should be ligand only)
    """
    assert topology.n_bonds == 0,"topology already has bonds"
    with open(itp_file) as f:
        #do nothing until [ bonds ]
        for line in f:
            if line.strip() == '[ bonds ]':
                break
        for line in f:
            #skip blank lines or lines starting with ;
            if ((not line.strip()) or (line[0] == ';')):
                continue
            #stop when reaching [ pairs ]
            if line.strip() == '[ pairs ]':
                break
            atom1,atom2 = line.split()[0:2]
            topology.add_bond(topology.atom(int(atom1)-1),topology.atom(int(atom2)-1))
    return topology     

In [None]:
def featurize_hbonds(ds=None, ligresname=None, chosen_hbonds=None):
    hbonds_featurized = dataset(pass_folder+'hbonds_featurized', mode='a', fmt='dir-npy')
    for key in ds.keys():   
        print(key,)
        trajectory = []
        for i, traj in enumerate(ds.iterload(key, chunk=4001)):
            ligand_top  = traj.topology.subset(traj.topology.select(ligresname)
            protein_top = traj.topology.subset(traj.topology.select('not '+ligresname))
            ligand_top = populate_bonds(itpfile, ligand_top)
            traj.topology = protein_top.join(ligand_top)
            trajectory.append(calc_hbonds_fingerprint(traj, input_triplets=chosen_hbonds))

        hbonds_featurized[key] = np.concatenate(trajectory)
    return hbonds_featurized

In [None]:
"""
Use this cell to manually specify hydrogen bonds (by index)
"""

#Indices for all acceptor/donor combinations between ligand and inwardly oriented non-loop residues
#Protonated amine is excluded as an acceptor
#No back-bone hydrogen bonds as loops are excluded
chosen_hbonds = np.vstack([[57, 58, 2730],
 [57, 58, 2766],
 [1547, 1548, 2730],
 [1547, 1548, 2766],
 [1555, 1556, 2730],
 [1555, 1556, 2766],
 [1579, 1580, 2730],
 [1579, 1580, 2766],
 [2193, 2194, 2730],
 [2193, 2194, 2766],
 [2220, 2221, 2730],
 [2220, 2221, 2766],
 [2220, 2222, 2730],
 [2220, 2222, 2766],
 [2353, 2354, 2730],
 [2353, 2354, 2766],
 [2361, 2362, 2730],
 [2361, 2362, 2766],
 [2384, 2385, 2730],
 [2384, 2385, 2766],
 [2438, 2439, 2730],
 [2438, 2439, 2766],
 [2730, 2729, 57],
 [2730, 2729, 559],
 [2730, 2729, 560],
 [2730, 2729, 761],
 [2730, 2729, 762],
 [2730, 2729, 1547],
 [2730, 2729, 1555],
 [2730, 2729, 1579],
 [2730, 2729, 2188],
 [2730, 2729, 2193],
 [2730, 2729, 2219],
 [2730, 2729, 2220],
 [2730, 2729, 2353],
 [2730, 2729, 2361],
 [2730, 2729, 2384],
 [2730, 2729, 2438],
 [2739, 2737, 57],
 [2739, 2737, 559],
 [2739, 2737, 560],
 [2739, 2737, 761],
 [2739, 2737, 762],
 [2739, 2737, 1547],
 [2739, 2737, 1555],
 [2739, 2737, 1579],
 [2739, 2737, 2188],
 [2739, 2737, 2193],
 [2739, 2737, 2219],
 [2739, 2737, 2220],
 [2739, 2737, 2353],
 [2739, 2737, 2361],
 [2739, 2737, 2384],
 [2739, 2737, 2438]])

In [None]:
hbonds_featurized = featurize_hbonds(ds=ds, ligresname='resname LIG', chosen_hbonds=chosen_hbonds)

### Combine ligand features

In [None]:
output_dataset = bin_contact_featurized.create_derived(pass_folder+'bin_contact_hbond_featurized')
for key in bin_contact_featurized.keys():
    output_dataset[key] = np.column_stack((bin_contact_featurized[key],hbonds_featurized[key]))
    output_dataset.close()

In [None]:
ligand_featurized = dataset(pass_folder+'/bin_contact_hbond_featurized')

# tICA Protein

In [None]:
'''Choice of lag time is generally robust, take your best guess.
Overestimate n_components and then decide correct number based on graphs'''

In [None]:
tica_prot = tICA(n_components=10,lag_time=10).fit(protein_featurized)
tica_prot_transformed = tica_prot.transform(protein_featurized)

In [None]:
plt.plot(tica_prot.timescales_, 'o-')

In [None]:
plt.plot(tica_prot.timescales_[:-1]/tica_prot.timescales_[1:], 'o-')

In [None]:
tica_prot = tICA(n_components=1,lag_time=10).fit(protein_featurized)
tica_prot_transformed = tica_prot.transform(protein_featurized)

# tICA Ligand

In [None]:
'''Choice of lag time is generally robust, take your best guess.
Overestimate n_components and then decide correct number based on graphs'''

In [None]:
tica_lig = tICA(n_components=10,lag_time=10).fit(ligand_featurized)
tica_lig_transformed = tica_lig.transform(ligand_featurized)

In [None]:
plt.plot(tica_lig.timescales_, 'o-')

In [None]:
plt.plot(tica_lig.timescales_[:-1]/tica_lig.timescales_[1:], 'o-')

In [None]:
tica_lig = tICA(n_components=3,lag_time=10).fit(ligand_featurized)
tica_lig_transformed = tica_lig.transform(ligand_featurized)

# Combine tICAs

In [None]:
tica_transformed = []
for dims, more_dims in zip(tica_lig_transformed, tica_prot_transformed):
    tica_transformed.append(np.column_stack((dims, more_dims)))

In [None]:
print(tica_transformed[0].shape)

In [None]:
utils.dump(tica_transformed, pass_folder+"tica_transformed")

In [None]:
tica_transformed = utils.load(pass_folder+"tica_transformed")

# Clustering

##### Choose either Hierarchical Agglomerative Clustering (Much slower, high memory usage) or KMeans (very fast, slightly worse) 

In [None]:
'''clustered is equivalent to clustering.labels_'''
clustering = KMeans(n_clusters=20, n_jobs=-1).fit(np.asarray(tica_transformed))
clustered = clustering.transform(tica_transformed)

In [None]:
utils.dump(clustered, pass_folder+"clustered")
utils.dump(clustering, pass_folder+"clustering")

In [None]:
clustering = utils.load(pass_folder+"clustering")
clustered = utils.load(pass_folder+"clustered")

In [None]:
def plot_clusters(clustering=None, tica_transformed=None, chosen_states='all'):
    '''plots clusters as binned histograms of the distance in tica-space of
    each cluster member to the cluster center
    Useful to check cluster quality'''
    tica_transformed = np.asarray(tica_transformed)
    states_done = 0
    n_states = max(map(lambda x: max(x), clustering.labels_)) + 1
    n_states_2 = len(np.unique(np.concatenate(clustering.labels_)))
    assert n_states == n_states_2, "Must have non-empty, zero-indexed, consecutive states: found %d states and %d unique states." % (n_states, n_states_2)
                                                                                                                                                           
    if chosen_states == 'all':
        chosen_states = range(n_states)
        
    def calc_pair_deviation(clustering=clustering, state=None, tica_transformed=tica_transformed):
        all_frames = [np.where(a == state)[0] for a in clustering.labels_]
        pairs = [(trj, frame) for (trj, frames) in enumerate(all_frames) for frame in frames]
        pair_deviation = [np.sqrt(np.sum(np.subtract(tica_transformed[pair[0]][pair[1]], clustering.cluster_centers_[state])**2)) for pair in pairs]
        return pair_deviation
    
    fig, sandwich = plt.subplots(figsize=(18,6))
    fig.delaxes(fig.axes[0])
    for state in chosen_states:
        columns = 3
        plot_num = len(fig.axes)+1
        rows = (plot_num-1)/columns+1
        
        if plot_num%columns == 1 and plot_num > columns:
            plt.show()
            fig, sandwich = plt.subplots(figsize=(18,6))
            fig.delaxes(fig.axes[0])
            rows = 1
            plot_num = 1
            
            #fig.set_size_inches(fig.get_figwidth(), (fig.get_figheight()*(float(rows)/(rows-1))))        
        
        subplot = fig.add_subplot(rows, columns, plot_num)

        #for i in range(plot_num):
        #    fig.axes[i].change_geometry(rows, columns , i+1)

        pair_deviation = calc_pair_deviation(state=state)
        subplot.set_title("State"+str(state))
        subplot.set_xlim(0,2)
        subplot.hist(pair_deviation, bins=20)
        
    plt.show()

In [None]:
plot_clusters(clustering, tica_transformed)

# Construct MSM

In [None]:
def plot_implied_timescales(clustering=None, n_timescales=10, max_lagtime=None, log=True, n_points=10):
    '''Creates markov models for a range of lag times and plots their implied timescales'''
    timescale_step = max_lagtime // n_points
    lag_times = range(1,max_lagtime,timescale_step)
    
    models = []
    for lag_time in lag_times:
        models.append(MarkovStateModel(lag_time=lag_time, verbose=True, ergodic_cutoff = 'on', n_timescales=n_timescales).fit(clustering.labels_))
    timescales = [m.timescales_ for m in models]
    #errors = np.asarray([m.uncertainty_timescales() for m in models])
    n_timescales = min(n_timescales, min(len(ts) for ts in timescales))
    timescales = np.array([ts[:n_timescales] for ts in timescales])
    plt.cla()
    for i in range(1,n_timescales):
        try:
            timescale, = plt.plot(lag_times, timescales[:,i], 'o-')
            #plt.fill_between(lag_times, timescales[:,i]-errors[:,i], timescales[:,i]+errors[:,i], facecolor=timescale.get_color(), alpha=0.5)
        except:
            pass

    plt.fill_between(range(max_lagtime),range(max_lagtime),1)
    if log==True:
        plt.semilogy()
    plt.show()

In [None]:
plot_implied_timescales(clustering, max_lagtime=700, n_timescales=10)

In [None]:
'''Choose a lag time of 1 during early phases to ensure states are not excised'''
MSM = MarkovStateModel(lag_time=400, n_timescales=10, ergodic_cutoff='on').fit(clustering.labels_)

# Plot tICA free-energy surface

In [None]:
def plot_tica(tica_transformed=None, HMM=None, clustering=None, dims='all', states=None, trajectories=[]):
    '''Provide clustering OR HMM
    dims: list of indices. All combinations will be plotted as separate graphs
    states: list of indices. Cluster center of each specifiedstate is plotted
    trajectories: list of indices (refer to ds.glob_matches). Trajectory path is plotted as a line white->blue'''
    
    if dims == 'all':
        dims = range(len(tica_transformed[0][0]))
    if HMM:
        cluster_centers = HMM.means_
    if clustering:
        cluster_centers = clustering.cluster_centers_
    if (states == 'all'):
        states = range(len(cluster_centers))
        
    def make_plots(tica_transformed, states, trajectories, dim1=None, dim2=None, color='blue'):
           
        x = np.concatenate(tica_transformed)[:,dim1]
        y = np.concatenate(tica_transformed)[:,dim2]
        
        z,x,y = np.histogram2d(x,y, bins=50)
        F = -np.log(z)
        extent = [x[0], x[-1], y[0], y[-1]]
        axes[nplots%2].set_title("tIC "+str(dim1)+" vs. tIC "+str(dim2))
        #adjust levels to define contour lines
        axes[nplots%2].contourf(F.T, 50, cmap=plt.cm.hot, extent=extent, levels=np.linspace(-9,0,10))
        
        if states:
                a = cluster_centers[states,dim1]
                b = cluster_centers[states,dim2]  
                axes[nplots%2].plot(a,b,"o",zorder=1)
        
        for traj in trajectories:
            u = tica_transformed[traj][:,dim1]
            v = tica_transformed[traj][:,dim2]
            points = np.array([u, v]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            lc = LineCollection(segments, cmap=plt.get_cmap('Blues'))
            lc.set_array(np.asarray(range(len(u))))
            axes[nplots%2].plot()
            axes[nplots%2].add_collection(lc)
            axes[nplots%2].axis('auto')        

    nplots = 0
    for dim_num, dim1 in enumerate(dims):
        for dim2 in dims[dim_num+1:]:
            if nplots%2 == 0:
                fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(18,9))

            make_plots(tica_transformed, states, trajectories, dim1=dim1, dim2=dim2, color='blue')
            
            if nplots%2==1: plt.show()
            nplots += 1

In [None]:
plot_tica(tica_transformed, clustering=clustering, HMM=None, states='all', trajectories=[], dims='all')

# Analyse MSM and choose metric that will determine next-pass structures

## Calculate simple connectivity

In [None]:
def calc_transmin(MSM=MSM, n_results=10):
'''Calculates the number of transitions into and out of each state
    returns the lower number (in vs out) for each state
    Effectively identifies poor connectivity on an individual state basis'''    
    trans_min = np.zeros(len(MSM.countsmat_))
    for state in range(len(MSM.countsmat_)):
        trans_self = MSM.countsmat_[state,state]
        trans_out  = np.sum(MSM.countsmat_[state,:]) - trans_self
        trans_in   = np.sum(MSM.countsmat_[:,state]) - trans_self
        trans_min[state]  = min(trans_out, trans_in)

    print(trans_min.argsort()[:n_results])
    return trans_min

In [None]:
trans_min = calc_transmin(MSM)

In [None]:
for state in trans_min.argsort()[:10]:
    print("State "+str(state)+" population = "+str(np.unique(np.concatenate(clustering.labels_),return_counts=True)[1][state]))

## Calculate MFPT

In [None]:
mfpt_matrix = mfpts(MSM)
'''Specify solvent state manually if necessary'''
solvent_state = MSM.mapping_[clustering.labels_[0][0]]
mfpt_list = mfpt_matrix[solvent_state]

mfpt_states = (-mfpt_list).argsort()[:10]
print(mfpt_states)
chosen_states_mfpt = mfpt_states

In [None]:
def advanced_mfpt(transmat, start_state=None, sink_states=None, n_iter=100, lag_time=1):
    '''Simulates movement through the transition matrix and returns a list of first passage times
    between start_state (int) and sink_states (list) for calculating mean, median and distribution
    of first passage times
    mfpt is calculated in time units of the transmat, it is returned multiplied by lag_time,
    choose a number that will scale to your desired units
    With sufficient iterations will result in same mfpt as calculated from transition path theory
    
    Returns:
    list of first passage times
    
    Can easily be modded to return a list of transitions, for movie making etc.
    '''

    transprob = copy.copy(transmat)
    #create a copy of transmat with cumulative probabilities for the transitions out of each state
    #to enable use of np.searchsorted to randomly select transitions
    for i,state in enumerate(transmat):
        for j,transition in enumerate(state):
            transprob[i,j] = np.sum(transmat[i,:j+1])
    nstep_list = []
    for i in range(n_iter):
        state = start_state
        nsteps = 0
        while (state not in sink_states):
            nsteps +=1
            random_number = random.random()
            new_state = np.searchsorted(transprob[state],random_number)
        #     if (new_state != state):
        #         print(new_state)
            state = new_state
        #     if (nsteps > 100):
        #         break
        nstep_list.append(nsteps)
    return np.asarray(nstep_list)*lag_time

In [None]:
nstep_list = advanced_mfpt(HMM.transmat_, start_state=17, sink_states=[13], lag_time=1/30)

In [None]:
print(np.median(nstep_list))
print(np.mean(nstep_list))
plt.hist(np.asarray(nstep_list))

## Calculate per-state contribution to eigenvalue uncertainty

## Test connectivity

In [None]:
def join_clusters(counts_matrix, final_clusters=None, populations=None):
    '''
    Finds the highest minimum value between forward and backward transitions...
    Merges most connected pairs of states together and reports 
    the <final_clusters> least connected states
    
    Also note that due to discretization error, there is a bias towards joining
    clusters into the largest blob rather than forming multiple aggregates.
    This bias only really emerges after the largest connectivity issues are fixed
    '''
    
    counts_matrix = counts_matrix.copy()
    diagonal = np.diagonal(counts_matrix).copy()
    np.fill_diagonal(counts_matrix, 0)
    min_matrix = counts_matrix.copy()
    
    cluster_record = [[i] for i in range(counts_matrix.shape[0])]
    n_clusters = counts_matrix.shape[0]

    while n_clusters > final_clusters:
        
        for i in range(len(min_matrix)):
            for j in range(i+1,len(min_matrix)):
                min_matrix[i,j] = min(counts_matrix[i,j], counts_matrix[j,i])
                min_matrix[j,i] = min_matrix[i,j]
        
        x = np.argmax(min_matrix)//min_matrix.shape[0]
        y = np.argmax(min_matrix) - x*min_matrix.shape[0]

        #print("merging "+str(y)+" into "+str(x))

        counts_matrix[x] = counts_matrix[x]+counts_matrix[y]
        counts_matrix[y] = counts_matrix[y]*0
        counts_matrix[:,x] = counts_matrix[:,x]+counts_matrix[:,y]
        counts_matrix[:,y] = counts_matrix[:,y] * 0

        diagonal = diagonal + np.diagonal(counts_matrix)
        diagonal[x] = diagonal[x] + diagonal[y]
        diagonal[y] = 0
        np.fill_diagonal(counts_matrix, 0)

        cluster_record[x] = cluster_record[x] + cluster_record[y]
        cluster_record[y] = [-1]

        n_clusters -= 1
        
        print("*LUMP*")
        for i in cluster_record:
            if (i != [-1]) & (len(i)>1):
                print(i)

    for i in range(len(diagonal)):
        counts_matrix[i,i] = diagonal[i]

    for i in reversed(range(len(cluster_record))):
        if (cluster_record[i] == [-1]):
            counts_matrix = np.delete(counts_matrix, i, 0)
            counts_matrix = np.delete(counts_matrix, i, 1)
            cluster_record = np.delete(cluster_record, i)
            
    reduced_matrix = counts_matrix
            
    return reduced_matrix, cluster_record

In [None]:
def find_gateways(counts_matrix, cluster_record):
    '''Identifies the individual state in each "lump" which is most connected to the rest of the network.
    '''
    bad_lumps = cluster_record[np.asarray([len(cluster_record[i]) for i in range(len(cluster_record))]).argsort()][:-1]
    chosen_states = np.asarray([], dtype=int)
    for lump in bad_lumps:
        transmax = 0
        for cluster in lump:
            transmat = counts_matrix.copy()
            for cluster2 in lump:
                transmat[cluster][cluster2] = 0
            if np.amax(transmat[cluster]) > transmax:
                transmax = np.amax(transmat[cluster])
                gateway = cluster
        chosen_states = np.append(chosen_states, gateway)

    return chosen_states

In [None]:
reduced_matrix, cluster_record = join_clusters(MSM.countsmat_, 21)

In [None]:
for i in cluster_record:
    print(i)

In [None]:
chosen_states_connect = find_gateways(MSM.countsmat_, cluster_record)
chosen_states_connect

# Draw Samples

In [None]:
def draw_samples(clustering, n_samples, featurized=None, selection_criteria='random', chosen_states='all'):
    """Sample conformations from each state. Samples are randomly selected to be
    closer to the cluster center than the mean of the distribution
                                                                                                                                                           
    Parameters
    ----------
    clustering : Object
        should have property clustering.labels_ which corresponds
        to the cluster assignment of every frame

    n_samples : int
        How many samples to return from each state
        
    featurized : Object, optional
        Featurization used for calculating distance to cluster centers
        
    selection_criteria : string
        Should be one of 'median', '95percentile, 'closest', 'random'
            median: Returns states closer to the cluster center than the median of all cluster members
            80percentile: Returns states closer to the cluster center than 80% of all cluster members
            closest: Returns the n_samples closest cluster members to the cluster center
            random: random sampling with replacement
                                                                                                                                                           
    Returns
    -------
    selected_pairs_by_state : np.array, dtype=int, shape=(n_states, n_samples, 2)
        selected_pairs_by_state[state] gives an array of randomly selected (trj, frame)
        pairs from the specified state.
                                                                                                                                                           
    See Also
    --------
    utils.map_drawn_samples : Extract conformations from MD trajectories by index.
                                                                                                                                                           
    """
    sequences = clustering.labels_
    random = check_random_state(None)
    
    n_states = max(map(lambda x: max(x), sequences)) + 1
    n_states_2 = len(np.unique(np.concatenate(sequences)))
    assert n_states == n_states_2, "Must have non-empty, zero-indexed, consecutive states: found %d states and %d unique states." % (n_states, n_states_2)
                                                                                                                                                           
    if chosen_states == 'all':
        chosen_states = range(n_states)

    selected_pairs_by_state = []
    for state in chosen_states:
        all_frames = [np.where(a == state)[0] for a in sequences]
        pairs = [(trj, frame) for (trj, frames) in enumerate(all_frames) for frame in frames]

        if selection_criteria=='median':
            pair_deviation = [np.sqrt(np.sum(np.subtract(featurized[pair[0]][pair[1]], clustering.cluster_centers_[state])**2)) for pair in pairs]
            pairs = np.asarray(pairs)[np.where(pair_deviation < np.median(pair_deviation))]
            selected_pairs_by_state.append([pairs[random.choice(len(pairs))] for i in range(n_samples)])
            
        if selection_criteria=='80percentile':
            pair_deviation = [np.sqrt(np.sum(np.subtract(featurized[pair[0]][pair[1]], clustering.cluster_centers_[state])**2)) for pair in pairs]
            pairs = np.asarray(pairs)[np.where(pair_deviation < np.percentile(pair_deviation,20))]
            selected_pairs_by_state.append([pairs[random.choice(len(pairs))] for i in range(n_samples)])
         
        if selection_criteria=='closest':
            pair_deviation = [np.sqrt(np.sum(np.subtract(featurized[pair[0]][pair[1]], clustering.cluster_centers_[state])**2)) for pair in pairs]
            selected_pairs_by_state.append([np.asarray(pairs)[np.argsort(pair_deviation)[i]] for i in range(n_samples)])
                        
        if selection_criteria=='random':
            selected_pairs_by_state.append([pairs[random.choice(len(pairs))] for i in range(n_samples)])
            
    return np.array(selected_pairs_by_state)


In [None]:
def output_data(file_path=None, MSM=None):
    '''This function was designed to output data in the format recognised by the old MSMexplorer program
    There has been a msmexplorer python package developed for visualizing the data which is probably better.
    However the MSMexplorer program lets you arrange states with the mouse, which is nice'''
    os.makedirs(file_path)
    np.savetxt(file_path+'/Populations.dat', MSM.populations_, delimiter='\n')

    tCounts = MSM.countsmat_
    tCountsFile=open(file_path+"/tCounts.mtx", "w")
    line = ' '.join(map(str,np.shape(tCounts))) + ' ' + str(np.count_nonzero(tCounts)) + '\n'
    tCountsFile.write(line)
    for indexes, value in np.ndenumerate(tCounts):
        x,y = indexes
        if (value != 0):
            line = str(x+1) + ' ' + str(y+1) + ' ' + str(value) + "\n"
            tCountsFile.write(line)

    tCountsFile.close()

    tProb = MSM.transmat_
    tProbFile=open(file_path+"/tProb.mtx", "w")
    line = ' '.join(map(str,np.shape(tProb))) + ' ' + str(np.count_nonzero(tProb)) + '\n'
    tProbFile.write(line)
    for indexes, value in np.ndenumerate(tProb):
        x,y = indexes
        if (value != 0):
            line = str(x+1) + ' ' + str(y+1) + ' ' + str(value) + "\n"
            tProbFile.write(line)

    tProbFile.close()
    
    for state, samples in enumerate(selected_pairs_by_state):
    #for state in MSM.state_labels_:
        samples = selected_pairs_by_state[state]

        traj = None
        for pair_num, pair in enumerate(samples):
            print(pair_num)
            traj_num,frame = pair
            traj_num = traj_num
    #        traj_num = mapping[traj_num]
            try:
                traj = traj.join(ds.get(traj_num)[frame])
            except:
                traj = ds.get(traj_num)[frame]
        filename = str(file_path)+'/State' + str(state) + '.xtc'
        traj.save_xtc(filename)
        print(state,)

In [None]:
selected_pairs_by_state = draw_samples(clustering, 10, tica_transformed)
print(selected_pairs_by_state.shape)

### Make sure MSM has not trimmed states and use MSM.mapping_ if required

In [None]:
output_data(file_path=pass_folder+"MSM_pass"+str(this_pass).zfill(2)+"_"+str(MSM.lag_time), MSM=MSM)

# Setup next pass

Make sure chosen_states is not in MSM space

MSMs may excise states with poor connectivity, changing the numbering

In [None]:
selected_pairs_by_state = draw_samples(clustering, tica_transformed, 5, random=False, chosen_states=chosen_states)

In [None]:
'''Only adds starting structure, all other files must be in a generic
simulation folder 'blank_sim/' in the workdir'''
next_pass_folder = "pass"+str(next_pass).zfill(2)+"/"
os.makedirs(next_pass_folder)
for index,sample in enumerate(np.vstack(selected_pairs_by_state)):
    %cd $workdir
    traj,frame = sample
    sim_num = index
    run_name = 'd3_hlp_'+str(next_pass)+'_'+str(sim_num)
    run_folder = next_pass_folder+run_name+"/"
    shutil.copytree('blank_sim/', run_folder)
    ds.get(traj)[frame].save_pdb(run_folder+'system.pdb')
    %cd $run_folder
    !gmx grompp -f em1.mdp -po em1_out.mdp -c system.pdb -p system.top -o em1.tpr -maxwarn 1
    !gmx mdrun -deffnm em1
    !gmx grompp -f md1.mdp -po md1_out.mdp -c em1.gro -p system.top -o md1.tpr -maxwarn 1
%cd $workdir

# HMM

In [None]:
def plot_HMM_timescales(tica_transformed=None, n_states=None, max_lagtime=2000, 
                        n_timescales=None, n_init=10, n_em_iter=10, n_steps=3):
    '''Implied timescales but for HMMs. This can easily take forever, especially if you
    want to make quality models.
    Also note that HMMs have more variables than MSM and this function fixes them all, 
    and only varies lag_time'''
    timescale_step = max_lagtime // n_steps
    lag_times = range(1,max_lagtime,timescale_step)
    sequences = tica_transformed
    n_features = len(tica_transformed[0][0])
    if n_timescales==None:
        n_timescales=n_states-1
    else:
        n_timescales = min(n_timescales, n_states-1)
        
    hmm_timescales = np.zeros((0,n_timescales))
    log_likelihoods = []
    print("lag_time: ",)
    for l, lag_time in enumerate(lag_times):
        strided_data = [s[i::lag_time] for s in sequences for i in range(lag_time)]
        HMM = hmm.GaussianHMM(n_states=n_states, n_init=n_init, n_iter=n_em_iter).fit(strided_data)
        timescales = HMM.timescales_ * lag_time
        hmm_timescales = np.vstack((hmm_timescales, timescales))
        log_likelihoods.append(HMM.fit_logprob_[-1])
        print(" "+str(lag_time),)
            
    for i in range(n_timescales):
        timescale, = plt.plot(lag_times, hmm_timescales[:,i])
    
    #print log_likelihoods
    plt.show()
    return log_likelihoods

In [None]:
def HMM_draw_samples(HMM, tica_transformed, n_samples, random=False, chosen_states='all'):
    """Sample conformations from each state. Samples are randomly selected to be
    closer to the cluster center than the mean of the distribution
                                                                                                                                                           
    Parameters
    ----------
    clustering : Object
        should have property clustering.labels_ which corresponds
        to the cluster assignment of every frame

    n_samples : int
        How many samples to return from each state
                                                                                                                                                           
    Returns
    -------
    selected_pairs_by_state : np.array, dtype=int, shape=(n_states, n_samples, 2)
        selected_pairs_by_state[state] gives an array of randomly selected (trj, frame)
        pairs from the specified state.
                                                                                                                                                           
    See Also
    --------
    utils.map_drawn_samples : Extract conformations from MD trajectories by index.
                                                                                                                                                           
    """
    
    logprob = [
        mixture.log_multivariate_normal_density(
        x, HMM.means_, HMM.vars_, covariance_type='diag'
        ) for x in tica_transformed
        ]

    sequences = [lp.argmax(1) for lp in logprob]

    n_states = max(map(lambda x: max(x), sequences)) + 1
    n_states_2 = len(np.unique(np.concatenate(sequences)))
    assert n_states == n_states_2, "Must have non-empty, zero-indexed, consecutive states: found %d states and %d unique states." % (n_states, n_states_2)
                                                                                                                                                           
    random_state = check_random_state(None)
    
    if chosen_states == 'all':
        chosen_states = range(n_states)
                                                                                                                                                           
    
    selected_pairs_by_state = []
    for state in chosen_states:
        all_frames = [np.where(a == state)[0] for a in sequences]
        pairs = [(trj, frame) for (trj, frames) in enumerate(all_frames) for frame in frames]
        if random==False:
            pair_deviation = [np.sqrt(np.sum(np.subtract(tica_transformed[pair[0]][pair[1]], HMM.means_[state])**2)) for pair in pairs]
            pairs = np.asarray(pairs)[np.where(pair_deviation < np.median(pair_deviation))]
        if random==-1:
            pair_deviation = [np.sqrt(np.sum(np.subtract(tica_transformed[pair[0]][pair[1]], HMM.means_[state])**2)) for pair in pairs]
            pairs = np.asarray(pairs)[np.where(pair_deviation > np.percentile(pair_deviation,95))]
            
        selected_pairs_by_state.append([pairs[random_state.choice(len(pairs))] for i in range(n_samples)])                                                                                                                                                      
    return np.array(selected_pairs_by_state)



In [None]:
#test timescales for a given number of states
plot_HMM_timescales(tica_transformed, n_states=5)

In [None]:
#test n_states for a given lagtime
lag_time=1000
strided_data = [s[i::lag_time] for s in tica_transformed for i in range(0,lag_time,10)]
logprob_list = []
for n_states in [2,4,6,8]:
    print("n_states = "+str(n_states))
    HMM = hmm.GaussianHMM(n_states=n_states, n_init=200, n_iter=50).fit(strided_data)
    logprob_list.append(HMM.fit_logprob_[-1])

In [None]:
plt.plot(logprob_list)

In [None]:
#When you decide on the number of states and lag time for your HMM, this will build it and write output
n_states = 8
lag_time = 700

file_path = pass_folder+"/HMM/"
os.makedirs(file_path)
#tica_transformed = utils.load(pass_folder+"tica_transformed")
strided_data = [s[l::lag_time] for s in tica_transformed for l in range(lag_time)]
print("Starting HMM")
HMM = hmm.GaussianHMM(n_states=n_states, n_init=30, n_iter=100).fit(strided_data)
utils.dump(HMM, file_path+"HMM")
print("HMM created")

tProb = HMM.transmat_
tProbFile=open(file_path+"/tProb.mtx", "w")
line = ' '.join(map(str,np.shape(tProb))) + ' ' + str(np.count_nonzero(tProb)) + '\n'
tProbFile.write(line)
for indexes, value in np.ndenumerate(tProb):
    x,y = indexes
    if (value != 0):
        line = str(x+1) + ' ' + str(y+1) + ' ' + str(value) + "\n"
        tProbFile.write(line)

tProbFile.close()

np.savetxt(file_path+'/Populations.dat', HMM.populations_, delimiter='\n')

selected_pairs_by_state = HMM.draw_centroids(np.asarray(tica_transformed))
selected_pairs_by_state = selected_pairs_by_state[0].squeeze()
for state, pair in enumerate(selected_pairs_by_state):
#    pair[0] = mapping[pair[0]]
    samples = utils.map_drawn_samples([[pair]], ds)
    print(state,)
    filename = str(file_path)+'/Mean' + str(state) + '.xtc'
    samples[0].save_xtc(filename)

selected_pairs_by_state = HMM_draw_samples(HMM, tica_transformed, 10)
for state, samples in enumerate(selected_pairs_by_state):
    traj = None
    for pair_num, pair in enumerate(samples):
        print(pair_num,)
        traj_num,frame = pair
        traj_num = traj_num
#        traj_num = mapping[traj_num]
        try:
            traj = traj.join(ds.get(traj_num)[frame])
        except:
            traj = ds.get(traj_num)[frame]
    filename = str(file_path)+'/State' + str(state) + '.xtc'
    traj.save_xtc(filename)
    print(state,)