In [None]:
import numpy as np
import gc
from tqdm import tqdm,trange
import matplotlib.pyplot as plt

import torch
import torchani
import os
import math
from torch.nn import ModuleList, Sequential
from collections import OrderedDict
from torch import Tensor
from typing import Tuple, NamedTuple, Optional

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Load stationary samples within the reactant (A), product (B) and the region outside of it (C)

In [None]:
flA  = '../jupyter_data/colloid_A.npy'
flB  = '../jupyter_data/colloid_B.npy'
flC  = '../jupyter_data/colloid_C.npy'

datA = np.load(flA).astype(np.float32)
datB = np.load(flB).astype(np.float32)
datC = np.load(flC).astype(np.float32)

NA = datA.shape[0]
NB = datB.shape[0]
NC = datC.shape[0]


### Divide the data into batches for the training set and validation set

In [None]:
Nbatch = 10
Nval = 5000

#Number of samples per batch
NsampC = int((4.0e4)/Nbatch)
NsampA = NsampC
NsampB = NsampC

#Obtain randomized sample indices for each dataset
idA = np.random.choice(NA,(Nbatch*NsampA+Nval),replace=False)
idB = np.random.choice(NB,(Nbatch*NsampB+Nval),replace=False)
idC = np.random.choice(NC,(Nbatch*NsampC+Nval),replace=False)

#Define the validation dataset and the species holder
xtA_val = torch.tensor(np.concatenate([datA[:Nval],np.zeros((Nval,7,1))],axis=-1)).to(device).float()
xtB_val = torch.tensor(np.concatenate([datB[:Nval],np.zeros((Nval,7,1))],axis=-1)).to(device).float()
xtC_val = torch.tensor(np.concatenate([datC[:Nval],np.zeros((Nval,7,1))],axis=-1)).to(device).float().requires_grad_(True)
species_val = torch.tensor(np.tile(np.array([0,1,1,1,1,1,1],dtype=int),(Nval,1))).to(device).int()

#Define the training dataset and the species holder
idA = idA[Nval:].reshape(Nbatch,NsampA)
idB = idB[Nval:].reshape(Nbatch,NsampB)
idC = idC[Nval:].reshape(Nbatch,NsampC)

xtA = torch.tensor(np.array([np.concatenate([datA[idA[i]],np.zeros((NsampA,7,1))],axis=-1) 
                    for i in range(Nbatch)])).to(device).float()
xtB = torch.tensor(np.array([np.concatenate([datB[idB[i]],np.zeros((NsampA,7,1))],axis=-1) 
                    for i in range(Nbatch)])).to(device).float()
xtC = [torch.tensor(np.concatenate([datC[idC[i]],np.zeros((NsampA,7,1))],axis=-1))
        .to(device).float().requires_grad_(True) for i in range(Nbatch)]
species = torch.tensor(np.tile(np.array([0,1,1,1,1,1,1],dtype=int),(NsampC,1))).to(device).int()

#Define parameters for the PBC information
L = 8.
cell = np.zeros((3,3))
np.fill_diagonal(cell,L)
cell = torch.tensor(cell).to(device).float()
pbc = torch.tensor(np.array([True,True,False])).to(device).bool()

#Delete the temporary dataset files
del idA,idB,idC,datA,datB,datC

gc.collect();

### Define the modified TorchANI and Sequential model for training the committor

In [None]:
class SpeciesEnergies(NamedTuple):
    species: Tensor
    energies: Tensor


class SpeciesCoordinates(NamedTuple):
    species: Tensor
    coordinates: Tensor
    
class ANIModel1(torch.nn.ModuleDict):
    @staticmethod
    def ensureOrderedDict(modules):
        if isinstance(modules, OrderedDict):
            return modules
        od = OrderedDict()
        for i, m in enumerate(modules):
            od[str(i)] = m
        return od

    def __init__(self, modules):
        super().__init__(self.ensureOrderedDict(modules))

    def forward(self, species_aev: Tuple[Tensor, Tensor],  # type: ignore
                cell: Optional[Tensor] = None,
                pbc: Optional[Tensor] = None) -> SpeciesEnergies:
        species, aev = species_aev
        
        assert species.shape == aev.shape[:-1]
        
        atomic_energies = self._atomic_energies((species, aev))

        return SpeciesEnergies(species, torch.sum(atomic_energies, dim=1))


    @torch.jit.export
    def _atomic_energies(self, species_aev: Tuple[Tensor, Tensor]) -> Tensor:

        species, aev = species_aev
        assert species.shape == aev.shape[:-1]
        species_ = species.flatten()
        aev = aev.flatten(0, 1)
        
        output = aev.new_zeros([species_.shape[0],10])
        tmp_species_ = torch.tile(species.flatten()[:,None],(1,10))
        sp = torch.tile(species[:,:,None],(1,1,10))

        for i, m in enumerate(self.values()):
            
            mask = (tmp_species_ == i)
            midx = mask[...,0].nonzero().flatten()

            tot = (mask.sum()/species.shape[0]/10).int()
            
            if midx.shape[0] > 0:
                input_ = aev.index_select(0, midx)
                output.masked_scatter_(mask, m(input_).reshape(species.shape[0]*tot,10))

        output = output.view_as(sp)

        return output

