In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import math
import numba
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm,trange
import gc

from IPython.display import HTML
import matplotlib.animation as animation


import torch
import torchani
from torch.nn import ModuleList, Sequential
from collections import OrderedDict
from torch import Tensor
from typing import Tuple, NamedTuple, Optional

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Define system parameters

In [None]:
N = 7 #N umber of particles
L = 8. # Box length

eps = 8.0 # Epsilon for WCA
sigma = 1.0 # Sigma for WCA
rwca = np.power(2,1/6) # WCA cutoff

re = 1. # Morse minima 
a  = 6. # Morse width
De_rb = 12. # Morse depth for red-black particles
De_bb = 0. # Morse depth for black-black
De = np.zeros((N,N))
De[:] = De_bb
De[:,0] = De_rb
De[0,:] = De_rb

gamma = 0.25 # Friction coeff.
kBT = 0.5 #Friction coeff.
dt = 2.5e-5 # Timestep
etasig = np.sqrt(2*gamma*kBT/dt) # Noise std.

tf = 0.5 # Observation time for the trajectories
Nsteps = int(tf/dt) # Total timesteps for the trajectories

### Instantiate the modified TorchANI architecture for loading the trained model

In [None]:
class SpeciesEnergies(NamedTuple):
    species: Tensor
    energies: Tensor


class SpeciesCoordinates(NamedTuple):
    species: Tensor
    coordinates: Tensor
    
# Modified ANIModel 
class ANIModel1(torch.nn.ModuleDict):
    @staticmethod
    def ensureOrderedDict(modules):
        if isinstance(modules, OrderedDict):
            return modules
        od = OrderedDict()
        for i, m in enumerate(modules):
            od[str(i)] = m
        return od

    def __init__(self, modules):
        super().__init__(self.ensureOrderedDict(modules))

    def forward(self, species_aev: Tuple[Tensor, Tensor], 
                cell: Optional[Tensor] = None,
                pbc: Optional[Tensor] = None) -> SpeciesEnergies:
        species, aev = species_aev
        
        assert species.shape == aev.shape[:-1]
        
        atomic_energies = self._atomic_energies((species, aev))

        return SpeciesEnergies(species, torch.sum(atomic_energies, dim=1))


    @torch.jit.export
    def _atomic_energies(self, species_aev: Tuple[Tensor, Tensor]) -> Tensor:

        species, aev = species_aev
        assert species.shape == aev.shape[:-1]
        species_ = species.flatten()
        aev = aev.flatten(0, 1)
        
        output = aev.new_zeros([species_.shape[0],10])
        tmp_species_ = torch.tile(species.flatten()[:,None],(1,10))
        sp = torch.tile(species[:,:,None],(1,1,10))

        for i, m in enumerate(self.values()):
            
            mask = (tmp_species_ == i)
            midx = mask[...,0].nonzero().flatten()

            tot = (mask.sum()/species.shape[0]/10).int()
            
            if midx.shape[0] > 0:
                input_ = aev.index_select(0, midx)
                output.masked_scatter_(mask, m(input_).reshape(species.shape[0]*tot,10))

        output = output.view_as(sp)

        return output

# Modified Sequential Model 
class Sequential1(torch.nn.ModuleList):
    """Modified Sequential module that accept Tuple type as input"""

    def __init__(self, *modules):
        super().__init__(modules)

    def forward(self, input_: Tuple[Tensor, Tensor],
                cell: Optional[Tensor] = None,
                pbc: Optional[Tensor] = None):
        z = 0
        for module in self:
            if z<2:
                input_ = module(input_, cell=cell, pbc=pbc)
            elif z == 2:
                input_ = module(input_[1])
            else:
                input_ = module(input_)
            z += 1

        return input_

### Load the models along with the configurations within the reactant state

In [None]:
# Load and randomize configurations within the reactant state
Aconfs = np.load('../jupyter_data/colloid_A.npy')
Aconfs = Aconfs[np.random.choice(Aconfs.shape[0],Aconfs.shape[0],replace=False)]
model = torch.load('../jupyter_data/colloid_A.npy/colloid_committor.pt')

gc.collect();

### Define functions for performing the simulation without any control forces

In [None]:
# Modifies distance vector using minimum image convention
@numba.jit(nopython=True)
def periodic(rv,L):
    tmp = np.abs(rv[:,0]) > L/2
    rv[tmp,0] = (L - np.abs(rv[tmp,0])) * -np.sign(rv[tmp,0])
    tmp = np.abs(rv[:,1]) > L/2
    rv[tmp,1] = (L - np.abs(rv[tmp,1])) * -np.sign(rv[tmp,1])
    return rv

