# ALIGNN/ALIGNN-d graph representation for molecular structures

This notebook will guide you through encoding molecular structures into the ALIGNN/ALIGNN-d graph representation.

You will need to install `ase` for reading molecular structure files and creating molecular graphs.

In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import ase.io

## Read structure files

In [3]:
d_mol = ase.io.read('./data/D-alanine.mol')
l_mol = ase.io.read('./data/L-alanine.mol')

# Remove hydrogens
d_mol = d_mol[d_mol.numbers != 1]
l_mol = l_mol[l_mol.numbers != 1]

## Convert to PyG graph data

In the ALIGNN formulation, the input representation consists of two graphs: an original atomic graph ("G") and its angular graph ("A"). The cell block below uses helper functions (which have dependency on `ase`) from `graphite` to extract atom type, bond lengths, bond angles, and dihedral angles and store them in a custom PyG `Data` object with the following fields:
- `edge_index_G`: atomic graph ("G") connectivity
- `edge_index_A`: angular graph ("A") connectivity
- `x_atm`: atom type information (in onehot encoding)
- `x_bnd`: bond lengths
- `x_ang`: bond angles and optionally dihedral angles (in radians)
- `mask_dih_ang`: mask vector indicating which angles are dihedral angles, if any

To omit dihedral angles (i.e., the original ALIGNN formulation), skip the dihedral angle calculation and set `mask_dih_ang` to `None`.

In [4]:
from ase.neighborlist import neighbor_list
from sklearn.preprocessing import OneHotEncoder
from graphite.utils.alignn import line_graph, dihedral_graph, get_bnd_angs, get_dih_angs
from graphite.data import AngularGraphPairData

# OVITO's default element-specific cutoffs for creating molecular bonds
ovito_cutoff = {
    ('H', 'C'): 1.74, ('H', 'N'): 1.65,  ('H', 'O'): 1.632,
    ('C', 'C'): 2.04, ('C', 'N'): 1.95,  ('C', 'O'): 1.932,
    ('N', 'N'): 1.86, ('N', 'O'): 1.842, ('O', 'O'): 1.824,
}

def get_molecular_graph(atoms):
    """Returns edge indices of strong chemical bonds according to a pre-defined
    element-specific cutoff criteria.  
    """
    dummy_cell = np.diag([30, 30, 30])
    dummy_pbc  = np.array([False]*3)
    i, j, d = neighbor_list('ijd', atoms, cutoff=ovito_cutoff)
    return np.stack((i, j)), d

def atoms2pygdata(atoms):
    """Converts ASE `atoms` into a PyG graph data holding the molecular graph (G) and the angular graph (A).
    The angular graph holds both bond and dihedral angles.
    """
    edge_index_G, x_bnd = get_molecular_graph(atoms)
    edge_index_bnd_ang = line_graph(edge_index_G)
    edge_index_dih_ang = dihedral_graph(edge_index_G)
    edge_index_A = np.hstack([edge_index_bnd_ang, edge_index_dih_ang])
    x_atm = OneHotEncoder(sparse_output=False).fit_transform(atoms.numbers.reshape(-1,1))
    x_bnd_ang = get_bnd_angs(atoms, edge_index_G, edge_index_bnd_ang)
    x_dih_ang = get_dih_angs(atoms, edge_index_G, edge_index_dih_ang)
    x_ang = np.concatenate([x_bnd_ang, x_dih_ang])
    mask_dih_ang = [False]*len(x_bnd_ang) + [True]*len(x_dih_ang)
    
    data = AngularGraphPairData(
        edge_index_G = torch.tensor(edge_index_G, dtype=torch.long),
        edge_index_A = torch.tensor(edge_index_A, dtype=torch.long),
        x_atm        = torch.tensor(x_atm,        dtype=torch.float),
        x_bnd        = torch.tensor(x_bnd,        dtype=torch.float),
        x_ang        = torch.tensor(x_ang,        dtype=torch.float),
        mask_dih_ang = torch.tensor(mask_dih_ang, dtype=torch.bool),
    )
    return data

Using toy data, L-alanine and D-alanine, as examples, we can see that while all the bond angles are the same between the L- and D- counterparts, some of the dihedral angles are clearly different. Many conventional graph neural networks cannot capture such chiral distinctions.

By the way, the angles are double-counted because the atomic graph (and its angular graph) is bidirectional under the GNN/message-passing formulation.

In [9]:
import pandas as pd

l_mol_data = atoms2pygdata(l_mol)
d_mol_data = atoms2pygdata(d_mol)

angles = pd.DataFrame(
    {
        'Is a dihedral angle': l_mol_data.mask_dih_ang.numpy(),
        'l-mol angles': l_mol_data.x_ang.rad2deg(),
        'd-mol angles': d_mol_data.x_ang.rad2deg(),
    }
)
pd.set_option("display.precision", 1)
angles

Unnamed: 0,Is a dihedral angle,l-mol angles,d-mol angles
0,False,109.7,109.7
1,False,108.1,108.1
2,False,109.7,109.7
3,False,110.8,110.8
4,False,108.1,108.1
5,False,110.8,110.8
6,False,123.2,123.2
7,False,110.4,110.4
8,False,123.2,123.2
9,False,126.4,126.4


## Define model architecture

Following DeepMind's nomenclature in the MeshGraphNets formulation, a graph (message-passing) neural network can be described as three components:
- Encoder
    - Initial encoding/embedding that transforms graph-based data into whatever format neccessary for graph convolutions or message-passing.
    - Often customized depending on the user's data format. For example, the original ALIGNN work also encodes bond types (e.g., single-bond, double-bond, triple-bond, etc.) but such information is not encoded in this example.
- Processor
    - Graph convolutions or message-passing layers. Refer to the ALIGNN paper for how it is implemented.
    - Normally does not need to be customized.
- Decoder
    - Final operation that transforms latent node/edge features into an appropriate output format.
    - Often customized by the user for a specific output format. For example, local nodal output vs. global pooled output.

In [6]:
from graphite.nn.models.alignn import Encoder, Processor, Decoder, ALIGNN

# In case it has not been made clear, this ALIGNN implementaition can encode dihedral angles
gnn = ALIGNN(
    encoder   = Encoder(num_species=3, cutoff=2.0, dim=128, dihedral=True),
    processor = Processor(num_convs=5, dim=128),
    decoder   = Decoder(node_dim=128, out_dim=2),
)

## Model forward pass

For now I'm only demonstrating a successful forward pass. Model training and inference are currently omitted.

In [7]:
gnn(l_mol_data)

tensor([[ 0.0957,  0.1131],
        [ 0.0341,  0.1423],
        [ 0.1306,  0.3828],
        [ 0.1108,  0.0066],
        [ 0.2235, -0.3803],
        [ 0.4413, -0.1710]], grad_fn=<AddmmBackward0>)