# Tutorial 3: Write Your Own GNN Module
Source: https://docs.dgl.ai/tutorials/blitz/3_message_passing.html

In [1]:
# Import 
import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
import dgl.function as fn

## Message Passing and GNNs

Message passing framework:

$m_{u \rightarrow v}^{(l)} = M^{(l)}(h_v^{(l-1)}, h_u^{(l-1)}, e_{u \rightarrow v}^{(l-1)})$

$m_v^{(l)} = \sum_{u \in \mathcal{N}(v)} m_{u \to v}^{(l)}$

$h_v^{(l)} = U^{l} ( h_v^{(l-1)}, m_v^{(l)} )$

Where $M^{(l)}$ is the message function, $\sum$ is the reduce function (not necessarily a summation) and $U^{(l)}$ is the update function.

In this tutorial we create the $\sum$ reduce function to be the GraphSAGE Convolution as follows:

$h_{\mathcal{N}(v)}^k \leftarrow Average\{h_{u}^{k-1}, \forall u \in \mathcal{N}(v)\}$

$h_v^k \leftarrow ReLU (W^k \cdot CONCAT(h_v^{k-1}, h_{\mathcal{N}(v)}^k)$

In [2]:
class SAGEConv(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule to transform the input and neighbor features to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)
        
    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            # Gathers and averages the neighborhood features. Update all triggers this for all nodes and edges.
            g.update_all(
                # Copies the node feature under name 'h' as messages under name 'm' which is sent to neighbors.
                message_func=fn.copy_u("h", "m"),
                # Averages all the received messages under name 'm' and stores the result under name 'h_N'.
                reduce_func=fn.mean("m", "h_N"),
            )
            h_N = g.ndata["h_N"]
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

In [3]:
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)
        
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h
    
    

## Training Loop

In [4]:
import dgl.data

dataset = dgl.data.CoraGraphDataset()
g = dataset[0]


Downloading /Users/dloader/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...
Extracting file to /Users/dloader/.dgl/cora_v2_d697a464
Finished data loading and preprocessing.
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.


In [5]:

def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    all_logits = []
    best_val_acc = 0
    best_test_acc = 0
    
    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    for e in range(200):
        # Forward pass
        logits = model(g, features)
        
        # Compute prediction
        pred = logits.argmax(1)
        
        # Compute loss
        # Only computes the loss for the nodes in the training set.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])
        
        # Compute train/val/test accuracy
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()
        
        # Save the best validation accuracy and the corresponding test accuracy
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc
            
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        all_logits.append(logits.detach())
        
        if e % 5 == 0:
            print(f"In epoch {e}, loss: {loss:.4f}, val acc: {val_acc:.4f} (Best:{best_val_acc:.4f}), test acc: {test_acc:.4f} (Best:{best_test_acc:.4f})")
        

In [6]:
model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

  assert input.numel() == input.storage().size(), (


In epoch 0, loss: 1.9474, val acc: 0.1620 (Best:0.1620), test acc: 0.1490 (Best:0.1490)
In epoch 5, loss: 1.8673, val acc: 0.2280 (Best:0.2280), test acc: 0.2610 (Best:0.2610)
In epoch 10, loss: 1.7141, val acc: 0.5480 (Best:0.5480), test acc: 0.5370 (Best:0.5370)
In epoch 15, loss: 1.4813, val acc: 0.6180 (Best:0.6180), test acc: 0.6300 (Best:0.6300)
In epoch 20, loss: 1.1809, val acc: 0.6560 (Best:0.6560), test acc: 0.6700 (Best:0.6700)
In epoch 25, loss: 0.8510, val acc: 0.6920 (Best:0.6920), test acc: 0.7230 (Best:0.7230)
In epoch 30, loss: 0.5500, val acc: 0.7360 (Best:0.7360), test acc: 0.7570 (Best:0.7570)
In epoch 35, loss: 0.3255, val acc: 0.7580 (Best:0.7580), test acc: 0.7690 (Best:0.7690)
In epoch 40, loss: 0.1850, val acc: 0.7580 (Best:0.7580), test acc: 0.7710 (Best:0.7690)
In epoch 45, loss: 0.1061, val acc: 0.7640 (Best:0.7640), test acc: 0.7740 (Best:0.7740)
In epoch 50, loss: 0.0639, val acc: 0.7640 (Best:0.7640), test acc: 0.7740 (Best:0.7740)
In epoch 55, loss: 0.04

In [7]:
class WeightedSAGEConv(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(WeightedSAGEConv, self).__init__()
        self.linear = nn.Linear(in_feat * 2, out_feat)
        
    def forward(self, g, h, w):
        with g.local_scope():
            g.ndata['h'] = h
            g.edata['w'] = w
            g.update_all(
                message_func=fn.u_mul_e("h", "w", "m"),
                reduce_func=fn.mean("m", "h_N"),
            )
            h_N = g.ndata["h_N"]
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

In [8]:
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, num_classes)
        
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat, torch.ones(g.num_edges(),1).to(g.device))
        h = F.relu(h)
        h = self.conv2(g, h, torch.ones(g.num_edges(),1).to(g.device))
        return h

