In [None]:
import plumed
from matplotlib import pyplot as plt
import matplotlib
import numpy as np
import MDAnalysis as md
from MDAnalysis.analysis import distances
import pandas as pd
import itertools
import random
import deeptime
from deeptime.decomposition import TICA
from deeptime.covariance import KoopmanWeightingEstimator
from deeptime.clustering import MiniBatchKMeans
from deeptime.markov import TransitionCountEstimator
from deeptime.markov.msm import MaximumLikelihoodMSM
from deeptime.markov.msm import BayesianMSM
from deeptime.plots import plot_implied_timescales
from deeptime.util.validation import implied_timescales
from deeptime.plots.chapman_kolmogorov import plot_ck_test
import networkx as nx
from copy import deepcopy
import torch
from torch import nn, optim, autograd
from torch.nn import functional as F
from torch.utils.data.dataset import random_split
from snrv import Snrv, load_snrv
import math

import warnings
warnings.filterwarnings('ignore')

In [None]:
# USE THIS TO CLEAR CUDA MEMORY
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
!plumed

In [None]:
print(torch.__version__)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

#### 1. Feature Selection

##### 1.1 Feature Selection

In [None]:
# Feature selection functions

def select_dihedrals(universe,dihedral_type,start_res,end_res):
    dihedrals={}
    if 'phi' in dihedral_type:
        dihedrals['phi'] = []
        for i in range(start_res,end_res+1):
            if i != universe.residues.resids[0]:
                dihedrals['phi'].append(i)

    if 'psi' in dihedral_type:
        dihedrals['psi'] = []
        for i in range(start_res,end_res+1):
            if i != universe.residues.resids[-1]:
                dihedrals['psi'].append(i)

    if 'omega' in dihedral_type:
        dihedrals['omega'] = []
        for i in range(start_res,end_res+1):
            if i != universe.residues.resids[0]:
                dihedrals['omega'].append(i)

    if 'chi1' in dihedral_type:
        dihedrals['chi1']=[]
        for i in range(start_res,end_res+1):
            count = i - universe.residues.resids[0]
            if universe.residues.resnames[count] not in ['GLY','ALA']:
                dihedrals['chi1'].append(i)

    if 'chi2' in dihedral_type:
        dihedrals['chi2']=[]
        for i in range(start_res,end_res+1):
            count = i - universe.residues.resids[0]
            if universe.residues.resnames[count] not in ['GLY','ALA','CYS','SER','THR','VAL']:
                dihedrals['chi2'].append(i)

    if 'chi3' in dihedral_type:
        dihedrals['chi3']=[]
        for i in range(start_res,end_res+1):
            count = i - universe.residues.resids[0]
            if universe.residues.resnames[count] in ['ARG','GLN','GLU','LYS','MET']:
                dihedrals['chi3'].append(i)

    if 'chi4' in dihedral_type:
        dihedrals['chi4']=[]
        for i in range(start_res,end_res+1):
            count = i - universe.residues.resids[0]
            if universe.residues.resnames[count] in ['ARG','LYS']:
                dihedrals['chi4'].append(i)
    return dihedrals

In [None]:
traj_stride = 1                               # stride for trajectories output

u = md.Universe('traj_and_dat/input.pdb')
nitrogen = u.select_atoms('name N')
oxygen = u.select_atoms('name O')
calpha = u.select_atoms('name CA')
cbeta = u.select_atoms('name CB')
cgamma = u.select_atoms('name CG or name CG1 or name CG2')
pairs = {'hbond':[],'CA':[],'CB':[],'CG':[]}
nitrogen_id = nitrogen.ids
oxygen_id = oxygen.ids
calpha_id = calpha.ids
cbeta_id = cbeta.ids
cgamma_id = cgamma.ids
pairs['hbond'] = list(itertools.product(nitrogen_id,oxygen_id))
pairs['CA'] = list(itertools.combinations(calpha_id,2))
pairs['CB'] = list(itertools.combinations(cbeta_id,2))
pairs['CG'] = list(itertools.combinations(cgamma_id,2))
dihedrals = select_dihedrals(u,['phi','psi','chi1','chi2','chi3','chi4'],1,20)
    
