<a href="https://colab.research.google.com/github/John-D-Boom/geom_virtual_nodes/blob/main/Experiment_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title [RUN] Install geometric and chem modules
import os
import torch
assert torch.cuda.is_available(), "WARNING! You are running on a non-GPU instance. For this practical a GPU is highly recommended."
if 'IS_GRADESCOPE_ENV' not in os.environ:
    print('Installing scatter')
    !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
    print('Installing sparse')
    !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
    # print('Installing cluster')
    # !pip install -q torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html
    print('Installing pytorch geometric')
    !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
    print('Installing rdkit')
    !pip install -q rdkit-pypi==2021.9.4
    print('Installing py3Dmol')
    !pip install -q py3Dmol
else:
    print('already installed. Not repeating')
    print('To uninstall: !pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster  --y')



Installing scatter
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling sparse
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling pytorch geometric
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone
Installing rdkit
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.8/20.8 MB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling py3Dmol


In [None]:
#@title [RUN] Import python modules

import os
import sys
import time
import random
import numpy as np
import copy

from scipy.stats import ortho_group

import torch
import torch.nn.functional as F
from torch.nn import Linear, ReLU, BatchNorm1d, Module, Sequential

import torch_geometric
from torch_geometric.data import ClusterData
from torch_geometric.data import Data
from torch_geometric.data import Batch
from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops, to_dense_adj, dense_to_sparse
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.datasets import QM9
from torch_scatter import scatter

import rdkit.Chem as Chem
from rdkit.Geometry.rdGeometry import Point3D
from rdkit.Chem import QED, Crippen, rdMolDescriptors, rdmolops
from rdkit.Chem.Draw import IPythonConsole

import py3Dmol
from rdkit.Chem import AllChem

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from google.colab import files
from IPython.display import HTML

print("All imports succeeded.")
print("Python version {}".format(sys.version))
print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))

All imports succeeded.
Python version 3.9.16 (main, Dec  7 2022, 01:11:51) 
[GCC 9.4.0]
PyTorch version 2.0.0+cu118
PyG version 2.4.0


In [None]:
#@title [RUN] Set random seed for deterministic results

def seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed(0)
print("All seeds set.")

All seeds set.


In [None]:
#@title [RUN] Helper functions for data preparation

class SetTarget:
    """
    This transform modifies the labels vector per data sample to only keep 
    the label for a specific target (there are 19 targets in QM9).

    Note: for this practical, we have hardcoded the target to be target #0,
    i.e. the electric dipole moment of a drug-like molecule.
    (https://en.wikipedia.org/wiki/Electric_dipole_moment)
    """
    def __call__(self, data):
        target = 0 # we hardcoded choice of target  
        data.y = data.y[:, target]
        return data


class CompleteGraph:
    """
    This transform adds all pairwise edges into the edge index per data sample, 
    then removes self loops, i.e. it builds a fully connected or complete graph
    """
    def __call__(self, data):
        device = data.edge_index.device

        row = torch.arange(data.num_nodes, dtype=torch.long, device=device)
        col = torch.arange(data.num_nodes, dtype=torch.long, device=device)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data

print("Helper functions loaded.")

Helper functions loaded.


In [None]:
#@title [RUN] Helper functions for visualization

allowable_atoms = [
    "H",
    "C",
    "N",
    "O",
    "F",
    "C",
    "Cl",
    "Br",
    "I",
    "H", 
    "Unknown",
]

def to_atom(t):
    try:
        return allowable_atoms[int(t.argmax())]
    except:
        return "C"


def to_bond_index(t):
    t_s = t.squeeze()
    return [1, 2, 3, 4][
        int(
            torch.dot(
                t_s,
                torch.tensor(
                    range(t_s.size()[0]), dtype=torch.float, device=t.device
                ),
            ).item()
        )
    ]

def to_rdkit(data, device=None):
    has_pos = False
    node_list = []
    for i in range(data.x.size()[0]):
        node_list.append(to_atom(data.x[i][:5]))

    # create empty editable mol object
    mol = Chem.RWMol()
    # add atoms to mol and keep track of index
    node_to_idx = {}
    invalid_idx = set([])
    for i in range(len(node_list)):
        if node_list[i] == "Stop" or node_list[i] == "H":
            invalid_idx.add(i)
            continue
        a = Chem.Atom(node_list[i])
        molIdx = mol.AddAtom(a)
        node_to_idx[i] = molIdx

    added_bonds = set([])
    for i in range(0, data.edge_index.size()[1]):
        ix = data.edge_index[0][i].item()
        iy = data.edge_index[1][i].item()
        bond = to_bond_index(data.edge_attr[i])  # <font color='red'>TODO</font> fix this
        # bond = 1
        # add bonds between adjacent atoms

        if data.edge_attr[i].sum() == 0:
          continue

        if (
            (str((ix, iy)) in added_bonds)
            or (str((iy, ix)) in added_bonds)
            or (iy in invalid_idx or ix in invalid_idx)
        ):
            continue
        # add relevant bond type (there are many more of these)

        if bond == 0:
            continue
        elif bond == 1:
            bond_type = Chem.rdchem.BondType.SINGLE
            mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
        elif bond == 2:
            bond_type = Chem.rdchem.BondType.DOUBLE
            mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
        elif bond == 3:
            bond_type = Chem.rdchem.BondType.TRIPLE
            mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
        elif bond == 4:
            bond_type = Chem.rdchem.BondType.SINGLE
            mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)

        added_bonds.add(str((ix, iy)))

    if has_pos:
        conf = Chem.Conformer(mol.GetNumAtoms())
        for i in range(data.pos.size(0)):
            if i in invalid_idx:
                continue
            p = Point3D(
                data.pos[i][0].item(),
                data.pos[i][1].item(),
                data.pos[i][2].item(),
            )
            conf.SetAtomPosition(node_to_idx[i], p)
        conf.SetId(0)
        mol.AddConformer(conf)

    # Convert RWMol to Mol object
    mol = mol.GetMol()
    mol_frags = rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
    largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
    return largest_mol


