In [None]:
import plumed
from matplotlib import pyplot as plt
from matplotlib.collections import LineCollection
import matplotlib.colors as mcolors
import numpy as np
import MDAnalysis as mda
import pandas as pd
import random
from deeptime.decomposition import TICA
from deeptime.covariance import KoopmanWeightingEstimator
from deeptime.clustering import MiniBatchKMeans
from deeptime.markov import TransitionCountEstimator
from deeptime.markov.msm import MaximumLikelihoodMSM
from deeptime.plots import plot_implied_timescales
from deeptime.util.validation import implied_timescales
import networkx as nx
from copy import deepcopy
from numpy.random import multinomial
import subprocess
import os
import math
from scipy.stats import pearsonr,chisquare
import warnings
import glob
warnings.filterwarnings('ignore')

In [None]:
# Configure files
whole_xtc = './Conf/traj_comp.xtc'
whole_gro = './Conf/step7.gro'
whole_tpr = './Conf/step7.tpr'
gmx = 'docker run --gpus all --rm -e CUDA_MPS_PIPE_DIRECTORY=/tmp/nvidia-mps -e CUDA_MPS_LOG_DIRECTORY=/tmp/nvidia-log --rm --gpus all -v $PWD:/workdir -v /tmp/nvidia-mps:/tmp/nvidia-mps -v /tmp/nvidia-log:/tmp/nvidia-log --ipc host -w /workdir registry.cn-hangzhou.aliyuncs.com/linjiahao/gromacs:2023-plumed-avx2-u2204-cu124 gmx'

# RUN setup 
round = 1
total_round = 2
start = 40
end = 51
ligand_number = 4
ligand_ids = [474,475,476,477] #No. in tpr
pocket_ids = [91,151,154,155,158,213,380,383] #No. in tpr
chain_ids = ['A','B','C','D']
# TICA parameters (for adaptive sampling)
tica_lagtime = 500
dim = None
var_cutoff = 0.95
koopman = True
runtICA_njobs = 12
# Implied time scale evaluation
its_lagtimes = [1,2,5,10,20,35,50,100,200,500]
tica_its_lagtimes = [1,2,5,10,20,35,50,100,200,500]
n_its = 5
# Markov State Model parameters
msm_lagtime = 500
# PCCA parameters
n_metastable_sets = 30
seed_num = 12
assignments = []
seed_idx = []
num_cvs = 2

patience = 2
convergence_criteria = 0.99

In [None]:
# TICA functions         
def read_features(round,start,end,seed_num,chain_ids,tica_lagtime,supplement):
    # traj is the time-series COLVAR in pandas.DataFrame format
    data = []
    traj = []
    for i in range(1,round+1):
        if i != 1:
            start = 0
            end = seed_num-1
        for split in range(start,end+1):
            for chain in chain_ids:
                try:
                    load = plumed.read_as_pandas(f'./CV/COLVAR_round{i}_split{split}_chain{chain}')
                    load = load.drop(columns=['time'])
                    columns = list(load.columns.values)
                    # Remove all dihedral angles, only keep sin/cos dihedrals 
                    for column in columns:
                        if column[:3] == 'phi' or column[:3] == 'psi' or column[:3] == 'chi' or column[:5] == 'omega':
                            load = load.drop(columns=[column])
                    if len(load) > 1.6*tica_lagtime:
                        traj.append(load)
                except:
                    pass

    # data is the time-series COLVAR in numpy.ndarrays format
    for i in range(len(traj)):
        numpy_data = traj[i].to_numpy(dtype='float32')
        data.append(numpy_data)
    n_features = len(data[0][0])
    time_length = len(data[0])
    # data = np.array(data)
    data_supp = []
    if supplement:
        for i in range(2,round+1):
            if round != 1:
                round_seed_idx = np.loadtxt(f'Seed/round{i}_seed.txt')
        
                for i,round_seed_idx_i in enumerate(round_seed_idx):
                    sim_i = int(round_seed_idx_i[0])
                    frame = int(round_seed_idx_i[1])
                    for chain_id in range(len(chain_ids)):
                        if frame == 0:
                            continue
                        elif tica_lagtime > frame:
                            data_supp_pre = data[sim_i*len(chain_ids)-len(chain_ids)+chain_id][:frame,:]
                            data_supp_post = data[end-start+1+(round-2)*seed_num+i][:tica_lagtime,:]
                        else :
                            data_supp_pre = data[sim_i*len(chain_ids)-len(chain_ids)+chain_id][frame -tica_lagtime+1:frame,:]
                            data_supp_post = data[end-start+1+(round-2)*seed_num+i][:tica_lagtime,:]
                    
                       
                        data_supp_i = np.concatenate([data_supp_pre,data_supp_post])
                        data_supp.append(data_supp_i)
                            

    return data,traj,data_supp