# Write features into plumed file for TICA. e^-d are introduced as compensating extra features input due to its linearity.
with open('traj_and_dat/features_tica.dat','w+') as f:
    f.writelines('MOLINFO STRUCTURE=input.pdb\n')
    count = 0
    for pair in pairs['hbond']:
        atom1 = pair[0]
        atom2 = pair[1]
        count = count + 1
        f.writelines('pair{count}: DISTANCE ATOMS={atom1},{atom2}\n'.format(count=count,atom1=atom1,atom2=atom2))
        f.writelines('epair{count}: CUSTOM ARG=pair{count} FUNC=e^-x PERIODIC=NO\n'.format(count=count))
    for pair in pairs['CA']:
        atom1 = pair[0]
        atom2 = pair[1]
        count = count + 1
        f.writelines('pair{count}: DISTANCE ATOMS={atom1},{atom2}\n'.format(count=count,atom1=atom1,atom2=atom2))
        f.writelines('epair{count}: CUSTOM ARG=pair{count} FUNC=e^-x PERIODIC=NO\n'.format(count=count))
    for pair in pairs['CB']:
        atom1 = pair[0]
        atom2 = pair[1]
        count = count + 1
        f.writelines('pair{count}: DISTANCE ATOMS={atom1},{atom2}\n'.format(count=count,atom1=atom1,atom2=atom2))
        f.writelines('epair{count}: CUSTOM ARG=pair{count} FUNC=e^-x PERIODIC=NO\n'.format(count=count))
    for pair in pairs['CG']:
        atom1 = pair[0]
        atom2 = pair[1]
        count = count + 1
        f.writelines('pair{count}: DISTANCE ATOMS={atom1},{atom2}\n'.format(count=count,atom1=atom1,atom2=atom2))
        f.writelines('epair{count}: CUSTOM ARG=pair{count} FUNC=e^-x PERIODIC=NO\n'.format(count=count))
    for dihedral in dihedrals:
        for resid in dihedrals[dihedral]:
            f.writelines('{dihedral}-{resid}: TORSION ATOMS=@{dihedral}-{resid}\n'.format(dihedral=dihedral,resid=resid))
            f.writelines('sin{dihedral}-{resid}: CUSTOM ARG={dihedral}-{resid} FUNC=sin(x) PERIODIC=NO\n'.format(dihedral=dihedral,resid=resid))
            f.writelines('cos{dihedral}-{resid}: CUSTOM ARG={dihedral}-{resid} FUNC=cos(x) PERIODIC=NO\n'.format(dihedral=dihedral,resid=resid))
        f.writelines('\n')
    f.writelines('PRINT ARG=* STRIDE={traj_stride} FILE=COLVAR'.format(traj_stride=traj_stride))

# Write features into plumed file for SRV and RCFlow
with open('traj_and_dat/features.dat','w+') as f:
    f.writelines('MOLINFO STRUCTURE=input.pdb\n')
    count = 0
    for pair in pairs['hbond']:
        atom1 = pair[0]
        atom2 = pair[1]
        count = count + 1
        f.writelines('pair{count}: DISTANCE ATOMS={atom1},{atom2}\n'.format(count=count,atom1=atom1,atom2=atom2))
        #f.writelines('epair{count}: CUSTOM ARG=pair{count} FUNC=e^-x PERIODIC=NO\n'.format(count=count))
    for pair in pairs['CA']:
        atom1 = pair[0]
        atom2 = pair[1]
        count = count + 1
        f.writelines('pair{count}: DISTANCE ATOMS={atom1},{atom2}\n'.format(count=count,atom1=atom1,atom2=atom2))
        #f.writelines('epair{count}: CUSTOM ARG=pair{count} FUNC=e^-x PERIODIC=NO\n'.format(count=count))
    for pair in pairs['CB']:
        atom1 = pair[0]
        atom2 = pair[1]
        count = count + 1
        f.writelines('pair{count}: DISTANCE ATOMS={atom1},{atom2}\n'.format(count=count,atom1=atom1,atom2=atom2))
        #f.writelines('epair{count}: CUSTOM ARG=pair{count} FUNC=e^-x PERIODIC=NO\n'.format(count=count))
    for pair in pairs['CG']:
        atom1 = pair[0]
        atom2 = pair[1]
        count = count + 1
        f.writelines('pair{count}: DISTANCE ATOMS={atom1},{atom2}\n'.format(count=count,atom1=atom1,atom2=atom2))
        #f.writelines('epair{count}: CUSTOM ARG=pair{count} FUNC=e^-x PERIODIC=NO\n'.format(count=count))
    for dihedral in dihedrals:
        for resid in dihedrals[dihedral]:
            f.writelines('{dihedral}-{resid}: TORSION ATOMS=@{dihedral}-{resid}\n'.format(dihedral=dihedral,resid=resid))
            f.writelines('sin{dihedral}-{resid}: CUSTOM ARG={dihedral}-{resid} FUNC=sin(x) PERIODIC=NO\n'.format(dihedral=dihedral,resid=resid))
            f.writelines('cos{dihedral}-{resid}: CUSTOM ARG={dihedral}-{resid} FUNC=cos(x) PERIODIC=NO\n'.format(dihedral=dihedral,resid=resid))
        f.writelines('\n')
    f.writelines('PRINT ARG=* STRIDE={traj_stride} FILE=COLVAR'.format(traj_stride=traj_stride))

