This notebooks shows how to:
* Use the built-in network and evaluate that network
* Train a small network from scratch using the subset of ANI dataset shipped with TorchANI

# Use the buildin network

In [1]:
import torch
import torchani
import tqdm
import math

aev_computer = torchani.AEV()
model = torchani.ModelOnAEV(aev_computer, from_nc=None)
ds = torchani.Dataset(torchani.buildin_dataset_dir)
energy_shifter = torchani.EnergyShifter()
loss = torch.nn.MSELoss(size_average=False)

batch_size = math.inf
total_se = 0
total_conformations = 0

for coordinates, energies, species in tqdm.tqdm_notebook(ds.iter(batch_size)):
    total_conformations += coordinates.shape[0]
    predicted_energies = model(coordinates, species).squeeze()
    predicted_energies = energy_shifter.add_sae(predicted_energies, species)
    total_se += loss(predicted_energies, energies).item()
    
mse = total_se / total_conformations
rmse = math.sqrt(mse)
print('RMSE is:', rmse * 627.509, 'kcal/mol')

  from ._conv import register_converters as _register_converters


A Jupyter Widget


RMSE is: 26.548645933722828 kcal/mol


# Train a network

In [8]:
import torch
import torchani
import tqdm
import math
import timeit

aev_computer = torchani.AEV()

def celu(x, alpha):
    return torch.where(x > 0, x, alpha * (torch.exp(x/alpha)-1))

class AtomicNetwork(torch.nn.Module):
    
    def __init__(self):
        super(AtomicNetwork, self).__init__()
        self.output_length = 1
        self.layer1 = torch.nn.Linear(384,128).type(aev_computer.dtype).to(aev_computer.device)
        self.layer2 = torch.nn.Linear(128,128).type(aev_computer.dtype).to(aev_computer.device)
        self.layer3 = torch.nn.Linear(128,64).type(aev_computer.dtype).to(aev_computer.device)
        self.layer4 = torch.nn.Linear(64,1).type(aev_computer.dtype).to(aev_computer.device)
        
    def forward(self, aev):
        y = aev
        y = self.layer1(y)
        y = celu(y, 0.1)
        y = self.layer2(y)
        y = celu(y, 0.1)
        y = self.layer3(y)
        y = celu(y, 0.1)
        y = self.layer4(y)
        return y

model = torchani.ModelOnAEV(aev_computer, reducer=torch.sum,
                            per_species = {
                                'C' : AtomicNetwork(),
                                'H' : AtomicNetwork(),
                                'N' : AtomicNetwork(),
                                'O' : AtomicNetwork(),
                            })
energy_shifter = torchani.EnergyShifter()
loss = torch.nn.MSELoss(size_average=False)

ds = torchani.Dataset(torchani.buildin_dataset_dir)
ds.shuffle()
ds.split(('train', 0.8), ('validate', 0.1), ('test', 0.1))

batch_size = 256
backward_every = 4
optimizer = torch.optim.Adam(model.parameters(), amsgrad=True)

class Averager:
    
    def __init__(self):
        self.count = 0
        self.subtotal = 0
        
    def add(self, count, subtotal):
        self.count += count
        self.subtotal += subtotal
        
    def avg(self):
        return self.subtotal / self.count
    
def evaluate(coordinates, energies, species):
    count = coordinates.shape[0]
    pred = model(coordinates, species).squeeze()
    pred = energy_shifter.add_sae(pred, species)
    squared_error = loss(pred, energies)
    return count, squared_error

def subset_rmse(name):
    a = Averager()
    for i in ds.iter(math.inf, name):
        count, squared_error = evaluate(*i)
        a.add(count, squared_error.item())
    return math.sqrt(a.avg())

def optimize_step(mse):
    loss = 0.5 * torch.exp(2 * mse) if epoch > 10 else mse
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

batch_size = 1024
step_every = 1
batch = 1
epoch = 0
best_validate_rmse = math.inf
best_epoch = 0
print('epoch','time','training_rmse', 'validation_rmse')
start = timeit.default_timer()
while True:
    total_a = Averager()
    step_a = Averager()
    for i in ds.iter(batch_size, 'train'):
        count, squared_error = evaluate(*i)
        total_a.add(count, squared_error.item())
        step_a.add(count, squared_error)
        batch += 1
        if batch % step_every == 0:
            optimize_step(step_a.avg())
            step_a = Averager()
    if step_a.count > 0:
        optimize_step(step_a.avg())
        
    training_rmse = round(math.sqrt(total_a.avg()) * 627.509, 2)
    validate_rmse = round(subset_rmse('validate') * 627.509, 2)
    elapsed = round(timeit.default_timer() - start, 2)
    print(epoch, elapsed, training_rmse, validate_rmse)
    
    # stop if no improvement in 1000 epoches
    if validate_rmse < best_validate_rmse:
        best_validate_rmse = validate_rmse
        best_epoch = epoch
    epoch += 1
    if epoch - best_epoch > 1000:
        break

test_rmse = round(subset_rmse('test') * 627.509, 2)
print('Done training, test RMSE is', test_rmse)

epoch time training_rmse validation_rmse
0 52.01 46.14 16.31
1 100.36 26.39 15.62
2 151.88 19.67 11.68
3 198.96 17.16 12.52
4 246.32 15.35 14.11
5 293.18 12.67 12.49
6 340.02 10.34 10.1
7 386.81 8.38 9.43
8 433.52 7.14 9.65
9 483.01 6.57 9.54
10 537.41 7.6 8.98
11 591.07 6.7 9.23
12 644.96 5.83 8.83
13 699.96 5.59 8.87
14 754.26 5.06 8.69
15 820.27 4.76 8.9
16 888.07 4.55 9.05
17 957.94 4.35 9.26
18 1025.35 4.22 9.57


KeyboardInterrupt: 