# Return the periodic distance using minimum image convention
# r and rv are holders of dimension [N,N] and [N,N,2]
# Note: Due to the small number of particles and the small
# box size, neighborlists are not used
@numba.jit(nopython=True)
def computeDistances(pos,r,rv):
    for i in range(7):
        rv[i] = periodic(pos[i] - pos,L)
    r = np.sqrt(rv[...,0]**2 + rv[...,1]**2)
    return r,rv


# Compute WCAPotential
@numba.jit(nopython=True)
def computeWCAPotential(r,holder):
    holder  =  eps*(np.power(r/sigma,-12) - np.power(r/sigma,-6) + 0.25)*(r<rwca).astype(np.float32)
    np.fill_diagonal(holder, 0.)
    return holder

# Compute Morse Potential
@numba.jit(nopython=True)
def computeMorsePotential(r,holder):
    holder  =  De*(np.exp(-2*a*(r-re)) - 2*np.exp(-a*(r-re)))
    np.fill_diagonal(holder, 0.)
    return holder

# Compute WCA Forces
@numba.jit(nopython=True)
def computeWCAForces(r,holder):
    holder  =  6*eps*np.power(r,-7)*(2*np.power(r,-6)-1)*(r<rwca).astype(np.float32)
    np.fill_diagonal(holder, 0.)
    return holder

# Compute Morse Forces
@numba.jit(nopython=True)
def computeMorseForces(r,holder):
    holder = 2*a*De*np.exp(-a*(r-re))*(np.exp(-a*(r-re))-1)
    np.fill_diagonal(holder, 0.)
    return holder


# Compute all forces. r,holder1,holder2 and holder3
# are holders of dimensions [N,N] and F and rv
# are holders of dimensions [N,N,2]
@numba.jit(nopython=True)
def computeForces(pos,r,rv,F,holder1,holder2,holder3):
    r,rv = computeDistances(pos,r,rv)
    rv[...,0] /= r
    rv[...,1] /= r
    np.fill_diagonal(rv[:,:,0],0.)
    np.fill_diagonal(rv[:,:,1],0.)
    holder1 = computeMorseForces(r,holder1)
    holder2 = computeWCAForces(r,holder2)
    holder3 = holder1+holder2
    F[:,:,0] = holder3*rv[...,0]
    F[:,:,1] = holder3*rv[...,1]
    return  np.sum(F,axis=1)

# Compute total potentials.
@numba.jit(nopython=True)
def computePotential(pos,r,rv,holder1,holder2):
    r,rv = computeDistances(pos,r,rv)
    return np.sum(computeWCAPotential(r,holder1) \
                  + computeMorsePotentialWithSwitching(r,holder2))/2

# Compute total potentials for a trajectory
def computePotentialHelper(pos,wca=True):
    r = np.zeros((N,N))
    rv = np.zeros((N,N,2))
    holder1 = np.zeros((N,N))
    holder2 = np.zeros((N,N))
    pot = np.zeros((pos.shape[0]))
    for i in range(pos.shape[0]):
        r,rv = computeDistances(pos[i],r,rv)
        pot[i]= np.sum(computeMorsePotentialWithSwitching(r,holder2))/2
        if wca:
            pot[i] += np.sum(computeWCAPotential(r,holder1))/2
    return pot

# Propagate system by 1 step
@numba.jit(nopython=True)
def langevinStep(pos,r,rv,eta,F,holder1,holder2,holder3,
          dt=dt,etasig=etasig):
    F = computeForces(pos,r,rv,F,holder1,holder2,holder3)
    eta = np.random.normal(0,etasig,(N,2))
    return periodic(pos + (F + eta)*dt,L)


# Propagate system for Nsteps
@numba.jit(nopython=True)
def runJit(Nsteps,pos,r,rv,eta,F,holder1,holder2,holder3,
           dt=dt,etasig=etasig):
    for i in range(1,Nsteps):
        pos = langevinStep(pos,r,rv,eta,F,holder1,holder2,holder3,dt,etasig)
    return pos


# Run for Nsteps and save configurations with a lag of dc
def run(Nsteps,dc,pos0,dt=dt,etasig=etasig):
    r    = np.zeros((N,N))
    rv   = np.zeros((N,N,2))
    F = np.zeros((N,N,2))
    eta =np.zeros((N,N,2))
    holder1 = np.zeros((N,N))
    holder2 = np.zeros((N,N))
    holder3 = np.zeros((N,N))
    
    Nt = int(Nsteps/dc)
    traj = np.zeros((Nt,N,2))
    traj[0] = pos0.copy()
    for i in tqdm(range(1,Nt)):
        traj[i] = runJit(dc,traj[i-1],r,rv,eta,F,holder1,holder2,holder3,
                         dt,etasig,).copy()
    return traj



