In [None]:
import bussilab
import plumed
import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import pandas
import MDAnalysis as mda
from MDAnalysis import transformations
from MDAnalysis.analysis import distances
import MDAnalysis.analysis.distances as distances
from MDAnalysis.analysis.base import AnalysisBase

import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from skmatter.feature_selection import FPS

%matplotlib inline

In [1]:
!scp /home/frohlkin@farma.unige.ch/Projects/PathOptimization/Toymodels/Dialanine/OneWayRatchet/Path1/diala.pdb .
!scp /home/frohlkin@farma.unige.ch/Projects/PathOptimization/Toymodels/Dialanine/OneWayRatchet/Path1/topol.tpr .
!scp /home/frohlkin@farma.unige.ch/Projects/PathOptimization/Toymodels/Dialanine/OneWayRatchet/Path1/newbox_diala_DESRES* .

In [None]:
lines=[]
lines.append('MOLINFO STRUCTURE=diala.pdb\n')
lines.append('phi: TORSION ATOMS=@phi-2\n')
lines.append('psi: TORSION ATOMS=@psi-2\n')
    
lines.append('abmd: ABMD ARG=phi,psi TO=-2,2 KAPPA=5.0,5.0\n')

lines.append('COMMITTOR ...\n')
lines.append('   ARG=phi,psi\n')
lines.append('   STRIDE=1\n')
lines.append('   BASIN_LL1=-2.6,2.2\n')
lines.append('   BASIN_UL1=-2.5,2.5\n')
lines.append('...\n')

lines.append('PRINT FMT=%s STRIDE=1 FILE=COLVAR_getA ARG=* \n' %('%8.4f'))

f = open("plumed_getA.dat", "w")
for elem in lines:
    f.writelines(elem)
f.close()

In [None]:
%%bash

mpirun -n 1 gmx_mpi mdrun -s topol -plumed plumed_getA.dat -deffnm getA -nsteps 50000000 -pin on -pinoffset 0 -nb gpu &

In [None]:
#3.62 should correspond to last frame (replace)

In [None]:
%%bash
echo 0| gmx_mpi trjconv -f getA.xtc -s newbox_diala_DESRES -dump 3.62 -o stateA.pdb

In [None]:
neighbor_cutoff=2

contact_ndx=range(1,22)
sel=[(x,y) for x in contact_ndx for y in contact_ndx if abs(x - y) >= neighbor_cutoff]
sel={tuple(sorted(item)) for item in sel}
len(sel)

In [None]:
for i in range(1,6):
    lines=[]
    lines.append('MOLINFO STRUCTURE=diala.pdb\n')
    lines.append('phi: TORSION ATOMS=@phi-2\n')
    lines.append('psi: TORSION ATOMS=@psi-2\n')

    for e,elem in enumerate(sel):
        lines.append('dist%s:  DISTANCE ATOMS=%s,%s \n' %(e,elem[0],elem[1]))

    lines.append('abmd: ABMD ARG=phi,psi TO=0.8,-1.0 KAPPA=8,4\n')

    lines.append('COMMITTOR ...\n')
    lines.append('   ARG=phi,psi\n')
    lines.append('   STRIDE=1\n')
    lines.append('   BASIN_LL1=0,-2\n')
    lines.append('   BASIN_UL1=1,-1\n')
    lines.append('...\n')

    lines.append('PRINT FMT=%s STRIDE=1 FILE=COLVAR_ratchet_path1_forward_%s ARG=* \n' %('%8.4f',i))

    f = open("plumed_ratchet_path1_forward_%s.dat" %i, "w")
    for elem in lines:
        f.writelines(elem)
    f.close()

In [None]:
%%bash

for j in 1 2 3 4 5
do
    mpirun -n 1 gmx_mpi mdrun -s topol -plumed plumed_ratchet_path1_forward_$j.dat -deffnm ratchet_path1_forward_$j -nsteps 50000000 -pin on -pinoffset 0 -nb gpu &
    wait    
done

In [None]:
phi_forward=[]
psi_forward=[]

