# 277B Final - Molecular Energy Prediction
Amar Jilani

Notes:
- Trained on molecules up to 4 heavy atoms 
- Tested on 5 heavy atoms dataset


>Note: Apologies for no output cells, kernel crashed but running on the s01 file is pretty fast if you want to try

#### Data Processing

In [None]:
# imports
import sys
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# FOR GOOGLE COLAB ONLY 
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# FOR GOOGLE COLAB
# !pip install torchani

In [None]:
# SAVIO
# sys.path.append("/global/scratch/users/amarjilani/ANI-dataset/ANI-1_release")

# COLAB
# sys.path.append("/content/drive/MyDrive/")

import pyanitools as pya
import torchani

In [None]:
# parameters are from rHCNO-5.2R_16-3.5A_a4-8.params in the torchani repository
# https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params
Rcr = 5.2
Rca = 3.5
EtaR = torch.tensor([16], dtype=torch.float)
ShfR = torch.tensor([0.900000,1.168750,1.437500,1.706250,1.975000,2.243750,2.51250,2.781250,3.050000,\
                            3.318750,3.587500,3.856250,4.125000,4.39375,4.662500,4.931250], dtype=torch.float)
EtaA= torch.tensor([8], dtype=torch.float)
Zeta = torch.tensor([32], dtype=torch.float)
ShfA = torch.tensor([0.900000,1.550000,2.200000,2.850000], dtype=torch.float)
ShfZ = torch.tensor([0.19634954,0.58904862,0.9817477,1.3744468,1.7671459,2.1598449,2.552544,2.945243],
                    dtype=torch.float)
num_species = 4
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR,
                                    ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)

In [None]:
# SAVIO SMALL DATASET (use for testing/debugging)
# data_directory = '/global/scratch/users/amarjilani/ANI-dataset/ANI-1_release/ani_gdb_s01.h5'

# SAVIO PRODUCTION TRAINING SET (s01 to s04)
# data_directory = '/global/scratch/users/amarjilani/ANI-dataset/ANI-1_release/training'

# COLAB SMALL DATASET (use for testing/debugging)
# data_directory = '/content/drive/MyDrive/data/ani_gdb_s01.h5'

# COLAB PRODUCTION TRAINING SET (s01 to s04)
# data_directory = '/content/drive/MyDrive/data/'

# LOCAL TESTING 
data_directory = '../datasets/ANI-1_release/ani_gdb_s01.h5'

# Using TorchANI's built-in data loading functions
energy_shifter = torchani.utils.EnergyShifter(None)
data = torchani.data.load(data_directory)
training, validation = data.subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(0.8, 0.2)
training = training.collate(128).cache()
validation = validation.collate(128).cache()

In [None]:
def convert_aev(mol):
    """
    Converts 3D coordinates into AEV representations
    Works on single molecules or batches
    """
    elems = mol['species']
    aev = aev_computer.forward((torch.tensor(elems, dtype=torch.long),
                                     torch.tensor(mol['coordinates'], dtype=torch.float)))
    return aev

# Primary Model - Predicting Energy based on 3D Coordinates (AEVs)

In [None]:
class ANI_sub(nn.Module):
    """Sub-network for ONE type of atom"""
    def __init__(self, architecture):
        super(ANI_sub, self).__init__()
        layers = []

        # create fully connected layers
        for i in range(len(architecture) - 1):
            layers.append(nn.Linear(architecture[i], architecture[i + 1]))
            layers.append(nn.LeakyReLU()) # tested different activation functions, leakyrelu was the best
        self.network = nn.Sequential(*layers)

    def forward(self, aev):
        # take AEV as input and run through fully connected layer to calculate energy
        atomic_energy = self.network(aev)
        return atomic_energy

class ANI(nn.Module):
    """Model for calculating the energy of a specific conformation of an organic molecule consisting of H, C, N or O."""
    def __init__(self, architectures):
        super(ANI, self).__init__()
        # create subnets for each atom passed in the architectures dictionary
        self.sub_nets = nn.ModuleDict({
            atom: ANI_sub(architecture) for atom, architecture in architectures.items()
        })

    def forward(self, aevs, atom_types):
        batch_energies = []  # store the total energies for each conformation in batch

        # go through each conf in the batch
        for conf_atom_types, conf_aevs in zip(atom_types, aevs):
            atomic_energies = []

            # each conformation conssists of an aev for each atom
            for atom_type, aev in zip(conf_atom_types, conf_aevs):
                if atom_type != -1:  # exclude padding atoms
                    atomic_energy = self.sub_nets[str(atom_type.item())](aev)
                    atomic_energies.append(atomic_energy)

            # sum the energies of all atoms in the molecule comformer
            total_molecule_energy = torch.sum(torch.stack(atomic_energies))
            batch_energies.append(total_molecule_energy)

        # convert list of energies to a tensor
        total_energies = torch.stack(batch_energies)
        return total_energies