def run_TICA(data,data_supp,lagtime,dim=None,var_cutoff=None):
    data_syn = data + data_supp
    tica = TICA(lagtime=lagtime,dim=dim,var_cutoff=var_cutoff)
    try:
        koopman_estimator = KoopmanWeightingEstimator(lagtime=lagtime)
        reweighting_model = koopman_estimator.fit(data_syn).fetch_model()
        tica = tica.fit(data_syn, weights=reweighting_model).fetch_model()
    except:
        print('Can\'t perform Koopman Reweighting TICA, try normal TICA')
        tica = tica.fit(data_syn).fetch_model()
    tica_output = tica.transform(data)
    tica_output_concat = np.concatenate(tica_output)
    tica_output_supp = []
    for data_supp_i in data_supp:
        tica_output_supp_i = tica.transform(data_supp_i)
        tica_output_supp.append(tica_output_supp_i)
    return tica,tica_output,tica_output_concat,tica_output_supp
    
def calculate_nmicro(data_concat):
    n_microstates = int(max(100, np.round(0.6 * np.log10(data_concat.shape[0] / 1000) * 1000 + 50)))
    return n_microstates
    
def run_kmeans(tica_output,tica_output_concat,tica_output_supp,n_microstates,n_jobs):
    minibatch_kmeans = MiniBatchKMeans(n_clusters=n_microstates,batch_size=10000,max_iter=100,init_strategy='kmeans++',n_jobs=n_jobs)
    microstates = minibatch_kmeans.fit(tica_output_concat).fetch_model()
    assignments_concat = microstates.transform(tica_output_concat)
    assignments = assignments_concat.reshape(-1,tica_output.shape[1])
    assignments_supp = []
    for tica_output_supp_i in tica_output_supp:
        assignments_supp_i = microstates.transform(tica_output_supp_i)
        assignments_supp.append(assignments_supp_i)
    return assignments,assignments_concat,assignments_supp

def evaluate_its(lagtimes,n_its,round):
    models = []
    for lagtime in lagtimes:
        counts = TransitionCountEstimator(lagtime=lagtime, count_mode='sliding').fit_fetch(assignments.reshape(-1,tica_output.shape[1]))
        models.append(MaximumLikelihoodMSM().fit_fetch(counts))
        its_data = implied_timescales(models)
    fig, ax = plt.subplots(1, 1)
    plot_implied_timescales(its_data, n_its=n_its, ax=ax)
    ax.set_yscale('log')
    ax.set_title('Implied timescales')
    ax.set_xlabel('lag time (steps)')
    ax.set_ylabel('timescale (steps)')
    plt.savefig(f'./figures/its-round{round}.png',dpi=600)
    return None

def evaluate_tica_its(tica_its_lagtimes,dim,var_cutoff,koopman,n_its,round):
    tica_models = []
    for lag in tica_its_lagtimes:
        tica = run_TICA(data,lag,dim,var_cutoff,koopman)[0]
        tica_models.append(tica)
    its_data = implied_timescales(tica_models)
    fig, ax = plt.subplots(1, 1)
    plot_implied_timescales(its_data, n_its=n_its, ax=ax)
    ax.set_yscale('log')
    ax.set_title('Implied timescales')
    ax.set_xlabel('lag time (steps)')
    ax.set_ylabel('timescale (steps)')
    plt.savefig(f'./figures/tica-its-round{round}.png',dpi=600)
    return None
    
def build_MSM(msm_lagtime,assignments,assignments_supp):
    assignments_syn = list(assignments)+assignments_supp
    counts = TransitionCountEstimator(lagtime=msm_lagtime, count_mode='sliding').fit_fetch(assignments_syn)
    msm = MaximumLikelihoodMSM().fit_fetch(counts)
    return counts,msm
    
def run_PCCA(msm,n_metastable_sets):
    pcca = msm.pcca(n_metastable_sets=n_metastable_sets)
    return pcca