def MolTo3DView(mol, size=(300, 300), style="stick", surface=False, opacity=0.5):
    """Draw molecule in 3D
    
    Args:
    ----
        mol: rdMol, molecule to show
        size: tuple(int, int), canvas size
        style: str, type of drawing molecule
               style can be 'line', 'stick', 'sphere', 'carton'
        surface, bool, display SAS
        opacity, float, opacity of surface, range 0.0-1.0
    Return:
    ----
        viewer: py3Dmol.view, a class for constructing embedded 3Dmol.js views in ipython notebooks.
    """
    assert style in ('line', 'stick', 'sphere', 'carton')

    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol)
    AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
    mblock = Chem.MolToMolBlock(mol)
    viewer = py3Dmol.view(width=size[0], height=size[1])
    viewer.addModel(mblock, 'mol')
    viewer.setStyle({style:{}})
    if surface:
        viewer.addSurface(py3Dmol.SAS, {'opacity': opacity})
    viewer.zoomTo()
    return viewer

def smi2conf(smiles):
    '''Convert SMILES to rdkit.Mol with 3D coordinates'''
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol)
        AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
        return mol
    else:
        return None

print("Helper functions added.")

Helper functions added.


In [None]:
#@title [RUN] Helper functions for managing experiments, training, and evaluating models.

def train(model, train_loader, optimizer, device):
    model.train()
    loss_all = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        y_pred = model(data)
        loss = F.mse_loss(y_pred, data.y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        optimizer.step()
    return loss_all / len(train_loader.dataset)


def eval(model, loader, device):
    model.eval()
    error = 0

    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            y_pred = model(data)
            # Mean Absolute Error using std (computed when preparing data)
            error += (y_pred * std - data.y * std).abs().sum().item()
    return error / len(loader.dataset)


def run_experiment(model, model_name, train_loader, val_loader, test_loader, n_epochs=100):
    
    print(f"Running experiment for {model_name}, training on {len(train_loader.dataset)} samples for {n_epochs} epochs.")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("\nModel architecture:")
    print(model)
    total_param = 0
    for param in model.parameters():
        total_param += np.prod(list(param.data.size()))
    print(f'Total parameters: {total_param}')
    model = model.to(device)

    # Adam optimizer with LR 1e-3
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # LR scheduler which decays LR when validation metric doesn't improve
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.9, patience=5, min_lr=0.00001)
    
    print("\nStart training:")
    best_val_error = None
    perf_per_epoch = [] # Track Test/Val MAE vs. epoch (for plotting)
    t = time.time()
    for epoch in range(1, n_epochs+1):
        # Call LR scheduler at start of each epoch
        lr = scheduler.optimizer.param_groups[0]['lr']

        # Train model for one epoch, return avg. training loss
        loss = train(model, train_loader, optimizer, device)
        
        # Evaluate model on validation set
        val_error = eval(model, val_loader, device)
        
        if best_val_error is None or val_error <= best_val_error:
            # Evaluate model on test set if validation metric improves
            test_error = eval(model, test_loader, device)
            best_val_error = val_error

        if epoch % 10 == 0:
            # Print and track stats every 10 epochs
            print(f'Epoch: {epoch:03d}, LR: {lr:5f}, Loss: {loss:.7f}, '
                  f'Val MAE: {val_error:.7f}, Test MAE: {test_error:.7f}')
        
        scheduler.step(val_error)
        perf_per_epoch.append((test_error, val_error, epoch, model_name))
    
    t = time.time() - t
    train_time = t/60
    print(f"\nDone! Training took {train_time:.2f} mins. Best validation MAE: {best_val_error:.7f}, corresponding test MAE: {test_error:.7f}.")
    
    return best_val_error, test_error, train_time, perf_per_epoch

In [None]:
# print(f"Total number of samples: {len(dataset)}.")

# # Split datasets (in case of using the full dataset)
# # test_dataset = dataset[:10000]
# # val_dataset = dataset[10000:20000]
# # train_dataset = dataset[20000:]

# # Split datasets (our 3K subset)
# train_dataset = dataset[:1000]
# val_dataset = dataset[1000:2000]
# test_dataset = dataset[2000:3000]
# print(f"Created dataset splits with {len(train_dataset)} training, {len(val_dataset)} validation, {len(test_dataset)} test samples.")