for i in range(1,6):
    
    colvar_forward=plumed.read_as_pandas("COLVAR_ratchet_path1_forward_%s" %i)
    
    print(len(colvar_forward['phi']))
    phi_forward.append(colvar_forward['phi'])
    psi_forward.append(colvar_forward['psi'])

    data=np.load('raw_data/Dialanine/Dialanine_reference_FES.npy')

phi_values = data[:,0]
psi_values = data[:,1]
file_free_values = data[:,2]

phi_grid, psi_grid = np.meshgrid(np.linspace(phi_values.min(), phi_values.max(), 100),
                                 np.linspace(psi_values.min(), psi_values.max(), 100))

file_free_grid = griddata((phi_values, psi_values), file_free_values, (phi_grid, psi_grid), method='cubic')

levels=np.arange(0,55,5)

num_bins = 11

fig, ax = plt.subplots(1, figsize=(6,6), dpi=100)
plt.contour(phi_grid, psi_grid,file_free_grid,levels,colors='grey')

for i in range(1,6):
    plt.scatter(phi_forward[i-1],psi_forward[i-1],c='black',s=1)

plt.tight_layout()
plt.show()

In [None]:
features=[]
features_plot=[]

for j in range(1,6):
    colvar=plumed.read_as_pandas("COLVAR_ratchet_path1_forward_%s"%j)
    all_data_plot=colvar.iloc[:,1:3].to_numpy()
    all_data=colvar.iloc[:,3:193].to_numpy()
    features.append(all_data)
    features_plot.append(all_data_plot)
    print(all_data.shape)

In [None]:
training_batches=[]
training_batches_plot=[]

for j in range(1,6):
        colvar=plumed.read_as_pandas("COLVAR_ratchet_path1_forward_%s"%j)

        data=np.array(colvar.iloc[:,1])
        bins = [-4,2.2,4]
        print(bins)

        hist, bin_edges = np.histogram(data, bins=bins)

        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

        all_data=np.vstack((colvar.iloc[:,1],colvar.iloc[:,2])).T
        all_data_plot=colvar.iloc[:,1:3].to_numpy()
        all_data=colvar.iloc[:,3:193].to_numpy()
        print(all_data.shape)

        min_propensity_bin = np.argmin(hist)

        min_propensity_count = hist[min_propensity_bin]

        print(min_propensity_count)

        min_propensity_count= 300
        temp =[]
        temp_plot=[]
        col_indices=[]
        for i in range(len(bin_edges) - 1):
            indices = np.where(np.logical_and(data >= bin_edges[i], data <= bin_edges[i+1]))[0]
            
            t_data=all_data[indices]
            t_data_plot=all_data_plot[indices]
            print(len(indices))
            selector = FPS(n_to_select=min_propensity_count,initialize=0)
            selector.fit(t_data.T)
            selector.transform(t_data.T)
            r_ndx=selector.selected_idx_
            
            col_indices.append(indices[r_ndx])
        
        temp_sorted_ndx=np.sort(np.concatenate(col_indices))
        training_batches.append(all_data[temp_sorted_ndx])
        training_batches_plot.append(all_data_plot[temp_sorted_ndx])

training_datapoints=np.concatenate(training_batches)
training_datapoints_plot=np.concatenate(training_batches_plot)
print(training_datapoints.shape)
plt.figure()

for elem in training_batches_plot:
    plt.scatter(elem[:,0], elem[:,1],alpha=0.2)
plt.show()

In [None]:
nneighbors=300

bin_edges = [-4,1,2,4]

min_propensity_count=nneighbors//(len(bin_edges)-1)

data=training_datapoints_plot[:,0]
col_indices=[]
for i in range(len(bin_edges) - 1):
    indices = np.where(np.logical_and(data >= bin_edges[i], data <= bin_edges[i+1]))[0]
    
    t_data=training_datapoints[indices]
    t_data_plot=training_datapoints_plot[indices]
    selector = FPS(n_to_select=min_propensity_count,initialize=0)
    selector.fit(t_data.T)
    selector.transform(t_data.T)
    r_ndx=selector.selected_idx_
    col_indices.append(indices[r_ndx])

