## Tutorial 1 : Implementing GNN layer using `MessagePasing` class

In general, a Graph Neural Network (GNN) layer can be written as spectral-based or spatial-based methods. The spectral-based GNN layer is defined in the Fourier domain, while the spatial-based GNN layer is defined in the vertex domain. In often cases, majority mainstream GNN models are spatial-based methods due to the limitations of spectral-based methods, including the difficulty of generalization to unseen graphs and the high computational complexity. 

In this tutorial, we will focus on implementing spatial-based GNN layers (i.e., message passing networks) using the `MessagePassing` class of `PyG`.

In [1]:
import torch
import torch.nn as nn
from torch_geometric.nn.conv import MessagePassing

from common.graph_gen import generate_random_graph

## `MessagePassing` Base Class in PyG

The `MessagePassing` base class in PyG implements the message passing scheme as follows:

$$
x'_i=
\underbrace{
f_\theta \left(\mathbf{x}_i,
\underbrace{ 
\bigoplus_{j \in \mathcal{N}(i)} 
}_{\text{(2) aggregation}}
\underbrace{
g_\theta\left(\mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{ij}\right)
}_{\text{(1) message}}
\right)
}_{\text{(3) update}},
$$

where 
- $x_i \in \mathbb{R}^{d_n}$ and $x'_i \in \mathbb{R}^{d_n'}$ are the input and updated node features (embeddings), respectively.
- $d_n$ and $d'_n$ are the dimensions of node features before and after the update.
- $e_{ij} \in \mathbb{R}^{d_{e}}$ is the edge feature between nodes $i$ and $j$.
- $d_e$ is the dimension of edge features.
- $\bigoplus$ is a differentiable and permutation-invariant function (e.g., summation, mean, maximum).
- $g_\theta : \mathbb{R}^{d_n} \times \mathbb{R}^{d_n} \times \mathbb{R}^{d_{e}} \rightarrow \mathbb{R}^{d''}$ is a trainable edge function (e.g., MLP).
- $f_\theta : \mathbb{R}^{d_n} \times \mathbb{R}^{d''} \rightarrow \mathbb{R}^{d'_n}$ is a trainable mapping function (e.g., MLP).
- $\mathcal{N}(i)$ is the set of neighbors of node $i$.


This message passing implementation is flexible and can be used to implement various GNN variants, including GCN, GAT, GraphSAGE, Interaction Layers, etc. 

In PyG, the `MessagePassing` triggers the (1) message, (2) message aggregation, and (3) node update functions in the order of (1) -> (2) -> (3) while `propgate` the messages. In the following tutorial, we will demonstrate how to implement the components of message passing (message, aggregation, and update) to create different graph convolution layers.

## Very First Message Passing Layer; NaiveGCN

We will use the `MessagePassing` base class to implement the NaiveGCN layer that performs the following message passing scheme:

$$
x'_i = \sigma \left(\sum_{j \in \mathcal{N}(i)} \left(W\mathbf{x}_j+b\right) \right),
$$
where $\sigma$ is an non-linear activations (e.g. ReLU, Tanh, SiLU, GELU, ...)

In [2]:
# Yes this is it! We are done with the implementation of NaiveGCN.

class NaiveGCNConv(MessagePassing):
    
    def __init__(self, dim:int, act:'str'='ReLU'):
        
        super().__init__(aggr='add') # Aggregates messages with Summation (i.e., addition).        
        self.linear = nn.Linear(dim, dim)
        self.act = getattr(nn, act)()
        
    def forward(self, x, edge_index):
        x = self.linear(x) # Perform Wx+b
        x = self.propagate(x=x, edge_index=edge_index) # Propagate messages and Aggregate them with Summation
        x = self.act(x) # Apply activation function
        return x

In [3]:
num_node = 5
n_dim = 16

g = generate_random_graph(num_node=num_node, 
                          p_edges=0.5,
                          node_feat_dim=n_dim,
                          edge_feat_dim=n_dim)

In [4]:
conv = NaiveGCNConv(dim=n_dim)
gc_out = conv(g.x, g.edge_index)
print(gc_out.shape)

torch.Size([5, 16])


## Can we have a more fine-grained control over messaging passing scheme?

As explained earlier, we can generally manipulate (1) message generation routine,
(2) message aggregation routine, and (3) node update routine to implement GNN layers.

In the following, we will implement such (sub) routines in `MessagePassing` class.
Generally, (1) is done by overridding `message` method, (2) is done by specifying `aggr` in `MessagePassing` class's constructor, and (3) is done by overridding `update` method.

### `message` method of `MessagePassing` class

The NaiveGCN can be implemented without explicit message generation as the messages are the source node features. However, in many cases, we often employ more sophisticated message generation schemes that generates "message" from source, destination, and edge features (if applicable). For example, the following message generation scheme is used in 
[Edge Convolution](https://arxiv.org/abs/1801.07829):

$$
x'_i = \max_{j \in \mathcal{N}(i)} h_\theta\left(x_i, x_j-x_i\right),
$$

where $h_\theta$ is an learnable function

In [5]:
class EdgeConv(MessagePassing):
    
    def __init__(self, dim:int):
        super().__init__(aggr='max')
        self.h = nn.Sequential(
            nn.Linear(dim*2, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )
    
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)
    
    def message(self, x_i, x_j):
        # By overriding this function, we can specify how messages are computed by
        
        # x_i is the source node and x_j is the target node
        # x_i, x_j has shape [E, dim], where E is the number of edges
        msg = torch.cat([x_i, x_j-x_i], dim=-1) # [E, 2*dim]
        msg = self.h(msg) # [E, dim]
        return msg

In [6]:
conv = EdgeConv(dim=n_dim)
gc_out = conv(g.x, g.edge_index)
print(gc_out.shape)

torch.Size([5, 16])


### `Update` method of `MessagePassing` class

The update method is used to update the node features using the aggregated messages. As a running example, we will consider a simplified Interaction Network layer, an iconic GNN model for learning the interaction between two objects (e.g., atoms, nodes, etc.) in a graph. The update function of the Interaction Network layer is defined as follows:
$$
\begin{align}
e'_{ij} &= f_\theta(x_i, x_j, e_{ij}), \\
x'_{i} &= g_\theta(x_i, \sum_{j \in \mathcal{N}(i)} e'_{ij}),
\end{align}
$$

where $f_\theta$ and $g_\theta$ are edge and node updater, respectively. The updaters are often implemented with learnable functions (e.g., MLP)

Unlike `NaiveGCNConv` or `EdgeConv`, the Interaction Network Layer requires to update the edge features to perform node updates. To take account the edge update, we additionally implement `edge_update` method in the `MessagePassing` class. This `edge_update` method is called inside of `edge_updater` method that is already defined in `MessagePassing` class.

In [7]:
class InteractionNetworkLayer(MessagePassing):
    
    def __init__(self, dim:int):
        super().__init__(aggr='add')
        self.f = nn.Sequential(
            nn.Linear(dim*3, dim), # Assuming edge are node features are of same dimension
            nn.ReLU(),
            nn.Linear(dim, dim)
        )
        
        self.g = nn.Sequential(
            nn.Linear(dim*2, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, x, edge_index, edge_attr):
        updated_ef = self.edge_updater(edge_index=edge_index, x=x, edge_attr=edge_attr)
        updated_nf = self.propagate(edge_index=edge_index, 
                                    x=x, edge_attr=updated_ef)    
        return updated_nf, updated_ef
    
    def edge_update(self, x, edge_index, edge_attr):
        row, col = edge_index
        x_i, x_j = x[row], x[col] # src and dst node features
        return self.f(torch.cat([x_i, x_j, edge_attr], dim=-1)) # Eq (1)

    # Eq (2) related
    def message(self, edge_attr):
        return edge_attr
    
    def update(self, aggr_msg, x): 
        # !!! The update function takes in the aggregated messages as the first argument !!!
        # The other arguments are any arguments passed to the propagate function.
        
        # Maybe good to practice
        # Try to pass additional arguments that is not passed to self.propagate such as 'y'
        # by changing 'def update(self, aggs_msg, x)' to 'def update(self, aggs_msg, x, y)'
        return self.g(torch.cat([x, aggr_msg], dim=-1))

In [8]:
conv = InteractionNetworkLayer(dim=n_dim)
updated_nf, updated_ef = conv(g.x, g.edge_index, g.edge_attr)
print(updated_nf.shape, updated_ef.shape)

torch.Size([5, 16]) torch.Size([12, 16])


## `aggr` argument of `messagePassing` class's constructor

### Implementing `AttentiveInteractionLayer` with Advanced aggregation

So far, we've considered to use "simple" aggregation methods in aggregating messages. For instance, attention-based aggregations
Luckily, PyG provides a set of aggregation methods that can be used to implement various GNN variants. For example, lets implement attentive aggregation with `aggr.AttentionalAggregation` Formally, the following layer performs the following message passing scheme:

$$
\begin{align}
e'_{ij} &= f_\theta(x_i, x_j, e_{ij}), \\
w_{ij} &= \text{softmax}_j \left(\text{gate}_\theta(e'_{ij}) \right), \\
x'_{i} &= g_\theta(x_i, \sum_{j \in \mathcal{N}(i)} w_{ij} e'_{ij}),
\end{align}
$$

In [9]:
from torch_geometric.nn import aggr

class AttentiveINLayer(MessagePassing):
    
    def __init__(self, dim:int):
        
        gate_nn = nn.Sequential(
            nn.Linear(dim, dim),
            nn.Tanh(),
            nn.Linear(dim, 1)
        )
            
        super().__init__(aggr=aggr.AttentionalAggregation(gate_nn=gate_nn))
        
        self.f = nn.Sequential(
            nn.Linear(dim*3, dim), # Assuming edge are node features are of same dimension
            nn.ReLU(),
            nn.Linear(dim, dim)
        )
        
        self.g = nn.Sequential(
            nn.Linear(dim*2, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, x, edge_index, edge_attr):
        updated_ef = self.edge_updater(edge_index=edge_index, x=x, edge_attr=edge_attr)
        updated_nf = self.propagate(edge_index, x=x, edge_attr=updated_ef)    
        return updated_nf, updated_ef
    
    def edge_update(self, x, edge_index, edge_attr):
        row, col = edge_index
        x_i, x_j = x[row], x[col] # src and dst node features
        return self.f(torch.cat([x_i, x_j, edge_attr], dim=-1)) # Eq (1)

    def message(self, edge_attr):
        return edge_attr
    
    def update(self, aggr_msg, x): 
        return self.g(torch.cat([x, aggr_msg], dim=-1))

In [10]:
conv = AttentiveINLayer(dim=n_dim)
updated_nf, updated_ef = conv(g.x, g.edge_index, g.edge_attr)
print(updated_nf.shape, updated_ef.shape)

torch.Size([5, 16]) torch.Size([12, 16])