#### 2. Load featurized trajectories (without $e^{-d}$ functions and sin/cos functions) and Train FMRC

##### 2.1 FMRC code and read features

In [None]:
# RCflow
class GaussianPrior(nn.Module):
    def __init__(self, mean_value, std_value):
        super().__init__()
        self.mean_ = nn.Parameter(torch.tensor(mean_value, dtype=torch.float32), requires_grad=False)
        self.std_ = nn.Parameter(torch.tensor(std_value, dtype=torch.float32), requires_grad=False)

    def forward(self, size):
        samples = self.mean_ + self.std_ * torch.randn(size).to(self.mean_.device)
        return samples

    def sample_like(self, x):
        size = x.size()
        samples = self.forward(size).to(x.device)
        return samples


class RCFlow(nn.Module):
    def __init__(self,input_size,latent_size,encoder_hidden_size,encode_state_label,hidden_size,hidden_depth,activation,sigma,
                 learning_rate,lr_decay,lr_decay_stepsize,val_frac,batch_size,batchnorm,n_epochs,device):
        super().__init__()
        
        # Neural network related
        self.input_size = input_size               # No. of features
        self.latent_size = latent_size
        self.encoder_hidden_size = encoder_hidden_size
        self.hidden_size = hidden_size
        self.hidden_depth = hidden_depth
        self.activation = activation
        self.batchnorm = batchnorm
        self.encode_state_label = encode_state_label
        self.sigma = sigma                         # the gaussian width of flow matching vector field sample, 
                                                   # serves as a regularization factor

        # Training related
        self.learning_rate = learning_rate
        self.lr_decay = lr_decay
        self.lr_decay_stepsize = lr_decay_stepsize
        self.val_frac = val_frac
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.device = device

        # Cached attributes
        self.encoder = None
        self.L_vector_field = None
        self.D_vector_field = None
        self.optimizer = None
        self.scheduler = None
        self.train_loss = None
        self.validation_loss = None
        
        #####
        
        ### Encoder x --> r(x) as a feed forward nn
        self.encoder = []
        # first layer
        self.encoder.append(nn.Linear(self.input_size,self.encoder_hidden_size))
        # Insert dropout/batchnorm here
        # Activation
        self.encoder.append(self.activation)
        if self.batchnorm == True:
            self.encoder.append(nn.BatchNorm1d(self.encoder_hidden_size))
        # middle layers
        for i in range(self.hidden_depth-1):
            self.encoder.append(nn.Linear(self.encoder_hidden_size,self.encoder_hidden_size))
            self.encoder.append(self.activation)
            if self.batchnorm == True:
                self.encoder.append(nn.BatchNorm1d(self.encoder_hidden_size))
        
        # final output layer
        self.encoder.append(nn.Linear(self.encoder_hidden_size,self.latent_size))
        
        if self.encode_state_label == True:
            self.encoder.append(nn.Softmax())
            
        self.encoder = nn.Sequential(*self.encoder).to(self.device)

        #####

        ### nn that represents the lumpability vector field u(r(x),y,t)
        self.L_vector_field = []
        L_input_size = self.latent_size + self.input_size + 1
        
        # first layer
        self.L_vector_field.append(nn.Linear(L_input_size,self.hidden_size))
        self.L_vector_field.append(self.activation)
        if self.batchnorm == True:
            self.L_vector_field.append(nn.BatchNorm1d(self.hidden_size))
        # middle layers
        for i in range(self.hidden_depth-1):
            self.L_vector_field.append(nn.Linear(self.hidden_size,self.hidden_size))
            self.L_vector_field.append(self.activation)
            if self.batchnorm == True:
                self.L_vector_field.append(nn.BatchNorm1d(self.hidden_size))
        # final output layer
        self.L_vector_field.append(nn.Linear(self.hidden_size,self.input_size))
        self.L_vector_field = nn.Sequential(*self.L_vector_field).to(self.device)

        #####
        
        ### nn that represents the decomposibility vector field u(r(y),x,t)
        self.D_vector_field = []
        
        D_input_size = self.latent_size + self.input_size + 1
        
        # first layer
        self.D_vector_field.append(nn.Linear(D_input_size,self.hidden_size))
        self.D_vector_field.append(self.activation)
        if self.batchnorm == True:
            self.D_vector_field.append(nn.BatchNorm1d(self.hidden_size))
        # middle layers
        for i in range(self.hidden_depth-1):
            self.D_vector_field.append(nn.Linear(self.hidden_size,self.hidden_size))
            self.D_vector_field.append(self.activation)
            if self.batchnorm == True:
                self.D_vector_field.append(nn.BatchNorm1d(self.hidden_size))
        # final output layer
        self.D_vector_field.append(nn.Linear(self.hidden_size,self.input_size))
        self.D_vector_field = nn.Sequential(*self.D_vector_field).to(self.device)

        #####
        
    def encode(self,x):
        # Here, x is the variable to be encoded i.e. for lumpability x:= x for decomposibility x:= y
        # encode x --> r:=r(x)
        r = self.encoder(x)
        return r

    def sample_from_prior(self,x):
        # Sample x/y (D_loss/L_loss) from prior, prior_sample should have shape (batch_size,input_size)
        # Since this is just a gaussian with mean & s.t.d from all data, x/y can share the same prior
        prior_sample = self.prior.sample_like(x)
        return prior_sample

    def sample_t(self,x):
        # Sample t = [t_1,...t_B], t should have shape (batch_size,)
        t = torch.rand_like(x[:,:1])
        return t

    def sample_x_t(self,x,t,prior_sample): 
        # Sample x_t/y_t, the 'location' of x/y after 'time' t in the vector field
        x_t = x * t + (1-t) * prior_sample + torch.randn_like(x) * self.sigma
        return x_t

    def data_vector_field(self,x,prior_sample):
        # Calculate v_t, the 'data vector field' that we want to match our 'neural network vector field' with
        v_t = x - prior_sample
        return v_t

    def L_loss(self,x,y,t,rx=None):
        # Step 1: encode x --> r(x)
        if rx == None:
            rx = self.encode(x)
        
        # Step 2: sample y from prior
        prior_y = self.sample_from_prior(y)
        
        # Step 3: sample y_t, y at time t in the vector field
        y_t = self.sample_x_t(y,t,prior_y)
        
        # Step 4: compute 'data vector field'
        v_t = self.data_vector_field(y,prior_y)
        
        # Step 5: compute 'nn vector field'
        u_input = torch.cat((rx,y_t,t),dim=-1)
        # so that u_t is a function of rx,y_t,t only i.e. rx contains nearly same information as x
        u_t = self.L_vector_field(u_input)
        
        # Step 6: compute flow matching loss
        L_loss = torch.mean(torch.sum((u_t-v_t)**2,-1))
        
        return L_loss
    
    def D_loss(self,x,y,t,ry=None):
        # Step 1: encode y --> r(y)
        if ry == None:
            ry = self.encode(y)
        
        # Step 2: sample x from prior
        prior_x = self.sample_from_prior(x)
        
        # Step 3: sample x_t, x at time t in the vector field
        x_t = self.sample_x_t(x,t,prior_x)
        
        # Step 4: compute 'data vector field'
        v_t = self.data_vector_field(x,prior_x)
        
        # Step 5: compute 'nn vector field'
        u_input = torch.cat((ry,x_t,t),dim=-1)
        # so that u_t is a function of rx,y_t,t only i.e. rx contains nearly same information as x
        u_t = self.L_vector_field(u_input)
        
        # Step 6: compute flow matching loss
        D_loss = torch.mean(torch.sum((u_t-v_t)**2,-1))
        
        return D_loss

    def fit(self,data,lagtime):
        # NB: usually we use tica_output as data
        # Initialize the gaussian prior
        prior_mean = np.mean(np.concatenate(data),axis=0)
        prior_sigma = np.std(np.concatenate(data),axis=0)
        self.prior = GaussianPrior(prior_mean,prior_sigma).to(self.device)

        # Create time-lagged dataset: this outputs a 3d numpy array with shape (no_frame,2,no_features)
        # the second dimension represents the time-lagged pairs X_t,X_t+tau
        dataset = create_timelagged_dataset(data,lagtime)
        
        # Create training set and validation set
        n_pairs = len(dataset)
        train_size = int((1-self.val_frac)*n_pairs)
        val_size = n_pairs - train_size
        train_data, val_data = random_split(dataset,[train_size,val_size])
        train_loader = torch.utils.data.DataLoader(train_data, batch_size = self.batch_size,shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_data, batch_size = self.batch_size)  
        
        # Training
        self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.lr_decay_stepsize, gamma=self.lr_decay)
        
        train_loss = []
        validation_loss = []
        
        with torch.autograd.set_detect_anomaly(True):
            for epoch in range(1, n_epochs + 1):
                train_loss_epoch = []
                validation_loss_epoch = []
                
                # Training
                for minibatch_data in train_loader:
                    # Prepare x and y, should both in shape (batchsize,no_features)
                    x = minibatch_data[:,0,:].to(self.device)
                    y = minibatch_data[:,1,:].to(self.device)
                    # Sample t
                    t = self.sample_t(x)
                    # Compute losses
                    L_loss = self.L_loss(x,y,t)
                    D_loss = self.D_loss(x,y,t)
                    loss = L_loss + D_loss
                    # back propagation
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    # Record current minibatch loss
                    train_loss_epoch.append(loss.item())
                train_loss_epoch = np.mean(train_loss_epoch)
                train_loss.append(train_loss_epoch)
                
                # learning rate decay
                self.scheduler.step()

                # Validation
                with torch.no_grad():
                    for minibatch_data in val_loader:
                        # Prepare x and y, should both in shape (batchsize,no_features)
                        x = minibatch_data[:,0,:].to(self.device)
                        y = minibatch_data[:,1,:].to(self.device)
                        # Sample t
                        t = self.sample_t(x)
                        # Compute losses
                        L_loss = self.L_loss(x,y,t)
                        D_loss = self.D_loss(x,y,t)
                        loss = L_loss + D_loss
                        # Record current minibatch loss
                        validation_loss_epoch.append(loss.item())
                    validation_loss_epoch = np.mean(validation_loss_epoch)
                    validation_loss.append(validation_loss_epoch)

                print('Epoch {}: Train loss = {:.4f}, Validation loss = {:.4f}'.format(epoch,train_loss_epoch,validation_loss_epoch))

        self.train_loss = train_loss
        self.validation_loss = validation_loss
                
        return None
        
    def transform(self,data_concat):
        r = self.encode(torch.tensor(data_concat,dtype=torch.float32).to(self.device)).cpu().detach().numpy()
        return r

    def save_model(self,filepath):
        torch.save(self,filepath)
        return None
        
    def plot_loss(self,sim_idx):
        fig,ax = plt.subplots()
        ax.plot(self.train_loss,label='train')
        ax.plot(self.validation_loss,label='validation')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        plt.legend()
        plt.savefig('figures/round{sim_idx}_rcflow_loss.png'.format(sim_idx=sim_idx-1),dpi=600)

