In [1]:
import dgl
import torch
import torch.nn as nn

Using backend: pytorch


## Message passing framework of `dgl`

One major advantage of `dgl` over `Pytorch Geomteric` is flexible message passing frameworks. The message passing frameworks allows the user to come up with user-defined-message routings easily. Such features are extremely helpful for designing sophistacted routing mechanism in MARL applications.

In this tutorial, we will not cover the advanced usage of the message passing framework of `dgl`. For more details, you can refer this [link](https://docs.dgl.ai/guide/message.html)

In [2]:
u, v = torch.tensor([0, 0, 0, 1]), torch.tensor([1, 2, 3, 3])
g = dgl.graph((u, v), num_nodes=8)
g = dgl.add_self_loop(g)

node_feat_dim = 32 # the node feature dim
edge_feat_dim = 3 # the edge feature dim

g.ndata['feat'] = torch.randn(g.number_of_nodes(), node_feat_dim)
g.edata['feat'] = torch.randn(g.number_of_edges(), edge_feat_dim)

## A simple GCN in message passing framework

GCN is the most famous and generally works well GNN model. Here we define a simple-to-implement GCN and implement the GCN layer with `dgl`'s message passing framework. The simple GCN is a variant of GCN that use the diffrent $A$ from the original derivation.

A simplified GCN layer can be defined as:

$$H^{(l+1)} = \sigma(AW^{(l)} H^{(l)}+ b^{(l)})$$

Here the $l$ indicates the layer index of GCN and $H^{(l)}$ is the $l$-th layer input feature. By definition $H^{(0)}$ is the input feature $V$. $W^{(l)}$ and $b^{(l)}$ are the learning paramters of $l$-th GCN layer. $A$ is the adjacency matrix of the input graph.


> Disclaimer: The original formulation GCN has not bias term $b^{(l)}$ because the existence of bias term makes inable to use the trained GCN models in the different size of graphs from the training cases.

Checking the math above with the matrix multiplication would help you to understand what happen in the computation of GCN. Assume $n$ is the number of nodes in the input graph and $p^{(l)}$ and $q^{(l)}$ are the input and output feature dimension respectively. Then the adjacency $A \in \mathbb{R}^{n \times n}$, weight matrix $W^{(l)} \in \mathbb{R}^{p^{(l)} \times q^{(l)}}$, the input feature $H^{(l)} \in \mathbb{R}^{n \times p^{(l)}}$ and the bias $b \in \mathbb{R}^{n \times q^{(l)}}$. It becomes again clear that having the bias term disables the GCN to be used for differently sized graphs.

## Message passing reformulation of the simple GCN

To implement the GNN operations with message passing framework, it is important to understand the operations in the perspective of given node $i$.

In the single node perspective, the update rule can be re-written as follows:

$$h^{(l+1)}_i = \sigma(\sum_{j \in \mathcal{N}(i)} z_j + b^{(l)}_i) $$

where the $\mathcal{N}(i)$ is the index set of node $i$'s neighborhood. The fused feature matrix $Z^{(l)}$ is defined as product of weight $W^{(l)}$ and input feature $H^{(l)}$. $z_j$ is the row vector of $Z$. By stacking all $h^{(l+1)}$, we can attain $H^{(l)}$ which is the exact outcome of the simplied GCN.

In the following cell, let's code the message passing reformulated version of the simple GCN within `dgl`'s message framework.

In [3]:
class MessagePassingGCN(nn.Module):
    
    def __init__(self, 
                 input_dim: int, 
                 output_dim: int):
        super(MessagePassingGCN, self).__init__()
        self.linear = nn.Linear(in_features=input_dim,
                                out_features=output_dim, bias=False)
        
    def forward(self, g, nf):        
        g = g.local_var() # make a local graph
        z = self.linear(nf) # compute WX -> Z
        g.ndata['z'] = z
        
        # Send source node features to the destination nodes
        g.pull(v=g.nodes(),
               message_func=self.msg_func,
               reduce_func=self.reduce_func)
        return g.ndata['h']
        
    def msg_func(self, edges):        
        return {'z': edges.src['z']}
    
    def reduce_func(self, nodes):
        incoming_msg = nodes.mailbox['z'] # [#.nodes x # incomings x # feat. dim]
        reduced_msg = incoming_msg.sum(dim=1) # perform AZ
        return {'h' : reduced_msg}

In [4]:
gc_out_dim = 256

In [5]:
gc = MessagePassingGCN(node_feat_dim, gc_out_dim)

In [17]:
h_updated = gc(g, g.ndata['feat'])
print(h_updated.shape)

torch.Size([8, 256])


In [18]:
%%timeit
h_updated = gc(g, g.ndata['feat'])

3.7 ms ± 592 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## A slightly optimized verision of simple GCN with `dgl.function`

`dgl` is not only a computational framework that supporting versatile message passing frameworks but also indeed optimized. Luckily, almost every basic arithmetic opertaions, such as weighted sums, top-k operataions has implemented already in `dgl.function` pacakage. The `dgl.function` also supports graph-readouts. For the detailed explanations, please refer this [link](https://docs.dgl.ai/guide/message.html).

In [19]:
class MessagePassingGCN(nn.Module):
    
    def __init__(self, 
                 input_dim: int, 
                 output_dim: int):
        super(MessagePassingGCN, self).__init__()
        self.linear = nn.Linear(in_features=input_dim,
                                out_features=output_dim, bias=False)
        
        self.msg_func = dgl.function.copy_src('z','z')
        self.reduce_func = dgl.function.sum('z','h')
        
    def forward(self, g, nf):        
        g = g.local_var() # make a local graph
        z = self.linear(nf) # compute WX -> Z
        g.ndata['z'] = z
        
        # Send source node features to the destination nodes
        g.pull(v=g.nodes(),
               message_func=self.msg_func,
               reduce_func=self.reduce_func)
        return g.ndata['h']

In [20]:
gc = MessagePassingGCN(node_feat_dim, gc_out_dim)

In [21]:
h_updated = gc(g, g.ndata['feat'])
print(h_updated.shape)

torch.Size([8, 256])


In [22]:
%%timeit
h_updated = gc(g, g.ndata['feat'])

1.59 ms ± 38.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