In [None]:
# architecture for each subnet, haven't played around with this yet
architectures = {
    "0": [384, 128, 64, 1], # hydrogen
    "1": [384, 128, 64, 1], # carbon
    "2":[384, 128, 64, 1], # nitrogen
    "3":[384, 128, 64, 1] # oxygen
}

model = ANI(architectures)
model = model.float()
model.sub_nets

In [None]:
# Run on GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
device

In [None]:
# Training script
loss_func = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 20

lowest_val = float('inf')
weights = model.state_dict()
losses = []
val_losses = []
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

    # training
    for mol in training:
        species, aevs = convert_aev(mol)
        species = species.to(device)
        aevs = aevs.to(device)
        energies = mol['energies'].float().to(device)
        predicted_energies = model(aevs, species)
        loss = loss_func(predicted_energies, energies)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * species.size(0)
    train_loss = train_loss / len(training)
    losses.append(train_loss)

    # validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for mol in validation:
            species, aevs = convert_aev(mol)
            species = species.to(device)
            aevs = aevs.to(device)
            energies = mol['energies'].to(device)
            predicted_energies = model(aevs, species)
            loss = loss_func(predicted_energies, energies)
            val_loss += loss.item() * species.size(0)

    val_loss = val_loss / len(validation)
    val_losses.append(val_loss)
    if val_loss < lowest_val:
        lowest_val = val_loss
        weights = model.state_dict()
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss,
            'val_loss': val_loss
        }, f'ani.pth')

    # print result
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.7f}, Val Loss: {val_loss:.7f}')



In [None]:
# plot curves
plt.plot(val_losses, label="Validation Loss")
plt.plot(losses, label="Training Loss") # getting rid of the first bc squishes the plot
plt.title("Training and Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.xticks(range(0, 20, 1))
plt.legend()

In [None]:
model = model.to("cpu")

# Alternative Method: AutoEncoder for Compression of AEVs

#### AutoEncoder class
I chose to use a similar architecture to the previous neural network, where there are sub-nets that focus on a specific atom. This way, each AEV in a molecule is being compressed.

In [None]:
class AtomEncoder(nn.Module):
    """Sub-AutoEncoder for a specific atom type"""
    def __init__(self, architecture):
        super(AtomEncoder, self).__init__()
        encoder_layers = []
        decoder_layers = []

        # encoder
        for i in range(len(architecture) - 1):
            encoder_layers.append(nn.Linear(architecture[i], architecture[i + 1]))
            encoder_layers.append(nn.LeakyReLU())
        self.encoder = nn.Sequential(*encoder_layers[:-1]) # don't use activation on the latent space

        # decoder
        for i in range(len(architecture) - 1, 0, -1):
            decoder_layers.append(nn.Linear(architecture[i], architecture[i - 1]))
            decoder_layers.append(nn.LeakyReLU())
        self.decoder = nn.Sequential(*decoder_layers[:-1])

    def forward(self, aev):
        encoded = self.encoder(aev)
        decoded = self.decoder(encoded)
        return encoded, decoded

class MolEncoder(nn.Module):
    """AutoEncoder network for a molecule, with sub-encoders for each specific atom"""
    def __init__(self, architectures):
        super(MolEncoder, self).__init__()
        self.sub_encoders = nn.ModuleDict({
            atom: AtomEncoder(architecture) for atom, architecture in architectures.items()
        })

    def forward(self, aevs, atom_types):
        batch_encoded = []
        batch_decoded = []
        for conf_atom_types, conf_aevs in zip(atom_types, aevs):
            conf_encoded = []
            conf_decoded = []

            for atom_type, aev in zip(conf_atom_types, conf_aevs):
                if atom_type != -1:  # ignore padding atoms
                    encoded, decoded = self.sub_encoders[str(atom_type.item())](aev)
                else:
                    encoded = torch.zeros(16, device=device)  # replace with appropriate encoded size
                    decoded = torch.zeros_like(aev, device=device)  # decoded padding should match AEV padding

                conf_encoded.append(encoded)
                conf_decoded.append(decoded)

            # concat encoded and decoded outputs for each molecule
            combined_encoded = torch.stack(conf_encoded)
            combined_decoded = torch.cat(conf_decoded)
            combined_decoded = combined_decoded.view_as(conf_aevs)
            batch_encoded.append(combined_encoded)
            batch_decoded.append(combined_decoded)

        # convert to tensors
        encoded_batch = torch.stack(batch_encoded)
        decoded_batch = torch.stack(batch_decoded)
        return encoded_batch, decoded_batch


In [None]:
ae_architectures = {
    "0": [384, 128, 64, 32, 16], # hydrogen
    "1": [384, 128, 64, 32, 16], # carbon
    "2":[384, 128, 64, 32, 16], # nitrogen
    "3":[384, 128, 64, 32, 16] # oxygen
}
autoencoder = MolEncoder(ae_architectures)
autoencoder = autoencoder.to(device)
device

In [None]:
ae_load = torch.load(f'ae.pth')

In [None]:
ae_load["loss"]

#### Training AutoEncoder network

In [None]:
loss_func = torch.nn.MSELoss()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=1e-3)