# We want to organize time-series trajectories of TICA eigenvectors into time-lagged pairs according to lagtime
# i.e. tica_output + tica_output_supp --> data in RCflow, should be inshape (no_time_lagged_pairs, no_features * 2)
def create_timelagged_dataset(tica_output,lagtime):
    timelagged_dataset = []
    tica_data = list(tica_output)
    for tica_data_i in tica_data:
        for i in range(tica_data_i.shape[0]-lagtime):
            lagged_pair_i = np.vstack([tica_data_i[i],tica_data_i[i+lagtime]])
            lagged_pair_i = np.array(lagged_pair_i)
            timelagged_dataset.append(lagged_pair_i)
    timelagged_dataset = torch.tensor(np.array(timelagged_dataset),dtype=torch.float32)
    return timelagged_dataset

def minmax_normalization(data,axis=0):
    data_max = data.max(axis=0)
    data_min = data.min(axis=0)
    normalized_data = (data - data_min) / (data_max - data_min)
    return normalized_data

def calculate_nmicro(data_concat):
    # Heuristic approach to determine cluster number from htmd 
    # https://github.com/Acellera/htmd/blob/master/htmd/adaptive/adaptivebandit.py
    n_microstates = int(max(100, np.round(0.6 * np.log10(data_concat.shape[0] / 1000) * 1000 + 50)))
    return n_microstates

