# Final Project

This notebook is adapted from here: https://aiqm.github.io/torchani/examples/nnp_training.html

## Checkpoint 1: Data preparation

1. Create a working directory: `/global/scratch/users/[USER_NAME]/[DIR_NAME]`. Replace the [USER_NAME] with yours and specify a [DIR_NAME] you like.
2. Copy the Jupyter Notebook to the working directory
3. Download the ANI dataset `ani_dataset_gdb_s01_to_s04.h5` from bCourses and upload it to the working directory

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torchani
import matplotlib.pyplot as plt

### Use GPU

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
print(device)

### Set up AEV computer

#### AEV: Atomic Environment Vector (atomic features)

Ref: Chem. Sci., 2017, 8, 3192

In [None]:
def init_aev_computer():
    Rcr = 5.2
    Rca = 3.5
    EtaR = torch.tensor([16], dtype=torch.float, device=device)
    ShfR = torch.tensor([
        0.900000, 1.168750, 1.437500, 1.706250, 
        1.975000, 2.243750, 2.512500, 2.781250, 
        3.050000, 3.318750, 3.587500, 3.856250, 
        4.125000, 4.393750, 4.662500, 4.931250
    ], dtype=torch.float, device=device)


    EtaA = torch.tensor([8], dtype=torch.float, device=device)
    Zeta = torch.tensor([32], dtype=torch.float, device=device)
    ShfA = torch.tensor([0.90, 1.55, 2.20, 2.85], dtype=torch.float, device=device)
    ShfZ = torch.tensor([
        0.19634954, 0.58904862, 0.9817477, 1.37444680, 
        1.76714590, 2.15984490, 2.5525440, 2.94524300
    ], dtype=torch.float, device=device)

    num_species = 4
    aev_computer = torchani.AEVComputer(
        Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species
    )
    return aev_computer

aev_computer = init_aev_computer()
aev_dim = aev_computer.aev_length
print(aev_dim)

### Prepare dataset & split

In [None]:
def load_ani_dataset(dspath):
    self_energies = torch.tensor([
        0.500607632585, -37.8302333826,
        -54.5680045287, -75.0362229210
    ], dtype=torch.float, device=device)
    energy_shifter = torchani.utils.EnergyShifter(None)
    species_order = ['H', 'C', 'N', 'O']

    dataset = torchani.data.load(dspath)
    dataset = dataset.subtract_self_energies(energy_shifter, species_order)
    dataset = dataset.species_to_indices(species_order)
    dataset = dataset.shuffle()
    return dataset

dataset = load_ani_dataset("./ani_gdb_s01_to_s04.h5")
# Use dataset.split method to do split
train_data, val_data, test_data = dataset.split(.8,.1,.1)

### Batching

In [15]:
batch_size = 64

In [10]:
#dataset.collate(...).cache() method to do batching
# train_data_loader = train_data.collate(batch_size).cache()
# val_data_loader = val_data.collate(batch_size).cache()
# test_data_loader = test_data.collate(batch_size).cache()

### Torchani API