if ae_load:
    autoencoder.load_state_dict(ae_load['model_state_dict'])
    optimizer.load_state_dict(ae_load['optimizer_state_dict'])
    num_epochs = 20 - ae_load['epoch']
    lowest_val = ae_load['val_loss']
    losses = [ae_load['loss']]
    val_losses = [ae_load['val_loss']]

else:
    num_epochs = 20

    weights = autoencoder.state_dict()
    losses = []
    val_losses = []
    lowest_val = float('inf')

for epoch in range(num_epochs):
    autoencoder.train()  
    train_loss = 0.0

    # training
    for mol in training:
        species, aevs = convert_aev(mol)
        species = species.cuda()
        aevs = aevs.cuda()
        encoded, decoded = autoencoder(aevs, species)
        loss = loss_func(decoded, aevs)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * species.size(0)
    train_loss = train_loss / len(training)
    losses.append(train_loss)

    # val
    autoencoder.eval()
    val_loss = 0.0
    with torch.no_grad():
        for mol in validation:
            species, aevs = convert_aev(mol)
            species = species.to(device)
            aevs = aevs.to(device)
            encoded, decoded = autoencoder(aevs, species)
            loss = loss_func(decoded, aevs)
            val_loss += loss.item() * species.size(0)

    val_loss = val_loss / len(validation)
    val_losses.append(val_loss)

    # save results 
    if val_loss < lowest_val:
        lowest_val = val_loss
        weights = autoencoder.state_dict()
        torch.save({
            'epoch': epoch,
            'model_state_dict': autoencoder.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss,
            'val_loss': val_loss
        }, f'ae.pth')

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.7f}, Val Loss: {val_loss:.7f}')

In [None]:
# autoencoder training/validation curves
plt.plot(val_losses, label="Validation Loss")
plt.plot(losses, label="Training Loss")
plt.title("Training and Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()

#### Evaluating the Encoded AEV Representation on a Second Energy Prediction Model

In [None]:
architectures2 = {
    "0": [16, 64, 32, 1],
    "1": [16, 64, 32, 1],
    "2":[16, 64, 32, 1],
    "3":[16, 64, 32, 1]
}
model_2 = ANI(architectures2) # using the same class but with different architectures
model_2 = model_2.to(device)

In [None]:
loss_func = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model_2.parameters(), lr=1e-3)
num_epochs = 20

lowest_val = float('inf')
weights = model_2.state_dict()
losses = []
val_losses = []
for epoch in range(num_epochs):
    model_2.train()  # Set the model to training mode
    train_loss = 0.0

    for mol in training:
        species, aevs = convert_aev(mol)
        species = species.to(device)
        aevs = aevs.to(device)
        energies = mol['energies'].float().to(device)
        encoded_aevs, _ = autoencoder(aevs, species)
        encoded_aevs = encoded_aevs.detach() # don't want to adjust gradients in the autoencoder network
        predicted_energies = model_2(encoded_aevs, species)
        loss = loss_func(predicted_energies, energies)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * species.size(0)
    train_loss = train_loss / len(training)
    losses.append(train_loss)

    # val
    model_2.eval()
    val_loss = 0.0
    with torch.no_grad():
        for mol in validation:
            species, aevs = convert_aev(mol)
            species = species.to(device)
            aevs = aevs.to(device)
            energies = mol['energies'].float().to(device)
            encoded_aevs, _ = autoencoder(aevs, species)
            predicted_energies = model_2(encoded_aevs, species)
            loss = loss_func(predicted_energies, energies)
            val_loss += loss.item() * species.size(0)

    val_loss = val_loss / len(validation)
    val_losses.append(val_loss)

    # save model
    if val_loss < lowest_val:
        lowest_val = val_loss
        weights = model_2.state_dict()
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_2.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss,
            'val_loss': val_loss
        }, f'model_2.pth')

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.7f}, Val Loss: {val_loss:.7f}')

