# Simple Impementation of E(n) Equivariant Graph Neural Networks

Original paper https://arxiv.org/pdf/2102.09844.pdf by Victor Garcia Satorras, Emiel Hoogeboom, Max Welling

In [None]:
import time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

# Load QM9 Dataset

In [None]:
!git clone https://github.com/senya-ashukha/simple-equivariant-gnn.git
%cd simple-equivariant-gnn

Cloning into 'simple-equivariant-gnn'...
remote: Enumerating objects: 87, done.[K
remote: Counting objects: 100% (87/87), done.[K
remote: Compressing objects: 100% (80/80), done.[K
remote: Total 87 (delta 37), reused 31 (delta 5), pack-reused 0[K
Unpacking objects: 100% (87/87), done.
/content/simple-equivariant-gnn


In [None]:
# QM9 is a dataset for Molecular Property Predictions http://quantum-machine.org/datasets/
# We will predict Highest occupied molecular orbital energy 
# https://en.wikipedia.org/wiki/HOMO_and_LUMO
# We use data loaders from the official repo

from qm9.data_utils import get_data, BatchGraph
train_loader, val_loader, test_loader, charge_scale = get_data(num_workers=1)

dict_keys([0, 1, 6, 7, 8, 9])
dict_keys([0, 1, 6, 7, 8, 9])
dict_keys([0, 1, 6, 7, 8, 9])


# Graph Representation

In [None]:
batch = BatchGraph(iter(train_loader).next(), False, charge_scale)
batch

In the batch: num_graphs 96 num_nodes 1735
> .h 		 a tensor of nodes representations 		shape 1735 x 15
> .x 		 a tensor of nodes positions  			shape 1735 x 3
> .edges 	 a tensor of edges, a fully connected graph 	shape 30312 x 2
> .batch  	 a tensor of graph_ids for each node 		tensor([ 0,  0,  0,  ..., 95, 95, 95])

# Define Equivariant Graph Convs  & GNN

In [None]:
def index_sum(agg_size, source, idx, cuda):
    """
        source is N x hid_dim [float]
        idx    is N           [int]
        
        Sums the rows source[.] with the same idx[.];
    """
    tmp = torch.zeros((agg_size, source.shape[1]))
    tmp = tmp.cuda() if cuda else tmp
    res = torch.index_add(tmp, 0, idx, source)
    return res

In [None]:
def BlockLinearSiLU(input_dim, hidden_dim, **layers):
    if layers['layer_type'] == 'f_e':
        return nn.Sequential(
                            nn.Linear(1 + 2 * input_dim, hidden_dim), nn.SiLU(),
                            nn.Linear(hidden_dim, hidden_dim), nn.SiLU()
                            )
    elif  layers['layer_type'] == 'f_h':
        return nn.Sequential(
                            nn.Linear(hidden_dim + input_dim, hidden_dim), nn.SiLU(),
                            nn.Linear(hidden_dim, hidden_dim)
                            )


def BlockLinearSigmoid(hidden_dim): 
    return nn.Sequential( 
                        nn.Linear(hidden_dim, 1),
                        nn.Sigmoid()
                        ) 

  
class ConvEGNN(nn.Module):
    def __init__(self, in_dim, hid_dim, cuda=True):
        super().__init__()
        self.hid_dim = hid_dim
        self.cuda = cuda
        
        # computes messages based on hidden representations -> [0, 1]
        self.f_e = BlockLinearSiLU(in_dim, hid_dim, layer_type='f_e')

        # preducts "soft" edges based on messages 
        self.f_inf = BlockLinearSigmoid(hid_dim)
        
        # updates hidden representations -> [0, 1]
        self.f_h = BlockLinearSiLU(in_dim, hid_dim, layer_type='f_h')

    
    def forward(self, b):
        # compute distances for all edges
        e_st, e_end = b.edges[:, 0], b.edges[:, 1]
        dists = torch.norm(b.x[e_st] - b.x[e_end], dim=1).reshape(-1, 1)

        # compute messages
        tmp = torch.hstack([b.h[e_st], b.h[e_end], dists])
        m_ij = self.f_e(tmp)

        # predict edges
        e_ij = self.f_inf(m_ij)
        
        # average e_ij-weighted messages  
        # m_i is num_nodes x hid_dim
        m_i = index_sum(b.h.shape[0], e_ij * m_ij, b.edges[:,0], self.cuda)
        
        # update hidden representations
        b.h = b.h + self.f_h(torch.hstack([b.h, m_i]))
        # see appendix C. Implementatation details

        return b


class NetEGNN(nn.Module):
    def __init__(self, in_dim=15, hid_dim=128, out_dim=1, n_layers=7, cuda=True):
        super().__init__()
        self.hid_dim = hid_dim
        
        self.emb = nn.Linear(in_dim, hid_dim) 

        # Make gnn of n_layers
        self.gnn = nn.Sequential(*[ConvEGNN(hid_dim, hid_dim, cuda=cuda) for _ in range(n_layers)])

        self.pre_mlp = nn.Sequential(
            nn.Linear(hid_dim, hid_dim), nn.SiLU(),
            nn.Linear(hid_dim, hid_dim))
        
        self.post_mlp = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(hid_dim, hid_dim), nn.SiLU(),
            nn.Linear(hid_dim, out_dim))

        if cuda: self.cuda()
        self.cuda = cuda
    
    
    def forward(self, b):
        b.h = self.emb(b.h)
        
        b = self.gnn(b)
        h_nodes = self.pre_mlp(b.h)
        
        # h_graph is num_graphs x hid_dim
        h_graph = index_sum(b.nG, h_nodes, b.batch, self.cuda) 
        
        out = self.post_mlp(h_graph)
        return out

In [None]:
epochs = 1000

cuda = True if torch.cuda.is_available() else False

model = NetEGNN(n_layers=7, cuda=cuda) # Оптимальное ли количество?

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-16)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, verbose=False)

### TEST ###
# optimizer =  torch.optim.Adadelta(model.parameters(), lr=1e-3)
# lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
#     factor=0.1, patience=10, threshold=0.0001, threshold_mode='abs')
# ignite.handlers.EarlyStopping

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# Training

In [None]:
print('> start training')
 

tr_ys = train_loader.dataset.data['homo'] 
me, mad = torch.mean(tr_ys), torch.mean(torch.abs(tr_ys - torch.mean(tr_ys)))

if cuda:
    me = me.cuda()
    mad = mad.cuda()

train_loss = []
val_loss = []
test_loss = []

for epoch in range(epochs):
    print('> epoch %s:' % str(epoch).zfill(3), end=' ', flush=True) 
    start = time.time()

    batch_train_loss = []
    batch_val_loss = []
    batch_test_loss = []

    model.train()
    for batch in train_loader:
        batch = BatchGraph(batch, cuda, charge_scale)
        
        out = model(batch).reshape(-1)
        # compute l1-loss 
        loss =  F.l1_loss(out,  (batch.y - me) / mad)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        with torch.no_grad():
            loss =  F.l1_loss(out * mad + me, batch.y)

        batch_train_loss += [float(loss.data.cpu().numpy())]  
        
    train_loss += [np.mean(batch_train_loss) / 0.001]
    
    print('train %.3f' % train_loss[-1], end=' ', flush=True)
    
    with torch.no_grad():
        model.eval()
        for batch in val_loader:
            batch = BatchGraph(batch, cuda, charge_scale)
            out = model(batch).reshape(-1)
            loss = F.l1_loss(out * mad + me, batch.y).data.cpu().numpy()
            batch_val_loss += [np.mean(loss)]
            
        val_loss += [np.mean(batch_val_loss) / 0.001]
        
        print('val %.3f' % val_loss[-1], end=' ', flush=True)
        
        for batch in test_loader:
            batch = BatchGraph(batch, cuda, charge_scale)
            out = model(batch).reshape(-1)
            loss = F.l1_loss(out * mad + me, batch.y).data.cpu().numpy()
            batch_test_loss += [np.mean(loss)]

        test_loss += [np.mean(batch_test_loss) / 0.001]
        
    end = time.time()

    print('test %.3f (%.1f sec)' % (test_loss[-1], end-start), flush=True)
    lr_scheduler.step()

> start training
> epoch 000: train 371.102 val 288.344 test 286.007 (90.6 sec)
> epoch 001: train 271.244 val 263.889 test 262.990 (88.8 sec)
> epoch 002: train 217.452 val 188.507 test 187.269 (89.0 sec)
> epoch 003: train 188.802 val 167.326 test 166.806 (88.9 sec)
> epoch 004: train 172.135 val 161.350 test 161.055 (88.8 sec)
> epoch 005: train 155.915 val 149.394 test 147.931 (88.9 sec)
> epoch 006: train 144.612 val 144.822 test 142.887 (88.7 sec)
> epoch 007: train 136.459 val 122.410 test 123.573 (88.5 sec)
> epoch 008: train 129.189 val 129.402 test 128.389 (88.9 sec)
> epoch 009: train 122.464 val 118.462 test 118.207 (88.5 sec)
> epoch 010: train 116.397 val 108.384 test 107.435 (88.2 sec)
> epoch 011: train 111.147 val 103.638 test 103.466 (88.8 sec)
> epoch 012: train 107.639 val 100.743 test 100.658 (88.9 sec)
> epoch 013: train 104.670 val 95.762 test 96.242 (89.3 sec)
> epoch 014: train 100.548 val 91.055 test 90.625 (88.3 sec)
> epoch 015: train 97.609 val 90.724 test 