# # Create dataloaders with batch size = 32
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
#@title [RUN] Helper function for visualizing molecules with virtual nodes
import plotly.graph_objects as go
from torch_geometric.utils.convert import to_networkx
import networkx as nx
def plot_molecule_3d(molecule):
    G = to_networkx(molecule)
    pos = nx.spring_layout(G, dim=3)
    edge_x = []
    edge_y = []
    edge_z = []
    for edge in G.edges():
        x0, y0, z0 = pos[edge[0]]
        x1, y1, z1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
        edge_z.extend([z0, z1, None])
    node_x = []
    node_y = []
    node_z = []
    node_color = []

    for node in G.nodes():
        x, y, z = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_z.append(z)

        # Use the first 5 features of the x tensor as a one-hot encoding of the atom identity
        if molecule.x.shape[1] == 11:
            atom_identity = torch.argmax(molecule.x[node][:5]).item()
        elif molecule.x.shape[1] == 12:
            atom_identity = torch.argmax(molecule.x[node][:6]).item()
        else:
            print("ERROR: molecule.x has unrecognized dimensions (should be 11 or 12) and so number of atom types cannot be determined")
            return
        # Map the atom identity to a color
        if atom_identity == 0:
            color = 'white' # H
        elif atom_identity == 1:
            color = 'black' # C
        elif atom_identity == 2:
            color = 'blue' # N
        elif atom_identity == 3:
            color = 'red' # O
        elif atom_identity == 4:
            color = 'purple' # F
        elif atom_identity == 5:
            color = 'green' # Virtual Node
        else:
            print('Unrecognized molecule type')
            color = 'pink' 
        node_color.append(color)
    node_trace = go.Scatter3d(x=node_x, y=node_y, z=node_z, mode='markers', 
                                marker=dict(size=8, color=node_color))
    edge_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', 
                                line=dict(color='black', width=1), hoverinfo='none')
    fig = go.Figure(data=[edge_trace, node_trace], layout=go.Layout(
        margin=dict(l=0, r=0, b=0, t=0),
        scene=dict(xaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False),
                    yaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False),
                    zaxis=dict(title='', showticklabels=False, showgrid=False, zeroline=False)),
        showlegend=False))
    fig.show()

In [None]:
#@title [RUN] Helper Function to add virtual nodes
def add_virtual_one_hot(tensor):
    # Get the shape of the input tensor
    n, m = tensor.shape
    assert m == 11, "input tensor should have 11 features, just like atom3d dataset. found {}".format(m)
    new_tensor = torch.zeros((n, 12))
    new_tensor[:, :5] = tensor[:, :5]
    new_tensor[:, 6:] = tensor[:, 5:]
    return new_tensor

def add_virtual_node(graph_in: torch_geometric.data.Data, nodes_to_connect: list):


    graph = copy.copy(graph_in)
    #Node will be appended to the end of the array
    virtual_index = len(graph.x) 

    #Pad the features with an extra one hot column in the 6th position to 
    #indicate the virtual node
    #Only do this if no virtual nodes have been added before
    if graph.x.shape[1] == 11:
        graph.x = add_virtual_one_hot(graph.x)

    #Add node to x features
    new_node_x = torch.zeros((1,12))
    new_node_x[0,5] = 1
    graph.x = torch.cat([graph.x, new_node_x], dim=0)
    assert graph.x.shape[0] == virtual_index + 1, print(graph.x.shape)

    #Add edges connecting the node to nodes_to_connect
    new_edges = [[], []]
    for node in nodes_to_connect:
        assert node >= 0
        assert node < len(graph.x-1) #node must have been possible in orig graph

        #Add edge in both directions
        new_edges[0].append(virtual_index)
        new_edges[1].append(node)

        new_edges[0].append(node)
        new_edges[1].append(virtual_index)

    #Add edges connecting the virtual node to all other virtual nodes
    # 1. get list of the nodes with virtual_node identifier
    # 2. Fully connect it

    virtual_indices = torch.nonzero(graph.x[:-1, 5] == 1, as_tuple=False) #:-1 to not include itself
    for idx in virtual_indices:
        node = int(idx[0])
        assert node >= 0
        assert node < len(graph.x-1)
        assert node != virtual_index
        new_edges[0].append(virtual_index)
        new_edges[1].append(node)

        new_edges[0].append(node)
        new_edges[1].append(virtual_index)

    graph.edge_index = torch.cat([graph.edge_index, torch.tensor(new_edges)], dim = 1)


    #Add a position to node based on arithmetic mean of positions
    virtual_pos = torch.zeros((1,3))
    for node in nodes_to_connect:
        assert node >= 0
        assert node < len(graph.x-1) #node must have been possible in orig graph
        virtual_pos = virtual_pos + graph.pos[node]
    virtual_pos = virtual_pos / (virtual_index-1)
    graph.pos = torch.cat([graph.pos, virtual_pos])
    
    #update z just cuz
    graph.z = torch.cat([graph.z, torch.tensor([0])])

    #update edge_attributes to be "single" bonds. Currently my MPNN won't 
    #analyze the actual bonds so it doesn't matter. However, this might help it 
    #plot and will ensure consistency
    new_edge_attr = torch.zeros((len(new_edges[0])), 4)
    new_edge_attr[:, 0] = 1
    graph.edge_attr = torch.cat([graph.edge_attr, new_edge_attr], dim = 0)
        
    return graph

In [None]:
#@title [RUN] Helper Function to use to assign nodes to a virtual node with METIS Clustering
def get_clusters(data:torch_geometric.data.Data, num_clusters: int):

    cluster_data = ClusterData(data, num_parts=num_clusters, recursive=False, log=False)
    
    clusters = {} #key: cluster_number | value: list of nodes in that cluster
    for i, cluster in enumerate(cluster_data):
        clusters[i] = []
        for node_pos in cluster.pos:
            node_index = int(torch.nonzero(torch.eq(node_pos, data.pos).all(dim=1))[0][0])
            clusters[i].append(node_index)

    return clusters