ndx=np.concatenate(col_indices)    
neighbours=training_datapoints[ndx]
neighbours_plot=training_datapoints_plot[ndx]

print(neighbours.shape)

plt.figure()
plt.scatter(neighbours_plot[:,0],neighbours_plot[:,1],s=3,alpha=0.5)
plt.colorbar()
plt.show()

In [None]:
class AutoEncoderCV(nn.Module):

    def __init__(self,
                f: int,
                d: int,
                n: int,
                ref: torch.Tensor,
                act: str):
        

        super(AutoEncoderCV,self).__init__()
        
        # =======   LOSS  =======
        # Reconstruction (MSE) loss
        self.loss_mse = torch.nn.MSELoss()
        

        # ======= BLOCKS =======
        
        self.n_features=f
        self.n_neighbors=n
        self.d_metric=d
        self.training_datapoints=ref
        
        self.mean=torch.mean(self.training_datapoints,axis=0)
        self.std=torch.std(self.training_datapoints,axis=0)
        self.range=(torch.max(self.training_datapoints,axis=0).values-torch.min(self.training_datapoints,axis=0).values)/2
        
        if act == 'ReLU':
            self.activationf=torch.nn.ReLU()
        if act == 'Tanh':
            self.activationf=torch.nn.Tanh()
        if act == 'Sigmoid':
            self.activationf=torch.nn.Sigmoid()
        if act == 'ELU':
            self.activationf=torch.nn.ELU()
        
        self.metric = torch.nn.Sequential(
                                    torch.nn.Linear(self.n_features, 50),
                                    self.activationf,
                                    torch.nn.Linear(50, 24),
                                    self.activationf,
                                    torch.nn.Linear(24, self.d_metric))
            

        # initialize encoder
        self.encoder = torch.nn.Sequential(
                                torch.nn.Linear(int(self.n_neighbors*self.d_metric), 24),
                                self.activationf,
                                torch.nn.Linear(24, 16),
                                self.activationf,
                                torch.nn.Linear(16, 1),
                                torch.nn.Sigmoid())

        # initialize decoder
        self.decoder = torch.nn.Sequential(
                                torch.nn.Linear(1, 16),
                                self.activationf,
                                torch.nn.Linear(16, 24),
                                self.activationf,
                                torch.nn.Linear(24, self.n_features))
        

        
    
    def normalize(self,x: torch.Tensor)-> torch.Tensor:
    
        return x
    
    def denormalize(self,x: torch.Tensor)-> torch.Tensor:
            
        return x
    
    
    def softmax_w(self,x: torch.Tensor, t=1e-1) -> torch.Tensor:
        x = x / t
        x = x - torch.max(x, dim=1, keepdim=True)[0]
        return (torch.exp(x)+1e-6) / torch.sum(torch.exp(x), dim=1, keepdim=True)
        


    def soft_top_k(self,x: torch.Tensor,t: torch.Tensor) -> torch.Tensor:
        y = torch.zeros_like(x)
        
        x_w = x * (1 - y)
        x_w_softmax = self.softmax_w(x_w)
        y = y+x_w_softmax
            
        for k in range(self.n_neighbors):
            x_w = x * (1 - y)
            x_w_softmax = self.softmax_w(x_w)
            y = y+x_w_softmax
            
            dm=torch.matmul(t.T,x_w_softmax.T)
            
            if k == 0:
                dn=dm
            else:
                dn=torch.cat((dn,dm))
        return dn.T

    def learn_metric(self,x: torch.Tensor) -> torch.Tensor:
        d=self.metric(x)
        t=self.metric(self.training_datapoints)
        return d,t
    
    def find_nearest_neighbors(self,x: torch.Tensor,t: torch.Tensor) -> torch.Tensor:
        
        dist = torch.cdist(x, t)
        dist=torch.exp(-dist)
        dn = self.soft_top_k(dist,t)
        
        return dn
        
    def encode(self,x: torch.Tensor) -> torch.Tensor:
        x=self.encoder(x)
        return x
    
    def decode(self,x: torch.Tensor) -> torch.Tensor:
        x=self.decoder(x)
        return x
    
    def encode_decode(self, x: torch.Tensor) -> torch.Tensor:
        #x_norm = x 
        x_norm = self.normalize(x) 
        d,t=self.learn_metric(x_norm)
        dn=self.find_nearest_neighbors(d,t)
        
        s=self.encode(dn)
        x_pre=self.decode(s) 
        x_hat = self.denormalize(x_pre) 
        
        return x_hat,s,d,dn
    
    def forward(self, x: torch.Tensor) -> torch.Tensor :
        #x_norm = x 
        x_norm = self.normalize(x) 
        d,t=self.learn_metric(x_norm)
        dn=self.find_nearest_neighbors(d,t)
        s=self.encode(dn).reshape(-1,1)
        z=self.compute_z(x).reshape(-1,1)
        
        out=torch.hstack((s,z))
        
        return out
    
    def compute_z(self,x: torch.Tensor,l=10) -> torch.Tensor:
        x_hat,s,d,dn=model.encode_decode(training_datapoints)
        z_dist=torch.cdist(x,x_hat)
        z_dist=torch.absolute(z_dist)
        z=(-1/l)*torch.log(torch.sum(torch.exp(-l*z_dist),axis=1))

        return z