In [None]:
# We want to organize time-series trajectories of TICA eigenvectors into time-lagged pairs according to lagtime
# i.e. tica_output + tica_output_supp --> data in RCflow, should be inshape (no_time_lagged_pairs, no_features * 2)
def create_timelagged_dataset(tica_output,lagtime):
    timelagged_dataset = []
    tica_data = list(tica_output)
    for tica_data_i in tica_data:
        for i in range(tica_data_i.shape[0]-lagtime):
            lagged_pair_i = np.vstack([tica_data_i[i],tica_data_i[i+lagtime]])
            lagged_pair_i = np.array(lagged_pair_i)
            timelagged_dataset.append(lagged_pair_i)
    timelagged_dataset = torch.tensor(np.array(timelagged_dataset),dtype=torch.float32)
    return timelagged_dataset

def minmax_normalization(data,axis=0):
    data_max = data.max(axis=0)
    data_min = data.min(axis=0)
    normalized_data = (data - data_min) / (data_max - data_min)
    return normalized_data

def calculate_nmicro(data_concat):
    # Heuristic approach to determine cluster number from htmd 
    # https://github.com/Acellera/htmd/blob/master/htmd/adaptive/adaptivebandit.py
    n_microstates = int(max(100, np.round(0.6 * np.log10(data_concat.shape[0] / 1000) * 1000 + 50)))
    return n_microstates