In [None]:
class AddVirtualNodes:
    """
    This transform adds up to 5 virtual nodes, depending on the number of input nodes
    It uses METIS clustering to assign each node to a cluster, then fully connects
    the virtual node to each cluster. Then, the virtual nodes are fully connected to
    each other.

    At least one virtual node is added to all molecules.  A virtual node is added for each 8 atoms in the dataset.
    8 atoms was chosen b/c that, on average, adds 3 nodes to each molecule. 
    Visually plotting results on 20 molecules selected at random showed that the
    clusters appeared to match chemical intuition.
    """
    def __call__(self, data):
        num_atoms = len(data.x)
        assert num_atoms>0, "Error: data should have more than 0 atoms"

        num_clusters = (num_atoms // 8) + 1

        clusters = get_clusters(data, num_clusters)
        new_data = data
        for node_list in clusters.values():
            new_data = add_virtual_node(new_data, node_list)
        
        return new_data

In [None]:
#@title [RUN] Download Dataset
if 'IS_GRADESCOPE_ENV' not in os.environ:
    path = './qm9'
    target = 0

    # Transforms which are applied during data loading:
    # (1) Fully connect the graphs, (2) Select the target/label
    # transform = T.Compose([CompleteGraph(), SetTarget()])
    transform = T.Compose([SetTarget(), AddVirtualNodes()]) #Removed CompleteGraph() to stop having them be fully connected
    
    # Load the QM9 dataset with the transforms defined
    dataset = QM9(path, transform=transform)

    # Normalize targets per data sample to mean = 0 and std = 1.
    mean = dataset.data.y.mean(dim=0, keepdim=True)
    std = dataset.data.y.std(dim=0, keepdim=True)
    dataset.data.y = (dataset.data.y - mean) / std
    mean, std = mean[:, target].item(), std[:, target].item()

In [None]:
plot_molecule_3d(dataset[1001])

In [None]:
#@title [RUN] Prepare splits for dataset
print(f"Total number of samples: {len(dataset)}.")

# Split datasets (in case of using the full dataset)
# test_dataset = dataset[:10000]
# val_dataset = dataset[10000:20000]
# train_dataset = dataset[20000:]

# Split datasets (our 3K subset)
train_dataset = dataset[:1000]
val_dataset = dataset[1000:2000]
test_dataset = dataset[2000:3000]
print(f"Created dataset splits with {len(train_dataset)} training, {len(val_dataset)} validation, {len(test_dataset)} test samples.")

# Create dataloaders with batch size = 32
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

Total number of samples: 130831.
Created dataset splits with 1000 training, 1000 validation, 1000 test samples.


In [None]:
#@title [RUN] Rotation/Translation Equivariance Unit Test

def rot_trans_equivariance_unit_test(module, dataloader):
    """Unit test for checking whether a module (GNN layer) is 
    rotation and translation equivariant.
    """
    it = iter(dataloader)
    data = next(it)

    out_1, pos_1 = module(data.x, data.pos, data.edge_index, data.edge_attr)

    Q = random_orthogonal_matrix(dim=3)
    t = torch.rand(3)

    # Perform random rotation + translation on data.
    data.pos = data.pos @ Q + t

    out_2, pos_2 = module(data.x, data.pos, data.edge_index, data.edge_attr)
 
    rotated_pos1 = pos_1@Q + t
    return torch.allclose(rotated_pos1, pos_2, atol=1e-04)
 

# E(n) Equivariant MPNN From Satorras et al.

In [None]:
class MPNNLayer(MessagePassing):
    def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
        """Message Passing Neural Network Layer

        Arg?s:
            emb_dim: (int) - hidden dimension `d`
            edge_dim: (int) - edge feature dimension `d_e`
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.edge_dim = edge_dim

        # MLP `\psi` for computing messages `m_ij`
        # Implemented as a stack of Linear->BN->ReLU->Linear->BN->ReLU
        # dims: (2d + d_e) -> d
        self.mlp_msg = Sequential(
            Linear(2*emb_dim + edge_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(),
            Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()
          )
        
        # MLP `\phi` for computing updated node features `h_i^{l+1}`
        # Implemented as a stack of Linear->BN->ReLU->Linear->BN->ReLU
        # dims: 2d -> d
        self.mlp_upd = Sequential(
            Linear(2*emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), 
            Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()
          )

    def forward(self, h, edge_index, edge_attr):
        """
        The forward pass updates node features `h` via one round of message passing.

        As our MPNNLayer class inherits from the PyG MessagePassing parent class,
        we simply need to call the `propagate()` function which starts the 
        message passing procedure: `message()` -> `aggregate()` -> `update()`.
        
        The MessagePassing class handles most of the logic for the implementation.
        To build custom GNNs, we only need to define our own `message()`, 
        `aggregate()`, and `update()` functions (defined subsequently).

        Args:
            h: (n, d) - initial node features
            edge_index: (e, 2) - pairs of edges (i, j)
            edge_attr: (e, d_e) - edge features

        Returns:
            out: (n, d) - updated node features
        """

        out = self.propagate(edge_index, h=h, edge_attr=edge_attr)
        return out

    def message(self, h_i, h_j, edge_attr):
        """Step (1) Message

        The `message()` function constructs messages from source nodes j 
        to destination nodes i for each edge (i, j) in `edge_index`.

        The arguments can be a bit tricky to understand: `message()` can take 
        any arguments that were initially passed to `propagate`. Additionally, 
        we can differentiate destination nodes and source nodes by appending 
        `_i` or `_j` to the variable name, e.g. for the node features `h`, we
        can use `h_i` and `h_j`. 
        
        This part is critical to understand as the `message()` function
        constructs messages for each edge in the graph. The indexing of the
        original node features `h` (or other node variables) is handled under
        the hood by PyG.

        Args:
            h_i: (e, d) - destination node features, essentially h[edge_index[0]]
            h_j: (e, d) - source node features, essentially h[edge_index[1]]
            edge_attr: (e, d_e) - edge features
        
        Returns:
            msg: (e, d) - messages `m_ij` passed through MLP `\psi`
        """

        msg = torch.cat([h_i, h_j, edge_attr], dim=-1)
        return self.mlp_msg(msg)
    
    def aggregate(self, inputs, index):
        """Step (2) Aggregate

        The `aggregate` function aggregates the messages from neighboring nodes,
        according to the chosen aggregation function ('sum' by default).

        Args:
            inputs: (e, d) - messages `m_ij` from destination to source nodes
            index: (e, 1) - list of source nodes for each edge/message in `input`

        Returns:
            aggr_out: (n, d) - aggregated messages `m_i`
        """
        return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)
    
    def update(self, aggr_out, h):
        """
        Step (3) Update

        The `update()` function computes the final node features by combining the 
        aggregated messages with the initial node features.

        `update()` takes the first argument `aggr_out`, the result of `aggregate()`, 
        as well as any optional arguments that were initially passed to 
        `propagate()`. E.g. in this case, we additionally pass `h`.

        Args:
            aggr_out: (n, d) - aggregated messages `m_i`
            h: (n, d) - initial node features

        Returns:
            upd_out: (n, d) - updated node features passed through MLP `\phi`
        """
        upd_out = torch.cat([h, aggr_out], dim=-1)
        return self.mlp_upd(upd_out)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')

In [None]:
class MPNNModel(Module):
    def __init__(self, num_layers=4, emb_dim=64, in_dim=12, edge_dim=4, out_dim=1):
        """Message Passing Neural Network model for graph property prediction

        Args:
            num_layers: (int) - number of message passing layers `L`
            emb_dim: (int) - hidden dimension `d`
            in_dim: (int) - initial node feature dimension `d_n`
            edge_dim: (int) - edge feature dimension `d_e`
            out_dim: (int) - output dimension (fixed to 1)
        """
        super().__init__()
        
        # Linear projection for initial node features
        # dim: d_n -> d
        self.lin_in = Linear(in_dim, emb_dim)
        
        # Stack of MPNN layers
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(MPNNLayer(emb_dim, edge_dim, aggr='add'))
        
        # Global pooling/readout function `R` (mean pooling)
        # PyG handles the underlying logic via `global_mean_pool()`
        self.pool = global_mean_pool

        # Linear prediction head
        # dim: d -> out_dim
        self.lin_pred = Linear(emb_dim, out_dim)
        
    def forward(self, data):
        """
        Args:
            data: (PyG.Data) - batch of PyG graphs

        Returns: 
            out: (batch_size, out_dim) - prediction for each graph
        """
        h = self.lin_in(data.x) # (n, d_n) -> (n, d)
        
        for conv in self.convs:
            h = h + conv(h, data.edge_index, data.edge_attr) # (n, d) -> (n, d)
            # Note that we add a residual connection after each MPNN layer

        h_graph = self.pool(h, data.batch) # (n, d) -> (batch_size, d)

        out = self.lin_pred(h_graph) # (batch_size, d) -> (batch_size, 1)

        return out.view(-1)

In [None]:
class EquivariantMPNNLayer(MessagePassing):
    def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
        """Message Passing Neural Network Layer

        This layer is equivariant to 3D rotations and translations.

        Args:
            emb_dim: (int) - hidden dimension `d`
            edge_dim: (int) - edge feature dimension `d_e`
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.edge_dim = edge_dim

        # ============ YOUR CODE HERE ==============
        # Define the MLPs constituting your new layer.
        # At the least, you will need `\psi` and `\phi` 
        # (but their definitions may be different from what
        # we used previously).
        #

        # MLP `\psi_m` for computing feature messages `m_ij`
        # dims: 2d + d_e + 1 -> d, 
        # +1 comes from distance btwn nodes


        self.mlp_msg = Sequential(
            Linear(2*emb_dim + edge_dim + 1, emb_dim), BatchNorm1d(emb_dim), ReLU(),
            Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()
          )
        

        # MLP `\psi_x` for computing the weight of relative difference of coord
        # dims: d -> 1, 
        # +1 comes from distance btwn nodes
        self.mlp_coord = Linear(emb_dim, 1) #As simple as possible for now
        

        # MLP `\phi` for computing updated node features `h_i^{l+1}`
        # dims: 2d -> d
        self.mlp_upd = Sequential(
            Linear(2*emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), 
            Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()
          )
        
        


        # ===========================================

    def forward(self, h, pos, edge_index, edge_attr):
        """
        The forward pass updates node features `h` via one round of message passing.

        Args:
            h: (n, d) - initial node features
            pos: (n, 3) - initial node coordinates
            edge_index: (e, 2) - pairs of edges (i, j)
            edge_attr: (e, d_e) - edge features

        Returns:
            out: [(n, d),(n,3)] - updated node features and coordinates
        """
        # ============ YOUR CODE HERE ==============
        # Notice that the `forward()` function has a new argument 
        # `pos` denoting the initial node coordinates. Your task is
        # to update the `propagate()` function in order to pass `pos`
        # to the `message()` function along with the other arguments.
        #

        #Same as invariantCoordMPNN
        feat_upd, coord_upd = self.propagate(edge_index, h=h, edge_attr = edge_attr, pos = pos)
        new_coords = coord_upd + pos
        return [feat_upd, new_coords]
        # ==========================================

    # ============ YOUR CODE HERE ==============
    # Write custom `message()`, `aggregate()`, and `update()` functions
    # which ensure that the layer is 3D rotation and translation equivariant.
    
    def message(self, h_i, h_j, pos_i, pos_j, edge_attr):
        """The `message()` function constructs messages from source nodes j 
        to destination nodes i for each edge (i, j) in `edge_index`.
        
        Args:
            h_i: (e, d) - destination node features, essentially h[edge_index[0]]
            h_j: (e, d) - source node features, essentially h[edge_index[1]]
            pos_i: (e, 3) - destination node position, essentially pos[edge_index[0]]
            pos_j: (e, 3) - source node position, essentially pos[edge_index[1]]
            edge_attr: (e, d_e) - edge features
            
        Returns:
            msg: (e, d) - messages `m_ij` passed through MLP `\psi`
            coord_update: (e, 3)- scalar weighting coefficient times dif in vectors
                    (x_i-x_j)*psi_x(m_ij)
        """

        dist = torch.sqrt(torch.sum(torch.pow((pos_i-pos_j),2), dim = 1)).unsqueeze(dim = 1) #Compute L2-norm
        msg = torch.cat([h_i, h_j, edge_attr, dist], dim = -1) #has distance now
        msg = self.mlp_msg(msg)

        coord_weight = self.mlp_coord(msg)
        # print("coord weight shape:", coord_weight.shape)
        coord_update = (pos_i-pos_j) * coord_weight
        # print("coord update shape:", coord_update.shape)
        # assert coord_update.shape == [len(msg), 3]
        return [msg, coord_update]
    
    def aggregate(self, inputs, index):
        """The `aggregate` function aggregates the messages from neighboring nodes,
        according to the chosen aggregation function ('sum' by default).

        Args:
            inputs: [(e, d), (e,3)] - 
                tuple of:
                    [0] messages `m_ij` from destination to source nodes,
                    [1] coord messages from destinatoin to source nodes
            index: (e, 1) - list of source nodes for each edge/message in `input`

        Returns:
            feat_out: (n, d) - aggregated messages `m_i`
            coord_out: (n, 3) - aggregated coordinate update 
        """

        feat_out = scatter(inputs[0], index, dim=self.node_dim, reduce=self.aggr)
        coord_out = scatter(inputs[1], index, dim=self.node_dim, reduce= 'mean') 
        #I believe mean is same here as the original paper, which sums then 
        #divides by number of elements


        return [feat_out, coord_out]

    def update(self, inputs, h):
        """The `update()` function computes the final node features by combining the 
        aggregated messages with the initial node features.

        Args:
            inputs: [(e, d), (e,3)] - 
                tuple of:
                    [0] aggregated messages `m_i`
                    [1] aggregated coordinate updates
            h: (n, d) - initial node features

        Returns:
            upd_feat: (n, d) - updated node features passed through MLP `\phi`
            upd_coord: (n, d) - updated node coordinates from aggregator
        """
        upd_feat = torch.cat([h, inputs[0]], dim=-1)
        upd_feat = self.mlp_upd(upd_feat)

        upd_coord = inputs[1]
        assert upd_coord.shape[1] == 3
        return [upd_feat, upd_coord]
    # ==========================================

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')


class FinalMPNNModel(MPNNModel):
    def __init__(self, num_layers=4, emb_dim=64, in_dim=12, edge_dim=4, out_dim=1):
        """Message Passing Neural Network model for graph property prediction

        This model uses both node features and coordinates as inputs, and
        is invariant to 3D rotations and translations (the constituent MPNN layers
        are equivariant to 3D rotations and translations).

        Args:
            num_layers: (int) - number of message passing layers `L`
            emb_dim: (int) - hidden dimension `d`
            in_dim: (int) - initial node feature dimension `d_n`
            edge_dim: (int) - edge feature dimension `d_e`
            out_dim: (int) - output dimension (fixed to 1)
        """
        super().__init__()
        
        # Linear projection for initial node features
        # dim: d_n -> d
        self.lin_in = Linear(in_dim, emb_dim)
        
        # Stack of MPNN layers
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(EquivariantMPNNLayer(emb_dim, edge_dim, aggr='add'))
        
        # Global pooling/readout function `R` (mean pooling)
        # PyG handles the underlying logic via `global_mean_pool()`
        self.pool = global_mean_pool

        # Linear prediction head
        # dim: d -> out_dim
        self.lin_pred = Linear(emb_dim, out_dim)
        
    def forward(self, data):
        """
        Args:
            data: (PyG.Data) - batch of PyG graphs

        Returns: 
            out: (batch_size, out_dim) - prediction for each graph
        """
        h = self.lin_in(data.x) # (n, d_n) -> (n, d)
        pos = data.pos
        
        for conv in self.convs:
            # Message passing layer
            h_update, pos_update = conv(h, pos, data.edge_index, data.edge_attr)
            
            # Update node features
            h = h + h_update # (n, d) -> (n, d)
            # Note that we add a residual connection after each MPNN layer
            
            # Update node coordinates
            pos = pos_update # (n, 3) -> (n, 3)

        h_graph = self.pool(h, data.batch) # (n, d) -> (batch_size, d)

        out = self.lin_pred(h_graph) # (batch_size, d) -> (batch_size, 1)

        return out.view(-1)

In [None]:
RESULTS = {}
DF_RESULTS = pd.DataFrame(columns=["Test MAE", "Val MAE", "Epoch", "Model"])

In [None]:
import warnings
warnings.filterwarnings('ignore')

layer = EquivariantMPNNLayer(emb_dim=11, edge_dim=4)
model = FinalMPNNModel(num_layers=4, emb_dim=64, in_dim=12, edge_dim=4, out_dim=1)
# ==========================================

model_name = type(model).__name__
best_val_error, test_error, train_time, perf_per_epoch = run_experiment(
    model, 
    model_name, # "MPNN w/ Features and Coordinates (Equivariant Layers)", 
    train_loader,
    val_loader, 
    test_loader,
    n_epochs=100
)

RESULTS[model_name] = (best_val_error, test_error, train_time)
df_temp = pd.DataFrame(perf_per_epoch, columns=["Test MAE", "Val MAE", "Epoch", "Model"])
DF_RESULTS = DF_RESULTS.append(df_temp, ignore_index=True)

Running experiment for FinalMPNNModel, training on 1000 samples for 100 epochs.

Model architecture:
FinalMPNNModel(
  (lin_in): Linear(in_features=12, out_features=64, bias=True)
  (convs): ModuleList(
    (0-3): 4 x EquivariantMPNNLayer(emb_dim=64, aggr=add)
  )
  (lin_pred): Linear(in_features=64, out_features=1, bias=True)
)
Total parameters: 103813

Start training:
Epoch: 010, LR: 0.001000, Loss: 0.4207803, Val MAE: 0.8707978, Test MAE: 0.6727469
Epoch: 020, LR: 0.000900, Loss: 0.2808954, Val MAE: 0.8504293, Test MAE: 0.6207397
Epoch: 030, LR: 0.000810, Loss: 0.1946441, Val MAE: 0.7036817, Test MAE: 0.5423659
Epoch: 040, LR: 0.000729, Loss: 0.0822185, Val MAE: 0.6763496, Test MAE: 0.5817720
Epoch: 050, LR: 0.000656, Loss: 0.0968526, Val MAE: 0.7106811, Test MAE: 0.5817720
Epoch: 060, LR: 0.000590, Loss: 0.0618200, Val MAE: 0.6638162, Test MAE: 0.5559601
Epoch: 070, LR: 0.000590, Loss: 0.0501357, Val MAE: 0.7135672, Test MAE: 0.5536521
Epoch: 080, LR: 0.000478, Loss: 0.0926983, Val

# Custom MPNN that optionally looks at coordinate information of virtual nodes

In [None]:
class VirtualEquivariantMPNNLayer(MessagePassing):
    def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
        """Message Passing Neural Network Layer

        This layer is equivariant to 3D rotations and translations.

        This layer optionally will use coordinate information 

        Args:
            emb_dim: (int) - hidden dimension `d`
            edge_dim: (int) - edge feature dimension `d_e`
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.edge_dim = edge_dim

        # ============ YOUR CODE HERE ==============
        # Define the MLPs constituting your new layer.
        # At the least, you will need `\psi` and `\phi` 
        # (but their definitions may be different from what
        # we used previously).
        #

        # MLP `\psi_m` for computing feature messages `m_ij`
        # dims: 2d + d_e + 1 -> d, 
        # +1 comes from distance btwn nodes


        self.mlp_msg = Sequential(
            Linear(2*emb_dim + edge_dim + 1, emb_dim), BatchNorm1d(emb_dim), ReLU(),
            Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()
          )
        

        # MLP `\psi_x` for computing the weight of relative difference of coord
        # dims: d -> 1, 
        # +1 comes from distance btwn nodes
        self.mlp_coord = Linear(emb_dim, 1) #As simple as possible for now
        

        # MLP `\phi` for computing updated node features `h_i^{l+1}`
        # dims: 2d -> d
        self.mlp_upd = Sequential(
            Linear(2*emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), 
            Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()
          )
        
        


        # ===========================================

    def forward(self, h, pos, edge_index, edge_attr):
        """
        The forward pass updates node features `h` via one round of message passing.

        Args:
            h: (n, d) - initial node features
            pos: (n, 3) - initial node coordinates
            edge_index: (e, 2) - pairs of edges (i, j)
            edge_attr: (e, d_e) - edge features

        Returns:
            out: [(n, d),(n,3)] - updated node features and coordinates
        """
        # ============ YOUR CODE HERE ==============
        # Notice that the `forward()` function has a new argument 
        # `pos` denoting the initial node coordinates. Your task is
        # to update the `propagate()` function in order to pass `pos`
        # to the `message()` function along with the other arguments.
        #

        #Same as invariantCoordMPNN
        feat_upd, coord_upd = self.propagate(edge_index, h=h, edge_attr = edge_attr, pos = pos)
        new_coords = coord_upd + pos
        return [feat_upd, new_coords]
        # ==========================================

    # ============ YOUR CODE HERE ==============
    # Write custom `message()`, `aggregate()`, and `update()` functions
    # which ensure that the layer is 3D rotation and translation equivariant.
    
    def message(self, h_i, h_j, pos_i, pos_j, edge_attr):
        """The `message()` function constructs messages from source nodes j 
        to destination nodes i for each edge (i, j) in `edge_index`.
        
        Args:
            h_i: (e, d) - destination node features, essentially h[edge_index[0]]
            h_j: (e, d) - source node features, essentially h[edge_index[1]]
            pos_i: (e, 3) - destination node position, essentially pos[edge_index[0]]
            pos_j: (e, 3) - source node position, essentially pos[edge_index[1]]
            edge_attr: (e, d_e) - edge features
            
        Returns:
            msg: (e, d) - messages `m_ij` passed through MLP `\psi`
            coord_update: (e, 3)- scalar weighting coefficient times dif in vectors
                    (x_i-x_j)*psi_x(m_ij)
        """

        dist = torch.sqrt(torch.sum(torch.pow((pos_i-pos_j),2), dim = 1)).unsqueeze(dim = 1) #Compute L2-norm
        msg = torch.cat([h_i, h_j, edge_attr, dist], dim = -1) #has distance now
        msg = self.mlp_msg(msg)

        coord_weight = self.mlp_coord(msg)
        # print("coord weight shape:", coord_weight.shape)
        coord_update = (pos_i-pos_j) * coord_weight
        # print("coord update shape:", coord_update.shape)
        # assert coord_update.shape == [len(msg), 3]
        return [msg, coord_update]
    
    def aggregate(self, inputs, index):
        """The `aggregate` function aggregates the messages from neighboring nodes,
        according to the chosen aggregation function ('sum' by default).

        Args:
            inputs: [(e, d), (e,3)] - 
                tuple of:
                    [0] messages `m_ij` from destination to source nodes,
                    [1] coord messages from destinatoin to source nodes
            index: (e, 1) - list of source nodes for each edge/message in `input`

        Returns:
            feat_out: (n, d) - aggregated messages `m_i`
            coord_out: (n, 3) - aggregated coordinate update 
        """

        feat_out = scatter(inputs[0], index, dim=self.node_dim, reduce=self.aggr)
        coord_out = scatter(inputs[1], index, dim=self.node_dim, reduce= 'mean') 
        #I believe mean is same here as the original paper, which sums then 
        #divides by number of elements


        return [feat_out, coord_out]

    def update(self, inputs, h):
        """The `update()` function computes the final node features by combining the 
        aggregated messages with the initial node features.

        Args:
            inputs: [(e, d), (e,3)] - 
                tuple of:
                    [0] aggregated messages `m_i`
                    [1] aggregated coordinate updates
            h: (n, d) - initial node features

        Returns:
            upd_feat: (n, d) - updated node features passed through MLP `\phi`
            upd_coord: (n, d) - updated node coordinates from aggregator
        """
        upd_feat = torch.cat([h, inputs[0]], dim=-1)
        upd_feat = self.mlp_upd(upd_feat)

        upd_coord = inputs[1]
        assert upd_coord.shape[1] == 3
        return [upd_feat, upd_coord]
    # ==========================================

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')


class FinalMPNNModel(MPNNModel):
    def __init__(self, num_layers=4, emb_dim=64, in_dim=12, edge_dim=4, out_dim=1):
        """Message Passing Neural Network model for graph property prediction

        This model uses both node features and coordinates as inputs, and
        is invariant to 3D rotations and translations (the constituent MPNN layers
        are equivariant to 3D rotations and translations).

        Args:
            num_layers: (int) - number of message passing layers `L`
            emb_dim: (int) - hidden dimension `d`
            in_dim: (int) - initial node feature dimension `d_n`
            edge_dim: (int) - edge feature dimension `d_e`
            out_dim: (int) - output dimension (fixed to 1)
        """
        super().__init__()
        
        # Linear projection for initial node features
        # dim: d_n -> d
        self.lin_in = Linear(in_dim, emb_dim)
        
        # Stack of MPNN layers
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(EquivariantMPNNLayer(emb_dim, edge_dim, aggr='add'))
        
        # Global pooling/readout function `R` (mean pooling)
        # PyG handles the underlying logic via `global_mean_pool()`
        self.pool = global_mean_pool

        # Linear prediction head
        # dim: d -> out_dim
        self.lin_pred = Linear(emb_dim, out_dim)
        
    def forward(self, data):
        """
        Args:
            data: (PyG.Data) - batch of PyG graphs

        Returns: 
            out: (batch_size, out_dim) - prediction for each graph
        """
        h = self.lin_in(data.x) # (n, d_n) -> (n, d)
        pos = data.pos
        
        for conv in self.convs:
            # Message passing layer
            h_update, pos_update = conv(h, pos, data.edge_index, data.edge_attr)
            
            # Update node features
            h = h + h_update # (n, d) -> (n, d)
            # Note that we add a residual connection after each MPNN layer
            
            # Update node coordinates
            pos = pos_update # (n, 3) -> (n, 3)

        h_graph = self.pool(h, data.batch) # (n, d) -> (batch_size, d)

        out = self.lin_pred(h_graph) # (batch_size, d) -> (batch_size, 1)

        return out.view(-1)