### Adaptive seeding functions
def fix_disconnected(counts,n_microstates,msm,pcca):
    sets = counts.connected_sets(connectivity_threshold=0,directed=True,sort_by_population=True)
    disconnected_sets = sets[1:]
    n_macro_disconnected = len(disconnected_sets)
    disconnected_dict = {}
    for i in range(n_macro_disconnected):
        macro_label = n_metastable_sets + i
        for j in disconnected_sets[i]:
            disconnected_dict[j] = macro_label
    pcca_assignments = np.zeros(n_microstates,dtype=int)
    stationary_distribution = np.zeros(n_microstates,)

    connected_count = 0
    for i in range(n_microstates):
        if i in disconnected_dict.keys():
            pcca_assignments[i] = disconnected_dict[i]
            stationary_distribution[i] = 0
        else:
            pcca_assignments[i] = pcca.assignments[connected_count]
            stationary_distribution[i] = msm.stationary_distribution[connected_count]
            connected_count += 1
        
    return n_macro_disconnected,pcca_assignments,stationary_distribution
    
#maybe it is not correct    
def count_macro(seed_num,n_macro_disconnected,pcca_assignments,assignments_concat,round):
    macro_assignments = dict(enumerate(pcca_assignments))
    macro_timeseries = np.vectorize(macro_assignments.get)(assignments_concat)
    # Macrostate seeding
    unique_macro, counts_macro = np.unique(macro_timeseries, return_counts=True)
    prob_macro = (1 / counts_macro) / np.sum(1 / counts_macro)
    macrostate_seed = multinomial(seed_num,prob_macro)

    # Microstate seeding
    # First count the occurences of all microstates
    unique_micro, counts_micro = np.unique(assignments_concat, return_counts=True)
    seed_idx = []
    counts_micro_i_log = {}
    for macro_i, n_sample in enumerate(macrostate_seed):
        # locate the index of microstates not assigned to current selected macrostates
        not_macro_idx = np.where(pcca_assignments != np.unique(pcca_assignments)[macro_i])
        # let all entries corresponding to not_macro_idx = 0, therefore ignore them during selection
        counts_micro_i = deepcopy(counts_micro)
        counts_micro_i[not_macro_idx] = 0
        # let 1/0 = 0
        inverse_counts = np.where(counts_micro_i==0, 0, 1/counts_micro_i)
        prob_micro_i = inverse_counts / np.sum(inverse_counts)
        microstate_seed = multinomial(n_sample,prob_micro_i)
        # Record selection statistics for visualization
        if n_sample != 0:
            macro_idx_log = unique_macro[macro_i] 
            counts_micro_i_log[macro_idx_log] = [counts_micro_i,microstate_seed]
        for micro_i, n_sample in enumerate(microstate_seed):
            seed_idx = seed_idx + n_sample * [micro_i]
    return seed_idx
#below need to refine
def write_gmxfile(round,start,end,seed_num,whole_gro,assignments,seed_idx,dumpseed=True):
    ### .gro seed files generation
    conf_seed = []
    if dumpseed:
        last_round = round
        round += 1
        !mkdir round{round}_unbiased
        u_list = []
        for round_i in range(1,round):
            if round_i != 1:
                start = 0
                end = seed_num-1
            for i in range(start,end+1):
                u_traj = mda.Universe(whole_gro,f'./round{round_i}_unbiased/split_{i}.xtc')
                u_list.append(u_traj)
        conf_seed = []
    
        for i,seed in enumerate(seed_idx):
            conf_idx = np.array(np.where(assignments==seed)).T
            conf_seed_frame = conf_idx[np.random.randint(conf_idx.shape[0], size=1), :][0]
            conf_seed_frame[0] = math.ceil((conf_seed_frame[0]+1)/4)-1
            conf_seed.append(conf_seed_frame)
    
        for i,seed in enumerate(conf_seed):
            traj_no = seed[0]
            frame = seed[1]
            u_list[traj_no].atoms.write(f'./round{round}_unbiased/split_{i}.gro',frames=u_list[traj_no].trajectory[frame:frame+1])
        
        np.savetxt(f'./Seed/round{round}_seed.txt',conf_seed,fmt='%s')
        
    with open(f'./round{round}_unbiased/run.sh','w') as f:
        f.writelines(f'gmx={gmx}\n')
        f.writelines(f'start={start}\n')
        f.writelines(f'end={end}\n')
        f.writelines('for i in `seq $start $end`\n')
        f.writelines('do\n')
        f.writelines(f'    $gmx grompp -f ./Conf/step7_production.mdp -p ./Conf/topol.top -c ./round{round}_unbiased/split_$i.gro -o ./round{round}_unbiased/split_$i.tpr -n ./Conf/index.ndx \n')
        f.writelines(f'    $gmx mdrun -deffnm ./round{round}_unbiased/split_$i -ntmpi 1 -ntomp 20 -nb gpu -pme gpu -pmefft gpu -bonded gpu -gpu_id 0 -cpi -v -pin on -update gpu\n')
        f.writelines('done\n')
        
    return round,conf_seed

