References:
- https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html

## Message Passing Framework

Generalizing the convolution operator to irregular domains is typically expressed as a neighborhood aggregation or message passing scheme. With $\mathbf{x}^{(k-1)}_i \in \mathbb{R}^F$ denoting node features of node $i$ in layer $(k-1)$ and $\mathbf{e}_{j,i} \in \mathbb{R}^D$ denoting (optional) edge features from node $j$ to node $i$, message passing graph neural networks can be described as:

$$\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \bigoplus_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right)$$

where:
- $\bigoplus$ denotes the aggregation function which is a differentiable, permutation invariant function, e.g., sum, mean or max.
- $\gamma$ denotes the update function which is a differentiable function, e.g., MLP (Multi Layer Perceptrons).
- $\phi$ denotes the message function which is a differentiable function, e.g., MLP (Multi Layer Perceptrons).

- Permutation invariance = The output does not change if the input is permuted. In the context of graph neural networks, this means that function does not depend on the arbitrary ordering of the rows/columns in the adjacency matrix.
- Permutation equivariance = The output changes in the same way as the input when the input is permuted. In graph neural networks, this means that the function is permuted in a consistent way when we permute the adjacency matrix.


Ensuring invariance or equivariance is a key challenge when we are learning over graphs.

### Implementing the GCN Layer

The [GCN layer](https://arxiv.org/abs/1609.02907) is mathematically defined as:

$$\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{W}^{\top} \cdot \mathbf{x}_j^{(k-1)} \right) + \mathbf{b}$$

where neighboring node features are first transformed by a weight matrix $\mathbf{W}$, normalized by their degree, and finally summed up. Lastly, we apply the bias vector $\mathbf{b}$ to the aggregated output. This formula can be divided into the following steps:

1. Add self-loops to the adjacency matrix.
2. Linearly transform node feature matrix.
3. Compute normalization coefficients.
4. Normalize node features in $\phi$.
5. Sum up neighboring node features ("add" aggregation).
6. Apply a final bias vector.

In [1]:
import torch
import torch.nn as nn
import torch_geometric
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

print("imports done")
print("torch version:", torch.__version__)
print("torch_geometric version:", torch_geometric.__version__)

imports done
torch version: 2.6.0+cpu
torch_geometric version: 2.5.3


In [None]:
class GCNConv(pyg_nn.MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')
        self.lin = nn.Linear(in_channels, out_channels, bias=False)
        self.bias = nn.Parameter(torch.empty(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()
    
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = pyg_utils.add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = pyg_utils.degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        # self.propagate() internally calls message(), aggregate() and update() methods.
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: Apply a final bias vector.
        out = out + self.bias

        return out

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j