def rcflow_projection(tica_output_concat,normalized_r,stride,markersize,savefile=None,save=False):
    
    fig,ax = plt.subplots(ncols=2,nrows=2,figsize=(16,12))
    
    sc1 = ax[0,0].scatter(normalized_r[:,0][::stride],normalized_r[:,1][::stride],c=tica_output_concat[:,0][::stride],s=markersize)
    ax[0,0].set_xlabel('RC 1')
    ax[0,0].set_ylabel('RC 2')
    ax[0,0].set_xticks(np.linspace(0,1,11))
    ax[0,0].set_yticks(np.linspace(0,1,11))
    fig.colorbar(sc1,ax=ax[0,0],label='TICA tIC1')
    
    sc2 = ax[0,1].scatter(normalized_r[:,0][::stride],normalized_r[:,1][::stride],c=tica_output_concat[:,1][::stride],s=markersize)
    ax[0,1].set_xlabel('RC 1')
    ax[0,1].set_ylabel('RC 2')
    ax[0,1].set_xticks(np.linspace(0,1,11))
    ax[0,1].set_yticks(np.linspace(0,1,11))
    fig.colorbar(sc2,ax=ax[0,1],label='TICA tIC2')
    
    sc3 = ax[1,0].scatter(tica_output_concat[:,0][::stride],tica_output_concat[:,1][::stride],c=normalized_r[:,0][::stride],s=markersize)
    ax[1,0].set_xlabel('TICA tIC 1')
    ax[1,0].set_ylabel('TICA tIC 2')
    fig.colorbar(sc3,ax=ax[1,0],label='RC 1',ticks=np.linspace(0,1,11))
    
    sc4 = ax[1,1].scatter(tica_output_concat[:,0][::stride],tica_output_concat[:,1][::stride],c=normalized_r[:,1][::stride],s=markersize)
    ax[1,1].set_xlabel('TICA tIC 1')
    ax[1,1].set_ylabel('TICA tIC 2')
    fig.colorbar(sc4,ax=ax[1,1],label='RC 2',ticks=np.linspace(0,1,11))
    
    plt.tight_layout()
    if save == True:
        if savefile == None:
            plt.savefig('figures/rcflow_projection/rcflow_projection.png',dpi=600)
        else:
            plt.savefig(savefile,dpi=600)
    plt.show()

    return None