In [None]:
class AtomicNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(384, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
    
    def forward(self, x):
        return self.layers(x)

net_H = AtomicNet()
net_C = AtomicNet()
net_N = AtomicNet()
net_O = AtomicNet()

# ANI model requires a network for each atom type
# use torch.ANIModel() to compile atomic networks
ani_net = torchani.ANIModel([net_H, net_C, net_N, net_O]).to(device)
model = nn.Sequential(
    aev_computer,
    ani_net
).to(device)

In [7]:
# train_data_batch = next(iter(train_data_loader))
# loss_func = nn.MSELoss()
# species = train_data_batch['species'].to(device)
# coords = train_data_batch['coordinates'].to(device)
# true_energies = train_data_batch['energies'].to(device).float()
# _, pred_energies = model((species, coords))
# loss = loss_func(true_energies, pred_energies)
# print(loss)

NameError: name 'train_data_loader' is not defined

In [None]:
class ANITrainer:
    def __init__(self, model, batch_size, learning_rate, epoch, l2):
        self.model = model
        
        num_params = sum(item.numel() for item in model.parameters())
        print(f"{model.__class__.__name__} - Number of parameters: {num_params}")
        
        self.batch_size = batch_size
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=l2)
        self.epoch = epoch
    
    def train(self, train_data, val_data, early_stop=True, draw_curve=True):
        ### Eric's comment: here you should pass in train_data, val_data, not dataloader
        
        
        self.model.train()
        
        # init data loader
        print("Initialize training data...")
        ### Eric's comment: call the collate().cache() here to init data loader
        train_data_loader = train_data_loader = train_data.collate(batch_size).cache()
        
        # definition of loss function: MSE is a good choice! 
        loss_func = torch.nn.MSELoss()
        
        # record epoch losses
        train_loss_list = []
        val_loss_list = []
        lowest_val_loss = np.inf
        
        for i in tqdm(range(self.epoch), leave=True):
            train_epoch_loss = 0.0
            for train_data_batch in train_data_loader:
                
                #computer energies
                species = train_data_batch['species'].to(device)
                coords = train_data_batch['coordinates'].to(device)
                true_energies = train_data_batch['energies'].to(device).float()
                _, pred_energies = model((species, coords))
                
                #compute loss
                batch_loss = loss_func(true_energies, pred_energies)
                
                
                # do a step
                ### Eric's comment: here you need to do optimization, follow the HW code
                self.optimizer.step()
                
                
                batch_importance = len(train_data_batch) / len(train_data)
                
                
                
                ### Eric's comment: instead of directly using batch_loss, please use 
                ### batch_loss.detach().cpu().item(), please refer to the previous HW code
                train_epoch_loss += batch_loss.detach().cpu().item() * batch_importance

            # use the self.evaluate to get loss on the validation set 
            val_epoch_loss = self.evaluate(val_data, loss_func)
            
            # append the losses
            ### Eric's comment: train_epoch_loss should not divided by len(train_data) anymore
            ### because it is already multiplied by the batch_importance
            train_loss_list.append(train_epoch_loss)
            val_loss_list.append(val_epoch_loss)
            
            if early_stop:
                if val_epoch_loss < lowest_val_loss:
                    lowest_val_loss = val_epoch_loss
                    weights = self.model.state_dict()
        
        if draw_curve:
            fig, ax = plt.subplots(1, 1, figsize=(5, 4), constrained_layout=True)
            ax.set_yscale("log")
            # Plot train loss and validation loss
            ax.plot(range(len(train_loss_list)), train_loss_list, label='Train')
            ax.plot(range(len(val_loss_list)), val_loss_list, label='Validation')
            ax.legend()
            ax.set_xlabel("# Epoch")
            ax.set_ylabel("Loss")
        
        if early_stop:
            self.model.load_state_dict(weights)
        
        return train_loss_list, val_loss_list
    
    
    def evaluate(self, data, draw_plot=True):
        
        # init data loader
        ### Eric's comment: again, call the collate().cache() here to init data loader
        data_loader = data.collate(batch_size).cache()
        
        # init loss function
        loss_func = torch.nn.MSELoss()
        total_loss = 0.0
        
        if draw_plot:
            true_energies_all = []
            pred_energies_all = []
            
        with torch.no_grad():
            for train_data_batch in data_loader:
                ### Eric's comment: here the name train_data_batch is not appropriate because it
                ### necessarily not train data
                
                #compute energies
                species = train_data_batch['species'].to(device)
                coords = train_data_batch['coordinates'].to(device)
                true_energies = train_data_batch['energies'].to(device).float()
                _, pred_energies = model((species, coords))
                
                #computer loss
                batch_loss = loss_func(true_energies, pred_energies)
                
                ### Eric's comment: here should be len(data) because the argument you passed in 
                ### is called data, not train_data
                batch_importance = len(train_data_batch) / len(train_data)
                ### Eric's comment: again, instead of directly using batch_loss, please use 
                ### batch_loss.detach().cpu().item(), please refer to the previous HW code
                total_loss += batch_loss.detach().cpu().item() * batch_importance
                
                if draw_plot:
                    true_energies_all.append(true_energies.detach().cpu().numpy().flatten())
                    pred_energies_all.append(pred_energies.detach().cpu().numpy().flatten())

        if draw_plot:
            true_energies_all = np.concatenate(true_energies_all)
            pred_energies_all = np.concatenate(pred_energies_all)
            # Report the mean absolute error
            # The unit of energies in the dataset is hartree
            # please convert it to kcal/mol when reporting the mean absolute error
            # 1 hartree = 627.5094738898777 kcal/mol
            # MAE = mean(|true - pred|)
            hartree2kcalmol = 627.5094738898777
            mae = np.mean(np.abs((true_energies_all - pred_energies_all) * hartree2kcalmol)) 
            fig, ax = plt.subplots(1, 1, figsize=(5, 4), constrained_layout=True)
            ax.scatter(true_energies_all, pred_energies_all, label=f"MAE: {mae:.2f} kcal/mol", s=2)
            ax.set_xlabel("Ground Truth")
            ax.set_ylabel("Predicted")
            xmin, xmax = ax.get_xlim()
            ymin, ymax = ax.get_ylim()
            vmin, vmax = min(xmin, ymin), max(xmax, ymax)
            ax.set_xlim(vmin, vmax)
            ax.set_ylim(vmin, vmax)
            ax.plot([vmin, vmax], [vmin, vmax], color='red')
            ax.legend()
            
        return total_loss


In [None]:
learning_rate = 1e-3
num_epochs = 15
l2 = 0.0  # L2 regularization (weight decay)
batch_size = 16324

# Initialize ANITrainer
trainer = ANITrainer(model, batch_size, learning_rate, num_epochs, l2)

# Train the model
trainer.train(train_data, val_data, early_stop=True, draw_curve=True)
