# ANI-1 Molecular Potential Prediction using Pre-trained EGNN and Transformer-Encoder

This file is used to record the complete process and workflow of model training/fine-tuning. Data processing classes, utility functions and model substructures were encapsulated in the models folder, which will not be shown in this file. If needed, please refer to the [Github Repository](https://github.com/Curtis-Wu/Equivariant-Graph-Transformer).

### General Import Statements

In [3]:
import os
import gc
import csv
import time
import yaml
import h5py
import math
import shutil
import numpy as np
from copy import deepcopy
from datetime import datetime
from utils import adjust_learning_rate
from sklearn.metrics import mean_absolute_error, mean_squared_error

import torch
import torch.nn as nn
import torch.multiprocessing
import torch.nn.functional as F
from torch.optim import AdamW
from torch_scatter import scatter
from torch_cluster import radius_graph
from torch.utils.tensorboard import SummaryWriter
torch.multiprocessing.set_sharing_strategy('file_system')

### Model Architecture Creation

In [None]:
config = {
  # EGNN part of the model
  "hidden_channels": 256,    # Number of hidden_channels
  "num_edge_feats": 0,           # Number of additional edge features
  "num_egcl": 5,             # Number of EGCL layers
  "residual": True,          # Residual calculation
  "attention": True,         # Graph Attention mechanism
  "normalize": True,         # Interatomic distance normalization
  "tanh": False,             # Additional activation after each layer
  "cutoff": 5.0,            # Interatomic distance curoff
  "max_atom_type": 28,       # Max atom types
  "max_chirality_type": 5,   # Max chirality type
  "max_num_neighbors": 32,   # Max number of neighborgoods

  # Transformer-Encoder part of the model
  "d_model": 256,            # Embeddings for each token
  "num_heads": 4,            # Number of self-attention heads
  "dropout_r": 0.1,          # Dropout rate
  "num_ffn": 256,            # Number of neurons in the feedforward MLP
  "num_encoder": 2,          # Number of encoder units

  # Energy Head
  "num_neuron": 512         # NUmber of neurons for the final energy head
}

In [21]:
from EGNN import E_GCL
from TF_Encoder import EncoderLayer

class EGTF(nn.Module):
    def __init__(self, # EGNN/EGCL parameters
                 hidden_channels, num_edge_feats = 0, num_egcl = 4, 
                 act_fn = nn.SiLU(), residual = True, attention = True,
                 normalize = False, tanh = False, max_atom_type = 100, 
                 cutoff = 5.0, max_num_neighbors = 32, 
                 # Transformer-Encoder parameters
                 d_model = 256, num_encoder = 2, num_heads = 8,
                 num_ffn = 512, dropout_r = 0.1,
                 # Energy Head parameter
                 num_neurons = 512):

        super(EGTF, self).__init__()
        # self.hidden_channels = hidden_channels
        self.n_layers = num_egcl
        # self.max_atom_type = max_atom_type
        self.cutoff = cutoff
        self.max_num_neighbors = max_num_neighbors
        # Create embeddings of dimension (hidden_channels, ) for each atom type
        self.type_embedding = nn.Embedding(max_atom_type, hidden_channels)
        
        # EGC layers
        for i in range(0, num_egcl):
            self.add_module("gcl_%d" % i, E_GCL(
                input_nf = hidden_channels, 
                output_nf = hidden_channels, 
                hidden_nf = hidden_channels, 
                add_edge_feats = num_edge_feats,
                act_fn=act_fn, residual=residual, 
                attention=attention, normalize=normalize, tanh=tanh))

        # Transformer-Encoder layers
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, num_ffn, dropout_r) 
                                             for _ in range(num_encoder)])

        # Energy Head
        self.energy_fc = nn.Sequential(
            nn.Linear(num_neurons, num_neurons),
            nn.SiLU(),
            nn.Linear(num_neurons, 1)
        )

    def forward(self, z, pos, batch, edge_index=None, edge_feats=None):
        h = self.type_embedding(z)
        x = deepcopy(pos)
        if edge_index is None:
            # Calculates edge_index from graph structure based on cutoff radius
            edge_index = radius_graph(
                pos,
                r=self.cutoff,
                batch=batch,
                loop=False,
                max_num_neighbors=self.max_num_neighbors + 1,
            )
        # EGC layers
        for i in range(0, self.n_layers):
            h, x, _ = self._modules["gcl_%d" % i](h, edge_index, x, edge_feats=edge_feats)
        # Encoder layers
        for layer in self.encoder_layers:
            x = layer(x)
        # Energy Head
        out = self.energy_fc(x)

        return out        

In [22]:
# Load the pre-trained model
pretrained_model = torch.load('model.pth', map_location=torch.device('cpu'))
# Initialize the modified model
modified_model = EGTF(# EGNN/EGCL parameters
                 hidden_channels = 256, num_edge_feats = 0, num_egcl = 3, 
                 act_fn = nn.SiLU(), residual = True, attention = True,
                 normalize = False, tanh = False, max_atom_type = 28, 
                 cutoff = 5.0, max_num_neighbors = 32, 
                 # Transformer-Encoder parameters
                 d_model = 256, num_encoder = 2, num_heads = 8,
                 num_ffn = 512, dropout_r = 0.1,
                 # Energy Head parameter
                 num_neurons = 256)
# Load weights
modified_model.load_state_dict(pretrained_model, strict=False);

### Data Processing and Model Initialization