In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import wandb
from pytorch_lightning.loggers import WandbLogger
import torch
import torch.nn as nn
import torch.optim as optim
from ocpmodels.datasets import LmdbDataset
from torch.utils.data import random_split
import torch_geometric.loader as geom_loader
import torch_geometric.data as data
from typing import Any
import yaml
import torch.nn.init as init
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


# EGformer

In [68]:
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import logging
import os
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
from torch_geometric.nn import radius_graph
from torch_scatter import scatter, segment_coo
import torch.nn.utils.rnn as rnn_utils
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import (
    compute_neighbors,
    conditional_grad,
    get_max_neighbors_mask,
    get_pbc_distances,
    radius_graph_pbc,
    scatter_det
    
)
from ocpmodels.datasets import LmdbDataset
from ocpmodels.modules.scaling.compat import load_scales_compat

from ocpmodels.models.gemnet_oc.initializers import get_initializer
from ocpmodels.models.gemnet_oc.interaction_indices import (
    get_mixed_triplets,
    get_quadruplets,
    get_triplets,
)
from ocpmodels.models.gemnet_oc.layers.atom_update_block import OutputBlock
from ocpmodels.models.gemnet_oc.layers.base_layers import Dense, ResidualLayer
from ocpmodels.models.gemnet_oc.layers.efficient import BasisEmbedding
from ocpmodels.models.gemnet_oc.layers.embedding_block import AtomEmbedding, EdgeEmbedding
from ocpmodels.models.gemnet_oc.layers.force_scaler import ForceScaler
from ocpmodels.models.gemnet_oc.layers.interaction_block import InteractionBlock
from ocpmodels.models.gemnet_oc.layers.radial_basis import RadialBasis
from ocpmodels.models.gemnet_oc.layers.spherical_basis import CircularBasisLayer, SphericalBasisLayer
from ocpmodels.models.gemnet_oc.utils import (
    get_angle,
    get_edge_id,
    get_inner_idx,
    inner_product_clamped,
    mask_neighbors,
    repeat_blocks,
)
from ocpmodels.trainers import ForcesTrainer
from ocpmodels import models
from ocpmodels.models.gemnet_oc.gemnet_oc import GemNetOC
from Models.encoder_layer import EncoderLayer