In [None]:
np.save('training_datapoints_path1.npy',training_datapoints)
np.save('training_datapoints_plot_path1.npy',training_datapoints_plot)
np.save('neighbours_path1.npy',neighbours)

training_datapoints=torch.Tensor(training_datapoints)
neighbours=torch.Tensor(neighbours)

n_features=190
d_metric=3
n_neighbors=3
activation_function='Tanh' #'ReLU'

model = AutoEncoderCV(f=n_features,d=d_metric,n=n_neighbors,ref=training_datapoints,act=activation_function)

device = torch.device("cpu")
optimizer = optim.Adam(model.parameters(), lr=0.001)

track=[]
track_rec=[]
track_equi =[]
best_loss=1e10

num_epochs = 5001
for epoch in range(num_epochs):
    train_loss = 0.0
    train_loss_rec = 0.0
    train_loss_equi = 0.0
    for data in training_batches:
        x = torch.Tensor(data)
        
        # Forward Pass
        x_hat,s,d,dn = model.encode_decode(x)

        # Compute Loss
        loss= model.loss_mse(x_hat, x)

        # Backward Pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * x.size(0)
        
    train_loss = train_loss / len(training_batches)
    
    track.append(train_loss)
    
    if epoch%100==0:
        print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch+1, train_loss))
        
        if train_loss < best_loss:
            best_loss=train_loss
            filename = 'model_DeepLNE_path1.pth'
            torch.save(model, filename)  
            torch.save(model.state_dict(), 'model_params_DeepLNE_path1.pt')

In [None]:
plt.figure()
plt.plot(track,'o-',color='orangered')
plt.yscale('log')
plt.show()

In [None]:
training_datapoints=np.load('training_datapoints_path1.npy')
neighbours=np.load('neighbours_path1.npy')

training_datapoints=torch.Tensor(training_datapoints)
neighbours=torch.Tensor(neighbours)

n_features=190
d_metric=3
n_neighbors=3
activation_function='Tanh'

model = AutoEncoderCV(f=n_features,d=d_metric,n=n_neighbors,ref=neighbours,act=activation_function)
model.load_state_dict(torch.load('model_params_DeepLNE_path1.pt'), strict=False)

neighbours_d=model.encode_decode(neighbours)[0].detach()