def run_TICA(data,lagtime,dim=None,var_cutoff=None,koopman=True):
    tica = TICA(lagtime=lagtime,dim=dim,var_cutoff=var_cutoff)
    if koopman == True:
        koopman_estimator = KoopmanWeightingEstimator(lagtime=lagtime)
        reweighting_model = koopman_estimator.fit(data).fetch_model()
        tica = tica.fit(data, weights=reweighting_model).fetch_model()
    else:
        tica = tica.fit(data).fetch_model()
    # tica is the data-fitted model, which contains eigenvalues and eigenvectors
    # tica_output is the tranformed time-series data in TICA space in shape(traj_idx,no_frames,dim)
    # tica_output_concat is tica_output in shape(traj_idx*no_frames,dim)
    tica_output = tica.transform(data)
        
    return tica,tica_output

In [None]:
# In principle, we do not need to include any e^-d functions, 
# since they are all functions of interatomic distances
data = plumed.read_as_pandas('CV/COLVAR')
data = data.drop(columns=['time'])
columns = list(data.columns.values)
for column in columns:
    if column[:3] == 'phi' or column[:3] == 'psi' or column[:3] == 'chi' or column[:5] == 'omega':
        data = data.drop(columns=[column])

In [None]:
data = [data.to_numpy()]

In [None]:
data[0].shape

##### 2.2 Models Training

In [None]:
# Run control and data pre-processing
dim = 20
var_cutoff = None
koopman = False

# FMRC training hyperparamters
rcflow_lagtime = 50                                
latent_size = 2                                    # dimensions of r
encoder_hidden_size = 256
encode_state_label = False
hidden_size = 256
hidden_depth = 3
activation = nn.ReLU()
sigma = 0.001
learning_rate = 0.001
lr_decay = 0.1
lr_decay_stepsize = 50
val_frac = 0.1
batch_size = 512
batchnorm = False
n_epochs = 100
device = 'cuda'

batchsize_transform = 1024

In [None]:
data_supp = []
train_loss_all = []
validation_loss_all = []
tica_output_supp = []
# TICA pre-processing
tica,tica_output = run_TICA(data,rcflow_lagtime,dim,var_cutoff,koopman)

In [None]:
tica_output.shape

In [None]:
input_size = tica_output.shape[2]

In [None]:
tica_output_concat = np.concatenate(tica_output)
tica_output_concat.shape

In [None]:
### Hyperparameters

#### Number of models to train
no_models = 10

for i in range(no_models):
    rcflow = RCFlow(input_size,latent_size,encoder_hidden_size,encode_state_label,hidden_size,hidden_depth,activation,sigma,
                    learning_rate,lr_decay,lr_decay_stepsize,val_frac,batch_size,batchnorm,n_epochs,device)
    print(rcflow)
    rcflow.fit(tica_output,rcflow_lagtime)
    rcflow.save_model('models/rcflow-lag10-sincos-256hidden-{i}.pt'.format(i=i))

    # Transform into RC space
    
    r = []
    for j in range(no_iteration):
        r_j = rcflow.transform(tica_output_concat[j*batchsize_transform:(j+1)*batchsize_transform])
        r.append(r_j)
    r = np.concatenate(r)
    
    # Normalize r for better clustering result
    normalized_r = minmax_normalization(r,axis=0)
    # Save figures for projection
    rcflow_projection(tica_output_concat,normalized_r,stride,markersize,i=i)

In [None]:
# Use this code to load trained model
#rcflow = torch.load('models/rcflow-lag10-sincos-2000hidden-1.pt')

In [None]:
# Transform into RC space
no_iteration = tica_output_concat.shape[0]//batchsize_transform + 1

r = []
for j in range(no_iteration):
    r_j = rcflow.transform(tica_output_concat[j*batchsize_transform:(j+1)*batchsize_transform])
    r.append(r_j)
r = np.concatenate(r)

# Normalize r for better clustering result
normalized_r = minmax_normalization(r,axis=0)

In [None]:
stride = 10
markersize = 3
rcflow_projection(tica_output_concat,normalized_r,stride,markersize)