@registry.register_model('EGformer')
class EGformer(GemNetOC):
    """
    Arguments
    ---------
    num_atoms (int): Unused argument
    bond_feat_dim (int): Unused argument
    num_targets: int
        Number of prediction targets.

    num_spherical: int
        Controls maximum frequency.
    num_radial: int
        Controls maximum frequency.
    num_blocks: int
        Number of building blocks to be stacked.

    emb_size_atom: int
        Embedding size of the atoms.
    emb_size_edge: int
        Embedding size of the edges.
    emb_size_trip_in: int
        (Down-projected) embedding size of the quadruplet edge embeddings
        before the bilinear layer.
    emb_size_trip_out: int
        (Down-projected) embedding size of the quadruplet edge embeddings
        after the bilinear layer.
    emb_size_quad_in: int
        (Down-projected) embedding size of the quadruplet edge embeddings
        before the bilinear layer.
    emb_size_quad_out: int
        (Down-projected) embedding size of the quadruplet edge embeddings
        after the bilinear layer.
    emb_size_aint_in: int
        Embedding size in the atom interaction before the bilinear layer.
    emb_size_aint_out: int
        Embedding size in the atom interaction after the bilinear layer.
    emb_size_rbf: int
        Embedding size of the radial basis transformation.
    emb_size_cbf: int
        Embedding size of the circular basis transformation (one angle).
    emb_size_sbf: int
        Embedding size of the spherical basis transformation (two angles).

    num_before_skip: int
        Number of residual blocks before the first skip connection.
    num_after_skip: int
        Number of residual blocks after the first skip connection.
    num_concat: int
        Number of residual blocks after the concatenation.
    num_atom: int
        Number of residual blocks in the atom embedding blocks.
    num_output_afteratom: int
        Number of residual blocks in the output blocks
        after adding the atom embedding.
    num_atom_emb_layers: int
        Number of residual blocks for transforming atom embeddings.
    num_global_out_layers: int
        Number of final residual blocks before the output.

    regress_forces: bool
        Whether to predict forces. Default: True
    direct_forces: bool
        If True predict forces based on aggregation of interatomic directions.
        If False predict forces based on negative gradient of energy potential.
    use_pbc: bool
        Whether to use periodic boundary conditions.
    scale_backprop_forces: bool
        Whether to scale up the energy and then scales down the forces
        to prevent NaNs and infs in backpropagated forces.

    cutoff: float
        Embedding cutoff for interatomic connections and embeddings in Angstrom.
    cutoff_qint: float
        Quadruplet interaction cutoff in Angstrom.
        Optional. Uses cutoff per default.
    cutoff_aeaint: float
        Edge-to-atom and atom-to-edge interaction cutoff in Angstrom.
        Optional. Uses cutoff per default.
    cutoff_aint: float
        Atom-to-atom interaction cutoff in Angstrom.
        Optional. Uses maximum of all other cutoffs per default.
    max_neighbors: int
        Maximum number of neighbors for interatomic connections and embeddings.
    max_neighbors_qint: int
        Maximum number of quadruplet interactions per embedding.
        Optional. Uses max_neighbors per default.
    max_neighbors_aeaint: int
        Maximum number of edge-to-atom and atom-to-edge interactions per embedding.
        Optional. Uses max_neighbors per default.
    max_neighbors_aint: int
        Maximum number of atom-to-atom interactions per atom.
        Optional. Uses maximum of all other neighbors per default.
    enforce_max_neighbors_strictly: bool
        When subselected edges based on max_neighbors args, arbitrarily
        select amongst degenerate edges to have exactly the correct number.
    rbf: dict
        Name and hyperparameters of the radial basis function.
    rbf_spherical: dict
        Name and hyperparameters of the radial basis function used as part of the
        circular and spherical bases.
        Optional. Uses rbf per default.
    envelope: dict
        Name and hyperparameters of the envelope function.
    cbf: dict
        Name and hyperparameters of the circular basis function.
    sbf: dict
        Name and hyperparameters of the spherical basis function.
    extensive: bool
        Whether the output should be extensive (proportional to the number of atoms)
    forces_coupled: bool
        If True, enforce that |F_st| = |F_ts|. No effect if direct_forces is False.
    output_init: str
        Initialization method for the final dense layer.
    activation: str
        Name of the activation function.
    scale_file: str
        Path to the pytorch file containing the scaling factors.

    quad_interaction: bool
        Whether to use quadruplet interactions (with dihedral angles)
    atom_edge_interaction: bool
        Whether to use atom-to-edge interactions
    edge_atom_interaction: bool
        Whether to use edge-to-atom interactions
    atom_interaction: bool
        Whether to use atom-to-atom interactions

    scale_basis: bool
        Whether to use a scaling layer in the raw basis function for better
        numerical stability.
    qint_tags: list
        Which atom tags to use quadruplet interactions for.
        0=sub-surface bulk, 1=surface, 2=adsorbate atoms.
    latent: bool
        Decide if output the latent space or not.
    
    """
    def __init__(
        self,
        num_atoms: Optional[int],
        bond_feat_dim: int,
        num_targets: int,
        num_spherical=7,
        num_radial=128,
        num_blocks=4,
        emb_size_atom=256,
        emb_size_edge=512,
        emb_size_trip_in=64,
        emb_size_trip_out=64,
        emb_size_quad_in=32,
        emb_size_quad_out=32,
        emb_size_aint_in=64,
        emb_size_aint_out=64,
        emb_size_rbf=16,
        emb_size_cbf=16,
        emb_size_sbf=32,
        num_before_skip=2,
        num_after_skip=2,
        num_concat=1,
        num_atom=3,
        num_output_afteratom=3,
        num_atom_emb_layers = 0,
        num_global_out_layers = 2,
        regress_forces = True,
        direct_forces = False,
        use_pbc = True,
        scale_backprop_forces = False,
        cutoff = 12.0,
        cutoff_qint = 12.0,
        cutoff_aeaint = 12.0,
        cutoff_aint = 12.0,
        max_neighbors = 30,
        max_neighbors_qint =8,
        max_neighbors_aeaint =20,
        max_neighbors_aint = 1000,
        enforce_max_neighbors_strictly = True,
        rbf = {"name": "gaussian"},
        rbf_spherical = None,
        envelope = {"name": "polynomial", "exponent": 5},
        cbf = {"name": "spherical_harmonics"},
        sbf = {"name": "spherical_harmonics"},
        extensive = True,
        forces_coupled = False,
        output_init = "HeOrthogonal",
        activation = "silu",
        quad_interaction = True,
        atom_edge_interaction = True,
        edge_atom_interaction = True,
        atom_interaction = True,
        scale_basis = False,
        qint_tags = [1, 2],
        num_elements = 83,
        otf_graph = True,
        scale_file = None,
        latent: bool = True,
        num_heads=4,
        emb_size_in=256,
        emb_size_trans=64,
        out_layer1=32,
        out_layer2=1,
        num_attn=4,
        **kwargs,  # backwards compatibility with deprecated arguments
    ):
        super().__init__(
                        num_atoms,
                        bond_feat_dim,
                        num_targets,
                        num_spherical=7,
                        num_radial=128,
                        num_blocks=4,
                        emb_size_atom=256,#256
                        emb_size_edge=512,#512
                        emb_size_trip_in=64,
                        emb_size_trip_out=64,
                        emb_size_quad_in=32,
                        emb_size_quad_out=32,
                        emb_size_aint_in=64,
                        emb_size_aint_out=64,
                        emb_size_rbf=16,
                        emb_size_cbf=16,
                        emb_size_sbf=32,
                        num_before_skip=2,
                        num_after_skip=2,
                        num_concat=1,
                        num_atom=3,
                        num_output_afteratom=3,
                        num_atom_emb_layers = 0,
                        num_global_out_layers = 2,
                        regress_forces = True,
                        direct_forces = False,
                        use_pbc = True,
                        scale_backprop_forces = False,
                        cutoff = 12.0,
                        cutoff_qint = 12.0,
                        cutoff_aeaint = 12.0,
                        cutoff_aint = 12.0,
                        max_neighbors = 30,
                        max_neighbors_qint =8,
                        max_neighbors_aeaint =20,
                        max_neighbors_aint = 1000,
                        enforce_max_neighbors_strictly = True,
                        rbf = {"name": "gaussian"},
                        rbf_spherical = None,
                        envelope = {"name": "polynomial", "exponent": 5},
                        cbf = {"name": "spherical_harmonics"},
                        sbf = {"name": "spherical_harmonics"},
                        extensive = True,
                        forces_coupled = False,
                        output_init = "HeOrthogonal",
                        activation = "silu",
                        quad_interaction = True,
                        atom_edge_interaction = True,
                        edge_atom_interaction = True,
                        atom_interaction = True,
                        scale_basis = False,
                        qint_tags = [1, 2],
                        num_elements = 83,
                        otf_graph = True,
                        scale_file = None,
                        num_heads=4,
                        emb_size_in=256,
                        emb_size_trans=64,
                        out_layer1=64,
                        out_layer2=1,                        
                        num_attn=4                        
                        )
        
        self.num_heads=num_heads 
        self.num_attn=num_attn       
        self.out_layer1=out_layer1
        self.out_layer2=out_layer2
        self.dense1=nn.Sequential(nn.Linear(emb_size_in,emb_size_trans),
                            nn.SiLU()
                            )
        self.encoder=EncoderLayer(emb_size_trans,8,emb_size_trans)
        self.layer_norm = nn.LayerNorm(emb_size_trans)        
        self.dense2E=nn.Sequential(nn.Linear(emb_size_trans*5,out_layer1),
                            nn.SiLU(),
                            nn.Linear(out_layer1,out_layer2)                                 
                    )
        self.dense2F=nn.Sequential(nn.Linear(emb_size_trans,out_layer1),
                    nn.SiLU(),
                    nn.Linear(out_layer1,out_layer2)                                 
                    )

    def forward(self, data):
        pos = data.pos
        batch = data.batch
        #print(len(data))    
        atomic_numbers = data.atomic_numbers.long()
        
        num_atoms = atomic_numbers.shape[0]
        if self.regress_forces and not self.direct_forces:
            pos.requires_grad_(True)

        (
            main_graph,
            a2a_graph,
            a2ee2a_graph,
            qint_graph,
            id_swap,
            trip_idx_e2e,
            trip_idx_a2e,
            trip_idx_e2a,
            quad_idx,
        ) = self.get_graphs_and_indices(data)
        # print('checkpoint1')
        _, idx_t = main_graph["edge_index"]

        (
            basis_rad_raw,
            basis_atom_update,
            basis_output,
            bases_qint,
            bases_e2e,
            bases_a2e,
            bases_e2a,
            basis_a2a_rad,
        ) = self.get_bases(
            main_graph=main_graph,
            a2a_graph=a2a_graph,
            a2ee2a_graph=a2ee2a_graph,
            qint_graph=qint_graph,
            trip_idx_e2e=trip_idx_e2e,
            trip_idx_a2e=trip_idx_a2e,
            trip_idx_e2a=trip_idx_e2a,
            quad_idx=quad_idx,
            num_atoms=num_atoms,
        )
        # Embedding block
        h = self.atom_emb(atomic_numbers)
        # (nAtoms, emb_size_atom)
        m = self.edge_emb(h, basis_rad_raw, main_graph["edge_index"])
        # (nEdges, emb_size_edge)
        
        x_E, x_F = self.out_blocks[0](h, m, basis_output, idx_t)
        
        xs_E, xs_F = [x_E], [x_F]
       
        # (nAtoms, num_targets), (nEdges, num_targets)

        for i in range(self.num_blocks):
            # Interaction block
            h, m = self.int_blocks[i](
                h=h,
                m=m,
                bases_qint=bases_qint,
                bases_e2e=bases_e2e,
                bases_a2e=bases_a2e,
                bases_e2a=bases_e2a,
                basis_a2a_rad=basis_a2a_rad,
                basis_atom_update=basis_atom_update,
                edge_index_main=main_graph["edge_index"],
                a2ee2a_graph=a2ee2a_graph,
                a2a_graph=a2a_graph,
                id_swap=id_swap,
                trip_idx_e2e=trip_idx_e2e,
                trip_idx_a2e=trip_idx_a2e,
                trip_idx_e2a=trip_idx_e2a,
                quad_idx=quad_idx,
            )  # (nAtoms, emb_size_atom), (nEdges, emb_size_edge)

            x_E,x_F= self.out_blocks[i + 1](h, m, basis_output, idx_t)
            # (nAtoms, emb_size_atom), (nEdges, emb_size_edge)
            xs_E.append(x_E)
            xs_F.append(x_F)
        # Previous is the Gemnet part, and the following is the Transformer part
        E_t = torch.stack(xs_E, dim=-1)
        #(E_t.shape==n_batch x 256 x 5)
        E_t=E_t.permute(0,2,1)
        E_t = self.dense1(E_t)
        E_t = self.layer_norm(E_t)
        for _ in range(self.num_attn):             
            E_t, _ = self.encoder(E_t, mask=None)
            E_t = self.layer_norm(E_t)
        E_t=E_t.reshape(E_t.shape[0],-1)
        E_t=self.dense2E(E_t)
        #F_t=self.dense2F(E_t)                  
        #(E_t.shape)=n_atoms*1*1
        nMolecules = torch.max(batch) + 1        
        E_t = scatter_det(
            E_t, batch, dim=0, dim_size=nMolecules, reduce="add"
        )

        return E_t