In [None]:
### Machine learning CV related functions
def tica_cv_print(feature_dats,traj,tica,num_cvs,chain_ids,round,start,end,seed_num,driver=True):
    traj_concat = pd.concat(traj,axis=0)
    lines = []
    for feature_dat in feature_dats:
        with open(feature_dat,'r') as g:
            line = g.readlines()
        lines.append(line)
    for i in range(1,round+1):
        if i != 1:
            start = 0
            end = seed_num-1
        for split in range(start,end+1):
            for chaini in range(len(chain_ids)):
                with open(f'./tica-cv/tica-driver_round{i}_chain{chain_ids[chaini]}.dat','w') as f:
                    f.writelines('MOLINFO STRUCTURE=./Conf/step7.pdb\n')
                    for line in lines[chaini][1:-1]:   # remove PRINT argument
                        f.writelines(line)
                    arg_string = ''
                    parameters_string = ''
                    for fi, feature in enumerate(traj[chaini].columns):
                        arg_string = arg_string + feature + ','
                        parameters_string = parameters_string + str(tica.mean_0[fi]) + ','
                    arg_string = arg_string[:-1]
                    parameters_string = parameters_string[:-1]
                    for j in range(num_cvs):
                        coeff_string = ''
                        for value in tica.singular_vectors_left.T[j]:
                            string = str(value)+','
                            coeff_string = coeff_string + string
                        coeff_string = coeff_string[:-1]
                        f.writelines('tica{j}: COMBINE ARG={arg_string} COEFFICIENTS={coeff_string} PARAMETERS={parameters_string} PERIODIC=NO\n'.format(j=j,arg_string=arg_string,coeff_string=coeff_string,parameters_string=parameters_string))
         
                    f.writelines(f'PRINT ARG=tica0,tica1 STRIDE=1 FILE=./tica-cv/tica_COLVAR_round{i}_split{split}_chain{chain_ids[chaini]}')
                if driver:
                    !plumed driver --mf_xtc ./round{i}_unbiased/split_{split}.xtc --plumed ./tica-cv/tica-driver_round{i}_chain{chain_ids[chaini]}.dat
    return traj[0].columns,tica.singular_vectors_left.T
def cv_projection_2D(round,start,end,seed_num,conf_seed,chain_ids,cv1,cv2,dir='tica-cv',prefix='tica_',seed=True,savefig=False,diff1=True,diff2=True):
    CV1 = cv1
    CV2 = cv2
    for i in range(1,round+1):
        if i != 1: 
            start_num = 0
            end_num = seed_num-1
        else:
            start_num = start
            end_num = end
        for split in range(start_num,end_num+1):
            for chain in chain_ids:
                COLVAR = plumed.read_as_pandas(f'./{dir}/{prefix}COLVAR_round{i}_split{split}_chain{chain}')
                if diff1:
                    cv1 = CV1[:-3]+chain+CV1[-3:]
                if diff2:
                    cv2 = CV2[:-3]+chain+CV2[-3:]
                plt.scatter(COLVAR[f'{cv1}'],COLVAR[f'{cv2}'],alpha=0.3)
    if seed:
        for i in range(1,round+1):
            if i != 1: 
                start_num = 0
                end_num = seed_num-1
            else:
                start_num = start
                end_num = end
            for split in range(start_num,end_num+1):
                for chain in chain_ids:           
                    for k in conf_seed:
                        seed_round = k[0]//seed_num+1
                        seed_traj = k[0]%seed_num
                        if i == seed_round:
                            if split == seed_traj+start:
                                COLVAR = plumed.read_as_pandas(f'./{dir}/{prefix}COLVAR_round{seed_round}_split{split}_chain{chain}')
                                if diff1:
                                    cv1 = CV1[:-3]+chain+CV1[-3:]
                                if diff2:
                                    cv2 = CV2[:-3]+chain+CV2[-3:]
                                plt.scatter(COLVAR.iloc[k[1]][f'{cv1}'],COLVAR.iloc[k[1]][f'{cv2}'],c='black')
    if CV1 == 'tica0':              
        plt.xlabel('TIC1',fontsize=15)
    else:
        plt.xlabel(f'{CV1}',fontsize=15)
    if CV2 == 'tica1':              
        plt.ylabel('TIC2',fontsize=15)
    else:
        plt.ylabel(f'{CV2}',fontsize=15)
    if savefig:
        plt.savefig(f'./figures/round{round}_ticacv.png',dpi=600,transparent=True)
    plt.show()

