# Chapter 3: Building GNN Modules

DGL NN inherits from backends including PyTorch NN, MXNet Gluon NN, and TensorFlow Keras NN

DGL has integrated many commonly used apinn-pytorch-conv, apinn-pytorch-dense-conv, apinn-pytorch-pooling, and apinn-pytorch-util.

This chapter takes SAGEConv with Pytorch backend as an example to introduce how to build a custom DGL NN Module.

## 3.1 DGL NN Module Construction Function

The construction function performs three sequential steps:
1. Set options
2. Register learnable parameters or submodules
3. Reset parameters

In [1]:
import torch.nn as nn
from dgl.utils import expand_as_pair
class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type, 
                 bias=True, 
                 norm=None, 
                 activation=None):
        super(SAGEConv, self).__init__()
        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation

When using the construction function, we first need to set the data dimensions. For general PyTorch modules, the dimensions are usually input dimension, output dimension and hidden dimensions. For graph neural networks, the input dimension can be split into source node dimension and destination node dimension. 

Beside data dimensions, a typical option for GNNs is aggregation type (`self._aggre_type`). Aggregation type determines how messages on different edges are aggregated for a certain destination node. Commonly used aggregation types include `mean`, `sum`, `max`, `min`. Some modules may apply more complicated aggregation like `lstm`. 

`norm` here is a callable function for feature normalization. In the SAGEConv paper normalization can be l2 normalization: $h_v = h_v/||h_v||_2$.

```
# aggregator type: mean, pool, lstm, gcn
if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:
    raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
if aggregator_type == 'pool':
    self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
    self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'pool', 'lstm']:
    self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()
```

Registering parameters and submodules. In SAGEConv, submodules vary according to the aggregation type. Those modules are pure PyTorch nn modules like `nn.Linear`, `nn.LSTM`, etc. At the end of construction function, weight initialization is applied by calling `reset_parameters()`. 

```
def reset_parameters(self):
    """Reinitialize learnable parameters."""
    gain = nn.init.calculate_gain('relu')
    if self._aggre_type == 'pool':
        nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
    if self._aggre_type == 'lstm':
        self.lstm.reset_parameters()
    if self._aggre_type != 'gcn':
        nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
    nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
```

## 3.2 DGL NN Module Forward Function

In the neural network module, `forward()` function does the actual message passing and computation. 

Compared with PyTorch's NN module which usually takes tensors as the parameters, DGL's NN module takes an additional parameter `dgl.DGLGraph`. The workload for `forward()` function can be split into three parts:

1. Graph checking and graph type specification
2. message passing 
3. Feature Update

### Aside GraphSAGE 

Source: https://youtu.be/LLUxwHc7O4A

GraphSAGE changed the way we think of GNNs through neighbourhood sampling.`

```
def forrward(self, graph, feat):
    with graph.local_scope():
        # Specify graph type then expand input feature according to graph type
        feat_src, feat_dst = expand_as_pair(feat, graph)
```

`forward()` needs to handle many corner cases on the input that can lead to invalid values in computing and message passing. One typical check in conv modules like `GraphConv` is to verify that the input graph has no 0-in-degree nodes. When a node has 0 in-degree, the `mailbox` will be empty and the reduce function will produce all-zero values. This may cause silent regression in model performance. However, in the `SAGEConv` module, the aggregated representation will be concatenated with the original node feature, the output of `forward()` will not be all-zero. 

The DGL NN module should be reusable across different types of graph input including homogeneous graphs, heterogeneous graphs, and subgraph blocks.

The math formulas for SAGEConv are:

$$h_{\mathcal{N}(dst)}^{(l+1)} = aggregate (\{h_{src}^l, \forall src \in \mathcal{N}(dst)\})$$

$$ h_{dst}^{(l+1)} = \sigma (W \cdot concat(h_{dst}^l, h_{\mathcal{N}(dst)}^{l+1})+b)$$

$$h_{dst}^{(l+1)} = norm(h_{dst}^{(l+1)})$$

We need to specify the source node feature `feat_src` and destination node feature `feat_dst` according to the graph type. `expand_as_pair()` is a function that specifies the graph type and expands `feat` into `feat_src` and `feat_dst`. 

```
def expand_as_pair(input, g=None):
    if isinstance(input_, tuple):
        # Bipartite graph case
        return input_
    elif g is not None and g.is_block:
        # Subgraph block case
        if isinstance(input_, Mapping):
            input_dst = {
                k: F.narrow_row(v, 0, g.number_of_dst_nodes(k)) 
                for k, v in input_.items()
            }
        else: 
            input_dst = F.narrow_row(input_, 0, g.number_of_dst_node())
        return input_, input_dst
    else:
        #Homogeneous graph case
        return input_, input_
        
```

For homogeneous the whole graph training, source nodes and destination nodes are the same. They are all the nodes in the graph. 

For heterogeneous case, the graph can be split into several bipartite graphs, one for each relation. The relations are represented as `(src_type, edge_type, dst_type)`. When the relation identifies that the input feature `feat` is a tuple, it will treat the graph as bipartite. The first element in the tuple will be the source node feature and the second element will be the destination node feature. 

## 3.3 Heterogeneous GraphConv Module