# Workflow

In [3]:
#Save the model hyperparameters to a YAML file
# model_hparams ={"num_atoms":0,
#                 "bond_feat_dim":0,
#                 "num_targets":0,
#                 "num_heads":4,
#                 "batch_size":4
#                 }

# with open('params/model_hparams.yml', 'w') as file:
#     yaml.dump(model_hparams, file)

In [69]:
with open('params/model_hparams.yml', 'r') as file:
    hyper_config = yaml.load(file, Loader=yaml.FullLoader)
warmup_epochs = hyper_config['configs'].get("warmup_epochs")
decay_epochs = hyper_config['configs'].get("decay_epochs")
y_mean = hyper_config['configs'].get("y_mean")
y_std = hyper_config['configs'].get("y_std")
num_epochs =hyper_config['configs'].get("num_epochs")
batch_size = hyper_config['configs'].get("batch_size")
learning_rate = hyper_config['configs'].get("learning_rate")

In [None]:

# with open('params/model_hparams.yml', 'r') as file:
#     loaded_model_hparams = yaml.load(file, Loader=yaml.FullLoader)['model']
# # Create the model using the loaded hyperparameters
# model = EGformer(**loaded_model_hparams)
# checkpoint_path='params/best_model_all.pt'
# pretrained_state_dict = torch.load(checkpoint_path)['MODEL_STATE']
# model.load_state_dict(pretrained_state_dict)