In [None]:
class AutoEncoderCV_Speed(nn.Module):

    def __init__(self,
                f: int,
                d: int,
                n: int,
                ref: torch.Tensor,
                ref_z: torch.Tensor,
                act: str):
        

        super(AutoEncoderCV_Speed,self).__init__()
        
        # =======   LOSS  =======
        # Reconstruction (MSE) loss
        self.loss_mse = torch.nn.MSELoss()
        

        # ======= BLOCKS =======
        
        self.n_features=f
        self.n_neighbors=n
        self.d_metric=d
        self.training_datapoints=ref
        self.training_datapoints_z=ref_z
        
        self.mean=torch.mean(self.training_datapoints,axis=0)
        self.std=torch.std(self.training_datapoints,axis=0)
        self.range=(torch.max(self.training_datapoints,axis=0).values-torch.min(self.training_datapoints,axis=0).values)/2
        
        if act == 'ReLU':
            self.activationf=torch.nn.ReLU()
        if act == 'Tanh':
            self.activationf=torch.nn.Tanh()
        if act == 'Sigmoid':
            self.activationf=torch.nn.Sigmoid()
        if act == 'ELU':
            self.activationf=torch.nn.ELU()
        
        self.metric = torch.nn.Sequential(
                                    torch.nn.Linear(self.n_features, 50),
                                    self.activationf,
                                    torch.nn.Linear(50, 24),
                                    self.activationf,
                                    torch.nn.Linear(24, self.d_metric))
            

        # initialize encoder
        self.encoder = torch.nn.Sequential(
                                torch.nn.Linear(int(self.n_neighbors*self.d_metric), 24),
                                self.activationf,
                                torch.nn.Linear(24, 16),
                                self.activationf,
                                torch.nn.Linear(16, 1),
                                torch.nn.Sigmoid())

        # initialize decoder
        self.decoder = torch.nn.Sequential(
                                torch.nn.Linear(1, 16),
                                self.activationf,
                                torch.nn.Linear(16, 24),
                                self.activationf,
                                torch.nn.Linear(24, self.n_features))
        

    def normalize(self,x: torch.Tensor)-> torch.Tensor:
        return x
    
    def denormalize(self,x: torch.Tensor)-> torch.Tensor:
        return x
    
    def softmax_w(self,x: torch.Tensor, t=1e-1) -> torch.Tensor:
        x = x / t
        x = x - torch.max(x, dim=1, keepdim=True)[0]
        return (torch.exp(x)+1e-6) / torch.sum(torch.exp(x), dim=1, keepdim=True)
        


    def soft_top_k(self,x: torch.Tensor,t: torch.Tensor) -> torch.Tensor:
        y = torch.zeros_like(x)
        
        x_w = x * (1 - y)
        x_w_softmax = self.softmax_w(x_w)
        y = y+x_w_softmax
            
        for k in range(self.n_neighbors):
            x_w = x * (1 - y)
            x_w_softmax = self.softmax_w(x_w)
            y = y+x_w_softmax
            
            dm=torch.matmul(t.T,x_w_softmax.T)
            
            if k == 0:
                dn=dm
            else:
                dn=torch.cat((dn,dm))
        return dn.T

    def learn_metric(self,x: torch.Tensor) -> torch.Tensor:
        d=self.metric(x)
        t=self.metric(self.training_datapoints)
        return d,t
    
    def find_nearest_neighbors(self,x: torch.Tensor,t: torch.Tensor) -> torch.Tensor:
        
        dist = torch.cdist(x, t)
        dist=torch.exp(-dist)
        dn = self.soft_top_k(dist,t)
        
        return dn
        
    def encode(self,x: torch.Tensor) -> torch.Tensor:
        x=self.encoder(x)
        return x
    
    def decode(self,x: torch.Tensor) -> torch.Tensor:
        x=self.decoder(x)
        return x
    
    def encode_decode(self, x: torch.Tensor) -> torch.Tensor:
        x_norm = self.normalize(x) 

        d,t=self.learn_metric(x_norm)
        dn=self.find_nearest_neighbors(d,t)
        
        s=self.encode(dn)
        x_pre=self.decode(s) 
        x_hat = self.denormalize(x_pre) 
        
        return x_hat,s,d,dn
    
    def forward(self, x: torch.Tensor) -> torch.Tensor :
        x_norm = self.normalize(x) 
        d,t=self.learn_metric(x_norm)
        dn=self.find_nearest_neighbors(d,t)
        s=self.encode(dn).reshape(-1,1)
        z=self.compute_z(x).reshape(-1,1)
        
        out=torch.hstack((s,z))
        
        return out
    
    def compute_z(self,x: torch.Tensor,l=100) -> torch.Tensor:
        z_dist=torch.cdist(x,self.training_datapoints_z)
        z_dist=torch.absolute(z_dist)
        z=(-1/l)*torch.log(torch.sum(torch.exp(-l*z_dist),axis=1))

        return z