In [None]:
# plot model 2 training and validation curves
plt.plot(val_losses, label="Validation Loss")
plt.plot(losses, label="Training Loss")
plt.title("Training and Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()

## Testing on 5 Heavy Atoms dataset 

Now that the models have been trained on data up to 4 heavy atoms, we will look at how the models can generalize to larger molecules

#### Testing the Original Model

In [None]:
# load in trained weights
ani_load = torch.load('ani.pth')
model.load_state_dict(ani_load['model_state_dict'])

In [None]:
# load in 5 heavy atom set 
data_directory = '../datasets/ANI-1_release/ani_gdb_s05.h5'
energy_shifter = torchani.utils.EnergyShifter(None)
data = torchani.data.load(data_directory)
test = data.subtract_self_energies(energy_shifter).species_to_indices().shuffle()
test = test.collate(1024).cache()

In [None]:
import time
model.eval()
sum_squared_error = 0.0
total_samples = 0
actual_energies = []
predicted_energies_list = []
batch_counter = 0
max_batches = 20    # only going through 1024*20 of the conformations in dataset 5 
timings = []
with torch.no_grad():
    for mol in test:
        if batch_counter >= max_batches:
            break
        
        # make prediction 
        species, aevs = convert_aev(mol)
        energies = mol['energies']
        start = time.time()
        predicted_energies = model(aevs, species)
        end = time.time()

        # store values for plotting 
        actual_energies.extend(energies.numpy())
        predicted_energies_list.extend(predicted_energies.numpy())

        # calculate squared errors 
        squared_errors = (predicted_energies - energies) ** 2
        sum_squared_error += squared_errors.sum().item()
        total_samples += energies.size(0)

        batch_counter += 1
        timings.append(end - start)

# calculate RMSE
rmse = np.sqrt(sum_squared_error / total_samples)
print("RMSE: ", rmse)

# calculate average time the model took to make prediction 
avg_time = np.mean(timings)
print("Average time taken for compressed model prediction: {:.2f} seconds".format(avg_time))

In [None]:
# plot predicted vs observed 
plt.figure(figsize=(8, 8))
plt.scatter(actual_energies, predicted_energies_list, alpha=0.5)
plt.xlabel('Actual Energies (Hartrees)')
plt.ylabel('Predicted Energies (Hartrees)')
plt.title('Original Mode: True vs Predicted Energies')
plt.plot([min(actual_energies), max(actual_energies)], [min(actual_energies), max(actual_energies)], 'r')  
plt.show()

In [None]:
from scipy.stats import pearsonr

# calculate correlation coefficient 
correlation, _ = pearsonr(actual_energies, predicted_energies_list)
print("Model 1 Correlation Coefficient: ", correlation)

#### Testing the Secondary Model

In [None]:
# load in saved training data 
model2 = ANI(architectures2) 
latent = torch.load('model_2.pth')
model2.load_state_dict(latent['model_state_dict'])

autoencoder2 = MolEncoder(ae_architectures)
ae2 = torch.load('ae.pth')
autoencoder2.load_state_dict(ae2['model_state_dict'])

In [None]:
import time 
model2.eval()
sum_squared_error = 0.0
total_samples = 0
actual_energies = []            # stores values for plotting
predicted_energies_list = []    # stores values for plotting 
batch_counter = 0
max_batches = 20    # only going through 1024*20 of the conformations in dataset 5 
timings = []    # stores the timing taken for predictions 

with torch.no_grad():
    for mol in test:
        if batch_counter >= max_batches:
            break
        
        # make prediction 
        species, aevs = convert_aev(mol)
        encoded_aevs, _ = autoencoder2(aevs, species)
        energies = mol['energies']
        start_time = time.time()
        predicted_energies = model2(encoded_aevs, species)
        end_time = time.time()
        timings.append(end_time - start_time)

        # store values for plotting 
        actual_energies.extend(energies.numpy())
        predicted_energies_list.extend(predicted_energies.numpy())

        # calculate squared errors 
        squared_errors = (predicted_energies - energies) ** 2
        sum_squared_error += squared_errors.sum().item()
        total_samples += energies.size(0)

        batch_counter += 1

# calculate RMSE
rmse = np.sqrt(sum_squared_error / total_samples)
print("RMSE: ", rmse)

avg_time = np.mean(timings)
print("Average time taken for compressed model prediction: {:.2f} seconds".format(avg_time))

In [None]:
# plot model2 prediction vs observed 
plt.figure(figsize=(8, 8))
plt.scatter(actual_energies, predicted_energies_list, alpha=0.5)
plt.xlabel('Actual Energies (kcal/mol)')
plt.ylabel('Predicted Energies (kcal/mol)')
plt.title('Encoded Input Model: True vs Predicted Energies')
plt.plot([min(actual_energies), max(actual_energies)], [min(actual_energies), max(actual_energies)], 'r')  
plt.show()

In [None]:
# calculate correlation between observed and predicted for model 2
correlation, _ = pearsonr(actual_energies, predicted_energies_list)
print("Model 2 Correlation Coefficient: ", correlation)