In [70]:
# with open('params/model_hparams.yml', 'r') as file:
#     loaded_model_hparams = yaml.load(file, Loader=yaml.FullLoader)['model']
# Create the model using the loaded hyperparameters
loaded_model_hparams ={"num_atoms":0,
                "bond_feat_dim":0,
                "num_targets":1,
                "num_heads":4,
                
                }
model = EGformer(**loaded_model_hparams)
checkpoint_path=os.path.join('params', 'gemnet_oc_base_oc20_oc22.pt')
pretrained_state_dict = torch.load(checkpoint_path)['state_dict']
new_model_state_dict = model.state_dict()
filtered_pretrained_state_dict = {k.strip('module.module.'): v for k, v in pretrained_state_dict.items() if k.strip('module.module.') in new_model_state_dict}
new_model_state_dict.update(filtered_pretrained_state_dict)
model.load_state_dict(new_model_state_dict)
for param_name, param in model.named_parameters():
    if param_name in filtered_pretrained_state_dict.keys():                
        param.requires_grad = False     



In [71]:
f_paras,t_paras=0,0
for param_name,param in model.named_parameters():
    if param.requires_grad is False:
        f_paras+=1
    else:
        t_paras+=1
print('Freeze params is',f_paras)
print('Need optimiz params is',t_paras)