def cv_projection_1D(round,start,end,seed_num,chain_ids,cv='dtotal',dir='tica-cv',prefix='tica_',seed=True,savefig=False,diff=True):
    CV = cv
    for i in range(1,round+1):
        if i != 1: 
            start_num = 0
            end_num = seed_num-1
        else:
            start_num = start
            end_num = end
        for split in range(start_num,end_num+1):
            for chain in chain_ids:
                alpha = 1.0
                COLVAR = plumed.read_as_pandas(f'./{dir}/{prefix}COLVAR_round{i}_split{split}_chain{chain}')
                if diff:
                    cv = CV[:-3]+chain+CV[-3:]
                plt.plot(COLVAR['time']/20,COLVAR[f'{cv}'],alpha=alpha)
    plt.xlabel('MD Time (ns)',fontsize=15)
    if cv == 'dtotal':
        plt.ylabel('Distance (nm)',fontsize=15)
    else:
        plt.ylabel(f'{CV}',fontsize=15)
    if savefig:
        plt.xlim(0,50)
        plt.savefig(f'./figures/round{round}_{CV}.png',dpi=600,transparent=True)       
    plt.show()

def interpret_cv(arg_string,coeff_string,num_cvs,savefig=False):
    coeff = []
    if type(arg_string) is not list:
        arg = arg_string.tolist()
    else:
        arg = arg_string
    coeff_string = coeff_string
    coeff = coeff_string.tolist()
    for i in range(num_cvs):
        coeff_toprint = []
        arg_toprint = []
        plt.plot(figsize=(6,6))
        zip_lists = list(zip(coeff[i],arg))
        zip_lists.sort(key=lambda x:abs(x[0]),reverse=True)
        coeff_toprint, arg_toprint = zip(*zip_lists)
        coeff_toprint = list(coeff_toprint)[:8]
        arg_toprint = list(arg_toprint)[:8]
        print(arg_toprint)                
        plt.bar(arg_toprint,coeff_toprint)
        plt.title(f'TIC {i+1}',fontsize=16)
        plt.xlabel('Components',fontsize=16)
        plt.ylabel('Weights',fontsize=16)
        plt.xticks(range(len(arg_toprint)),rotation=45,fontsize=12)
        plt.tight_layout()
        if savefig:
            plt.savefig(f'./figures/TIC {i+1}.png',dpi=600,transparent=True)
        plt.show()

    return 

In [None]:
evaluate_tica_its(tica_its_lagtimes,dim,var_cutoff,False,n_its,round)

In [None]:
#Perform tICA
data,traj,data_supp = read_features(round,start,end,seed_num,chain_ids,tica_lagtime,supplement=True)       
tica,tica_output,tica_output_concat,tica_output_supp = run_TICA(data,data_supp,tica_lagtime,dim,var_cutoff)
# K-means clustering
n_microstates = calculate_nmicro(tica_output_concat)
assignments,assignments_concat,assignments_supp = run_kmeans(tica_output,tica_output_concat,tica_output_supp,n_microstates,runtICA_njobs)

In [None]:
evaluate_its(its_lagtimes,n_its,round)
counts,msm = build_MSM(msm_lagtime,assignments,assignments_supp)
pcca = run_PCCA(msm,n_metastable_sets)

In [None]:
# Seed generation
n_macro_disconnected,pcca_assignments,stationary_distribution = fix_disconnected(counts,n_microstates,msm,pcca)
seed_idx = count_macro(seed_num,n_macro_disconnected,pcca_assignments,assignments_concat,round)

In [None]:
# Write files and Run adaptive MD
round,conf_seed = write_gmxfile(round,start,end,seed_num,whole_gro,assignments,seed_idx)

In [None]:
arg_string,coeff_string = tica_cv_print(['./CV/round1_split40_chainA.dat','./CV/round1_split40_chainB.dat','./CV/round1_split40_chainC.dat','./CV/round1_split40_chainD.dat'],traj,tica,num_cvs,chain_ids,round-1,start,end,seed_num,True)
cv_projection_2D(round-1,start,end,seed_num,[],chain_ids,cv1='tica0',cv2='tica1',dir='./tica-cv',prefix='tica_',seed=False,savefig=True,diff1=False,diff2=False)

In [None]:
interpret_cv(arg_string,coeff_string,num_cvs,savefig=True)

In [None]:
cv_projection_1D(round-1,start,end,seed_num,[],chain_ids,cv='dtotal',dir='CV',prefix='',savefig=True,diff=False,seed=False)