### Define functions for performing simulation with the control forces

In [None]:
# Compute the conservative forces and the noises 
@numba.jit(nopython=True)
def drivenStep(pos,r,rv,eta,F,holder1,holder2,holder3,dt=dt,etasig=etasig):
    F = computeForces(pos,r,rv,F,holder1,holder2,holder3)
    eta = np.random.normal(0,etasig,(N,2))
    return F,eta
   

# Compute the control force
def computeLamb(model,r,species,pbc,cell,tau,mu2=0.089339,pB=0.410):
    # Convert the configurations to tensors
    dat = torch.tensor(r).to(device).float().requires_grad_(True)
    # Compute the committor 
    y = model((species, dat),pbc=pbc,cell=cell)
    # Compute the gradients
    grad = 2*kBT*torch.autograd.grad(y.sum(), dat, create_graph=True, retain_graph=True)[0].cpu().\
            detach().numpy()[...,:-1]
    # Compute the denominator of the control force
    y = y.cpu().detach().numpy() + pB*(np.exp(mu2*tau)-1.)
    return grad/y[:,None,:]


# Obtain Ntraj driven trajectories integrated for Nsteps
# initiated at pos0
# Returns 3 arrays of dimensions [Ntraj,Nsteps,N,2]
# containing the trajectory, noises and the control forces
# respectively
def runDriven(Ntraj,Nsteps,pos0,dt=dt,etasig=etasig):
    # Define the arrays arrays for storing
    # trajectories, noises and control forces
    traj = np.zeros((Ntraj,Nsteps,N,2))
    traj[:,0] = pos0.copy()
    eta =np.zeros((Ntraj,Nsteps,N,2))
    lamb = np.zeros((Ntraj,Nsteps,N,2))
    
    # Define the holders required for the simulation
    r    = np.zeros((Ntraj,N,N))
    rv   = np.zeros((Ntraj,N,N,2))
    F = np.zeros((Ntraj,N,N,2))
    holder1 = np.zeros((Ntraj,N,N))
    holder2 = np.zeros((Ntraj,N,N))
    holder3 = np.zeros((Ntraj,N,N))
    
    
    # Define the arrays and tensors for computing the forces
    # using the TorchANI module
    dat = np.zeros((Ntraj,N,3),dtype=np.float32) # Holder for the configurations
    cell = np.zeros((3,3)) # Cell size for PBC information 
    np.fill_diagonal(cell,L) 
    cell = torch.tensor(cell).to(device).float()
    pbc = torch.tensor(np.array([True,True,False])).to(device).bool() # Dimensions that have PBC
    species = torch.tensor(np.tile(np.array([0,1,1,1,1,1,1],dtype=int),
                                   (Ntraj,1))).to(device).int() # Species of the system
    
    # Perform simulation
    for i in tqdm(range(1,Nsteps)):
        for j in range(Ntraj):
            # Compute conservative forces and the noises
            F[j],eta[j,i-1] = drivenStep(traj[j,i-1],r[j],rv[j],eta[j,i-1]
                                       ,F[j],holder1[j],holder2[j],holder3[j],
                                         dt=dt,etasig=etasig)
            # Prepare data for computing control forces
            dat[j,:,:-1] = periodic(traj[j,i-1]-traj[j,i-1,:1],L)
        
        # Compute the control forces
        lamb[:,i-1] = computeLamb(model,dat,species,pbc=pbc,cell=cell,tau=(Nsteps-i-1)*dt)
        # Propagate the system by one step
        traj[:,i] = traj[:,i-1] +  dt*(F+eta[:,i-1]+lamb[:,i-1])/gamma
    
    return traj,eta,lamb


In [None]:
Ntraj = 2000
pos0 = Aconfs[np.random.choice(Aconfs.shape[0],Ntraj,replace=False)]
traj, eta, lamb, = runDriven(tot,Nsteps,pos0,dt,etasig)

In [None]:
#Shift all the trajectories s.t. the red particle is in the center
traj = np.array([[periodic(traj[i][j][:] - traj[i][j][:1],L)  for j in range(traj.shape[1])] \
                  for i in tqdm(range(traj.shape[0]))])



In [None]:
np.save('../jupyter_data/colloids_reactive_driven.npy',np.concatenate([traj,eta,lamb],axis=-1))