class Sequential1(torch.nn.ModuleList):

    def __init__(self, *modules):
        super().__init__(modules)

    def forward(self, input_: Tuple[Tensor, Tensor],  # type: ignore
                cell: Optional[Tensor] = None,
                pbc: Optional[Tensor] = None):
        z = 0
        for module in self:
            if z<2:
                input_ = module(input_, cell=cell, pbc=pbc)
            elif z == 2:
                input_ = module(input_[1])
            else:
                input_ = module(input_)
            z += 1

        return input_

### Initialize the model

In [None]:
# Define the AEV descriptors
aev_computer = torchani.AEVComputer.cover_linearly(4*np.sqrt(2), 2., 16.0, 8.0, 30, 3, 32.0, 6, 2)
aev_dim = aev_computer.aev_length

#Network for the red particle
red_network = torch.nn.Sequential(
    torch.nn.Linear(aev_dim, 5),
    torch.nn.CELU(0.1),
    torch.nn.Linear(5, 10),
    torch.nn.CELU(0.1),
    torch.nn.Linear(10, 10),
)

#Network for the black particle
black_network = torch.nn.Sequential(
    torch.nn.Linear(aev_dim, 5),
    torch.nn.CELU(0.1),
    torch.nn.Linear(5, 10),
    torch.nn.CELU(0.1),
    torch.nn.Linear(10, 10),
)

# Initialize the ANI and the Sequential model
nn = ANIModel1([red_network,  black_network])
model = Sequential1(aev_computer, nn
                    ,torch.nn.Tanh(),torch.nn.Linear(10,1),
                    torch.nn.Sigmoid()).to(device)

### Initialize the model parameters and the validation method

In [None]:
#Initialize the weights and biases of the model
def init_params(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight, a=1.0)
        torch.nn.init.zeros_(m.bias)

nn.apply(init_params);


In [None]:
#Compute the validation set loss
def validate():
    model.train(False)

    committorA = model((species_val,xtA_val),pbc=pbc,cell=cell)
    committorB = model((species_val,xtB_val),pbc=pbc,cell=cell)
    committorC = model((species_val,xtC_val),pbc=pbc,cell=cell)

    lamb = torch.autograd.grad(committorC.sum(), xtC_val, create_graph=True, retain_graph=True)[0][...,:-1]

    lossBoundary = reg*(torch.square(committorA).mean()
                        + torch.square(1-committorB).mean())
    lossBKE = torch.square(lamb).sum()/Nval

    loss = lossBoundary + lossBKE 

    model.train(True)
    return lossBKE,lossBoundary
gc.collect();

### Perform training

In [None]:
AdamW = torch.optim.Adam(model.parameters(),lr=1e-3,weight_decay=0.001) #Optimizer
Nepochs = 1000 # Number of epochs
reg = 5000. # Lagrange multipliers for the boundary conditions
losses = np.zeros((Nepochs,4)) #Array for saving the running loss

# Iterator for optimization 
tqs= trange(Nepochs,desc="Loss : ",leave=True)
desc = 'Tr BKE : {: .3f} | Tr Bdy : {: .3f} | Val BKE: {:.3f} | Val Bdy: {:.3f}'

for i in tqs:
    # Reset the optimizer after 500 steps
    if i % 500:
        AdamW = torch.optim.Adam(model.parameters(),lr=1e-3,weight_decay=0.0001)
        
    #Compute the validation loss
    val1,val2 = validate() 
    losses[i][2] = val1.cpu().detach().numpy()
    losses[i][3] = val2.cpu().detach().numpy()
    

    # Randomized iterator for selecting the batches
    rangeA = np.random.choice(Nbatch,Nbatch,replace=False)
    rangeB = np.random.choice(Nbatch,Nbatch,replace=False)
    rangeC = np.random.choice(Nbatch,Nbatch,replace=False)
    
    
    for j,k,l in zip(rangeC,rangeB,rangeA):
        
        #Compute the boundary loss on the A and B samples
        committorA = model((species, xtA[l]),pbc=pbc,cell=cell)
        committorB = model((species, xtB[k]),pbc=pbc,cell=cell)
        lossBoundary = reg*(torch.square(committorA).sum()/NsampA
                            + torch.square(1-committorB).sum()/NsampB)
        
        #Compute BKE loss on the C samples
        committorC = model((species, xtC[j]),pbc=pbc,cell=cell)
        lamb = torch.autograd.grad(committorC.sum(), xtC[j], create_graph=True,
                                   retain_graph=True)[0][...,:-1]
        lossBKE = torch.square(lamb).sum()/NsampC

        #Compute total loss and perform backpropogation
        loss = lossBoundary + lossBKE 
        AdamW.zero_grad()
        loss.backward()
        AdamW.step()
        
        #Store the loss
        losses[i][0] += lossBKE.cpu().detach().numpy()/Nbatch
        losses[i][1] += lossBoundary.cpu().detach().numpy()/Nbatch
        
        tqs.set_description(desc.format(np.log(lossBKE.cpu().detach().numpy()),
                                np.log(lossBoundary.cpu().detach().numpy()),
                                np.log(val1.cpu().detach().numpy()),
                                np.log(val2.cpu().detach().numpy())),refresh=True)
    

### Save the model

In [None]:
torch.save(model,'../jupyter_data/colloids_committor.pt')