In [1]:
import torch
import torchani
import tqdm
import math
import timeit
import matplotlib.pyplot as plt
%pylab inline

aev_computer = torchani.AEV()

def celu(x, alpha):
    return torch.where(x > 0, x, alpha * (torch.exp(x/alpha)-1))

class AtomicNetwork(torch.nn.Module):
    
    def __init__(self):
        super(AtomicNetwork, self).__init__()
        self.output_length = 1
        self.layer1 = torch.nn.Linear(384,64).type(aev_computer.dtype).to(aev_computer.device)
        self.layer2 = torch.nn.Linear(64,32).type(aev_computer.dtype).to(aev_computer.device)
        self.layer3 = torch.nn.Linear(32,1).type(aev_computer.dtype).to(aev_computer.device)
        
    def forward(self, aev):
        y = aev
        y = self.layer1(y)
        y = celu(y, 0.1)
        y = self.layer2(y)
        y = celu(y, 0.1)
        y = self.layer3(y)
        return y

model = torchani.ModelOnAEV(aev_computer, reducer=torch.sum,
                            per_species = {
                                'C' : AtomicNetwork(),
                                'H' : AtomicNetwork(),
                                'N' : AtomicNetwork(),
                                'O' : AtomicNetwork(),
                            })
energy_shifter = torchani.EnergyShifter()
loss = torch.nn.MSELoss(size_average=False)

ds = torchani.Dataset(torchani.buildin_dataset_dir)
ds.shuffle()

batch_size = 256
backward_every = 4
optimizer = torch.optim.Adam(model.parameters(), amsgrad=True)

class Averager:
    
    def __init__(self):
        self.count = 0
        self.subtotal = 0
        
    def add(self, count, subtotal):
        self.count += count
        self.subtotal += subtotal
        
    def avg(self):
        return self.subtotal / self.count
    
def evaluate(coordinates, energies, species):
    count = coordinates.shape[0]
    pred = model(coordinates, species).squeeze()
    pred = energy_shifter.add_sae(pred, species)
    squared_error = loss(pred, energies)
    return count, squared_error

def optimize_step(mse):
    optimizer.zero_grad()
    mse.backward()
    optimizer.step()

step_size = 1024
for step_every in [1, 2, 4, 6, 16, 32]:
    batch_size = step_size // step_every
    batch = 1
    elapsed_times = []
    rmses = []
    
    start = timeit.default_timer()
    step_a = Averager()
    for i in ds.iter(batch_size):
        count, squared_error = evaluate(*i)
        step_a.add(count, squared_error)
        batch += 1
        if batch % step_every == 0:
            optimize_step(step_a.avg())   
            rmse = round(math.sqrt(step_a.avg()) * 627.509, 2)
            elapsed = round(timeit.default_timer() - start, 2)
            elapsed_times.append(elapsed)
            rmses.append(rmse)
            step_a = Averager() 
    if step_a.count > 0:
        optimize_step(step_a.avg())
        rmse = round(math.sqrt(step_a.avg()) * 627.509, 2)
        elapsed = round(timeit.default_timer() - start, 2)
        elapsed_times.append(elapsed)
        rmses.append(rmse)
    
    plt.plot(rmses)
    plt.xlabel('Step')
    plt.ylabel('RMSE (kcal/mol)')
    plt.show()
    
    plt.plot(elapsed_times, rmses)
    plt.xlabel('Elapsed time (seconds)')
    plt.ylabel('RMSE (kcal/mol)')
    plt.show()

  from ._conv import register_converters as _register_converters


Populating the interactive namespace from numpy and matplotlib


KeyboardInterrupt: 