model_speed = AutoEncoderCV_Speed(f=n_features,d=d_metric,n=n_neighbors,ref=neighbours,ref_z=neighbours_d,act=activation_function)
model_speed.load_state_dict(torch.load('model_params_DeepLNE_path1.pt'), strict=False)

In [None]:
filename = 'model_DeepLNE_path1.pth'
torch.save(model_speed, filename)  
torch.save(model_speed.state_dict(), 'model_params_DeepLNE_path1.pt')

In [None]:
input=training_datapoints

out=model_speed(input)
s=out[:,0]
z=out[:,1]
s=s.detach().numpy()
z=z.detach().numpy()

xhat=model_speed.encode_decode(input)[0].detach()

ndx=np.where(z<0.2)[0] #find cutoff for harmonic constraint

plt.figure()
plt.scatter(training_datapoints_plot[ndx,0],training_datapoints_plot[ndx,1],c=s[ndx],alpha=1)
plt.colorbar()
plt.show()
plt.figure()
plt.scatter(training_datapoints_plot[ndx,0],training_datapoints_plot[ndx,1],c=z[ndx],alpha=1)
plt.colorbar()
plt.show()
plt.figure()
plt.scatter(training_datapoints_plot[ndx,0],s[ndx],c=z[ndx],alpha=1)
plt.colorbar()
plt.show()

In [None]:
m=torch.jit.trace(model_speed,torch.ones(1,n_features))
m.save('model_DeepLNE_path1.ptc')

In [None]:
# parameters
multiply_by_stddev = True #whether to multiply derivatives by std dev of inputs
order_by_importance = True #plot results ordered by importance


features=[]
for j in range(1,6):
    colvar=plumed.read_as_pandas("COLVAR_ratchet_path1_forward_%s"%j)
    all_data=colvar.iloc[::10,3:193].to_numpy()
    features.append(all_data)
    print(all_data.shape)
    input_names = colvar.filter(regex='dist').columns.values
    print(input_names,len(input_names))
    
    n_input = len(input_names)
    in_num=np.arange(n_input)
    rank=torch.zeros(n_input)
    
    X=features[j-1]
    if multiply_by_stddev:
        in_std=torch.std(torch.Tensor(X),axis=0).numpy()
        
    for iteration,x_i in enumerate(X):    
        
        x_i = torch.Tensor(x_i.reshape(1,-1))
        x_i.requires_grad=True
        s_i,z_i = model_speed(x_i)[0]
        grad_i = torch.autograd.grad(s_i,x_i)
        rank += grad_i[0].reshape(-1).abs()
        
        if iteration%100==0:
            print(iteration)
            
    rank = rank.numpy()

    if multiply_by_stddev:
        rank = rank * in_std

    #normalize to 1
    rank/= np.sum(rank)

    #sort
    if order_by_importance:
        index= rank.argsort()
        input_names = input_names[index]
        rank = rank[index]

In [None]:
#plot
fig=plt.figure(figsize=(5,0.25*n_input), dpi=100)
ax = fig.add_subplot(111)

if order_by_importance:
    ax.barh(in_num, rank,linewidth=0.3)
    ax.set_yticklabels(input_names,fontsize=9)
else:
    ax.barh(in_num[::-1], rank[::-1],color='fessa1',edgecolor = 'fessa0',linewidth=0.3)
    ax.set_yticklabels(input_names[::-1],fontsize=9)

ax.set_xlabel('Relevance')
ax.set_ylabel('Inputs')
ax.set_yticks(in_num)
ax.yaxis.tick_right()

plt.tight_layout()
plt.show()