Freeze params is 238
Need optimiz params is 159


In [72]:
warmup_epochs=2
decay_epochs=3
    
y_mean=-7
y_std=6
num_epochs=10
batch_size = 6
learning_rate=0.001
CHECKPOINT_PATH="./checkpoints"
DEVICE=torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset=LmdbDataset({"src":"/shareddata/ocp/ocp22/oc22_trajectories/trajectories/Transformer_clean_valid/data.0000.lmdb"})

criterion=nn.MSELoss()
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)
warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda epoch: (epoch+1)/warmup_epochs)
decay_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=decay_epochs,gamma=0.1)

In [21]:
train_length = int(0.8 * len(dataset))
val_length = len(dataset) - train_length

# Split the dataset into train and validation
train_dataset, val_dataset =random_split(dataset, [train_length, val_length])
train_loader = geom_loader.DataLoader(train_dataset, batch_size=4,drop_last=True)
val_loader = geom_loader.DataLoader(val_dataset, batch_size=4,drop_last=True)

In [29]:
data = next(iter(train_loader))
batched_data = data.batch
print(batched_data.shape)

torch.Size([311])


In [73]:
#test of model
test_input=model(data)

In [75]:
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Number of trainable parameters: {num_trainable_params}")

Number of trainable parameters: 5872834


In [74]:
test_input

tensor([[-33.6354],
        [-22.4121],
        [-34.7020],
        [-26.5363]], grad_fn=<ScatterAddBackward0>)

In [None]:
from tqdm import tqdm

In [None]:
def get_loss(predictions, targets, y_mean=-7, y_std=6):
    masks = (targets.y/ targets.natoms - y_mean) / y_std
    mask_loss = nn.MSELoss()
    mask_acc=nn.L1Loss()
    loss = mask_loss(predictions.view(-1, 1), masks.view(-1, 1))
    accuracy = mask_acc(predictions.view(-1, 1) , masks.view(-1, 1))
    return loss, accuracy

In [None]:
def train_fn(data_loader, model, optimizer, device, optimize_after=8):
    model.train()
    total_loss = 0.0
    iteration = 0

    for images in tqdm(data_loader):
        model=model.to(device)
        images = images.to(device)
        optimizer.zero_grad()
        predictions = model(images)
        targets = images
        loss, acc = get_loss(predictions, targets)
        loss.backward()

        # Accumulate gradients for a specified number of iterations
        iteration += 1
        if iteration % optimize_after == 0:
            optimizer.step()
            iteration = 0
            total_loss += loss.item()
        break

    return total_loss / (len(data_loader) // optimize_after)

In [None]:
def eval_fn(data_loader,model,device):
    model.eval()
    total_loss=0.0
    total_acc=0    
    with torch.no_grad():
        for images in tqdm(data_loader):
            model=model.to(device) 
            images=images.to(device)
            predictions = model(images)
            targets = images            
            loss, acc =get_loss(predictions, targets)
            total_loss += loss.item()
            total_acc += acc.item()

    return total_loss/len(data_loader),total_acc/len(data_loader)

In [None]:
best_valid_loss=np.Inf
for epoch in range(num_epochs):
    train_loss=train_fn(train_loader,model,optimizer,device=DEVICE)
    valid_loss,acc=eval_fn(val_loader,model,device=DEVICE)
    if epoch<warmup_epochs:
        warmup_scheduler.step()
    else:
        decay_scheduler.step()    
    if valid_loss< best_valid_loss:
        torch.save(model.state_dict(),'best_model_n.pt')
        print('saved-model')
        best_valid_loss=valid_loss
    current_lr=optimizer.param_groups[0]['lr']
    
    print(f'epoch:{epoch+1} Train_loss:{train_loss} Valid_loss:{valid_loss} Valid acc:{acc} lr:{current_lr}')