model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

In epoch 0, loss: 1.9532, val acc: 0.1560 (Best:0.1560), test acc: 0.1440 (Best:0.1440)
In epoch 5, loss: 1.8732, val acc: 0.4580 (Best:0.4580), test acc: 0.4490 (Best:0.4490)
In epoch 10, loss: 1.7286, val acc: 0.3340 (Best:0.4620), test acc: 0.3460 (Best:0.4800)
In epoch 15, loss: 1.5148, val acc: 0.4340 (Best:0.4620), test acc: 0.4330 (Best:0.4800)
In epoch 20, loss: 1.2393, val acc: 0.5480 (Best:0.5480), test acc: 0.5360 (Best:0.5360)
In epoch 25, loss: 0.9312, val acc: 0.6640 (Best:0.6640), test acc: 0.6540 (Best:0.6540)
In epoch 30, loss: 0.6383, val acc: 0.7200 (Best:0.7200), test acc: 0.7180 (Best:0.7180)
In epoch 35, loss: 0.4030, val acc: 0.7480 (Best:0.7480), test acc: 0.7330 (Best:0.7330)
In epoch 40, loss: 0.2416, val acc: 0.7540 (Best:0.7560), test acc: 0.7450 (Best:0.7420)
In epoch 45, loss: 0.1432, val acc: 0.7660 (Best:0.7680), test acc: 0.7460 (Best:0.7450)
In epoch 50, loss: 0.0869, val acc: 0.7660 (Best:0.7680), test acc: 0.7490 (Best:0.7450)
In epoch 55, loss: 0.05

## WeightedSageConvolution Implementation

with the API under the dgl.function package we can use the built-in message and reduce functions to create new graph convolutions. The following SageConv class aggregates neighbor representations using a weighted average. 

In [11]:
class WeightedSageConv(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(WeightedSAGEConv, self).__init__()
        # A linear submodule that projects the input and neighbor feature to the output
        self.linear = nn.Linear(in_feat*2, out_feat)
        
    def forward(self, g, h, w):
        with g.local_scope():
            g.ndata['h'] = h
            g.edata['w'] = w
            g.update_all(
                message_func=fn.u_mul_e("h", "w", "m"),
                reduce_func=fn.mean("m", "h_N"),
                         )
            h_N = g.ndata["h_N"]
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)
        
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, num_classes)
        
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))
        h = F.relu(h)
        h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))
        return h

In [12]:
model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

In epoch 0, loss: 1.9478, val acc: 0.3160 (Best:0.3160), test acc: 0.3190 (Best:0.3190)
In epoch 5, loss: 1.8653, val acc: 0.5620 (Best:0.5620), test acc: 0.5440 (Best:0.5440)
In epoch 10, loss: 1.6988, val acc: 0.4780 (Best:0.5620), test acc: 0.4630 (Best:0.5440)
In epoch 15, loss: 1.4456, val acc: 0.4620 (Best:0.5620), test acc: 0.4670 (Best:0.5440)
In epoch 20, loss: 1.1262, val acc: 0.5520 (Best:0.5620), test acc: 0.5390 (Best:0.5440)
In epoch 25, loss: 0.7866, val acc: 0.6440 (Best:0.6440), test acc: 0.6200 (Best:0.6200)
In epoch 30, loss: 0.4896, val acc: 0.7140 (Best:0.7140), test acc: 0.6930 (Best:0.6930)
In epoch 35, loss: 0.2785, val acc: 0.7260 (Best:0.7260), test acc: 0.7250 (Best:0.7250)
In epoch 40, loss: 0.1522, val acc: 0.7340 (Best:0.7340), test acc: 0.7380 (Best:0.7380)
In epoch 45, loss: 0.0846, val acc: 0.7380 (Best:0.7380), test acc: 0.7480 (Best:0.7480)
In epoch 50, loss: 0.0499, val acc: 0.7340 (Best:0.7380), test acc: 0.7560 (Best:0.7480)
In epoch 55, loss: 0.03

## Customization by User-defined Function

In [13]:
def u_mul_e_udf(edges):
    """_summary_

    Args:
        edges (_type_): _description_

    Returns:
        _type_: _description_
    """
    return {"m": edges.src["h"] * edges.data["w"]}

In [14]:
def mean_udf(nodes):
    """The equivalent of dgl's built-in

    Args:
        nodes (_type_): _description_

    Returns:
        _type_: _description_
    """
    return {"h_N": nodes.mailbox["m"].mean(1)}