# Writing GNN Modules for Stochastic GNN Training

All GNN modules DGL provides support stochastic GNN training.  This tutorial teaches you how to write your own graph neural network module for stochastic GNN training.  It assumes that

1. You know [how to write GNN modules for full graph training](3_message_passing.ipynb).
2. You know [how stochastic GNN training pipeline works](L1_large_node_classification.ipynb).

In [1]:
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset('ogbn-products')

graph, node_labels = dataset[0]
idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
node_features = graph.ndata['feat']

sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
    graph, train_nids, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)

input_nodes, output_nodes, bipartites = next(iter(train_dataloader))

Using backend: pytorch


## DGL Bipartite Graph Introduction

In the previous tutorials such as [node classification](L1_large_node_classification.ipynb), [link prediction](L2_large_link_prediction.ipynb), and [custom neighbor sampling](L3_custom_sampler.ipynb), you have seen the concept *bipartite graph*.  This section introduces how you can manipulate (directional) bipartite graphs.

<div class="alert alert-danger">
     
**Question**: shall we tell the users how to create a bipartite graph here?  We don't have a `dgl.bipartite()` interface, so a user must dive straight into the "heterogeneous graph" concept.
     
</div>

You can access the input node features and output node features via `srcdata` and `dstdata` attributes:

In [2]:
bipartite = bipartites[0]
print(bipartite.srcdata)
print(bipartite.dstdata)

{'feat': tensor([[-0.0808,  0.5835, -1.1753,  ...,  1.0232, -0.4817,  2.8607],
        [ 0.7227, -0.1247, -0.0356,  ...,  0.6085,  0.7714, -0.2870],
        [ 0.1973,  0.0741, -0.0163,  ...,  0.6373,  0.0277,  0.1868],
        ...,
        [-0.0392, -0.2355,  0.1328,  ..., -0.3510,  0.5908,  0.9608],
        [ 0.1111,  0.3545,  0.0535,  ...,  0.1027,  1.1263,  0.6594],
        [ 0.0189,  0.4955,  0.1431,  ..., -0.3645,  0.4985,  0.7945]]), '_ID': tensor([195412,  60541,  49188,  ...,  63542,  44467,  81629])}
{'feat': tensor([[-0.0808,  0.5835, -1.1753,  ...,  1.0232, -0.4817,  2.8607],
        [ 0.7227, -0.1247, -0.0356,  ...,  0.6085,  0.7714, -0.2870],
        [ 0.1973,  0.0741, -0.0163,  ...,  0.6373,  0.0277,  0.1868],
        ...,
        [ 0.2732, -0.8037,  0.1437,  ..., -0.4809,  0.9125,  0.9294],
        [-0.6513,  0.0445,  0.0212,  ..., -0.6451,  0.7263,  0.7248],
        [ 0.4365,  0.0622,  0.3198,  ..., -0.3303,  0.0735,  0.2900]]), '_ID': tensor([195412,  60541,  49188,  .

It also has `num_src_nodes` and `num_dst_nodes` functions to query how many input nodes and output nodes exist in the bipartite graph:

In [3]:
print(bipartite.num_src_nodes(), bipartite.num_dst_nodes())

23865 5086


You can assign features to `srcdata` and `dstdata` just as what you will do with `ndata` on the graphs you have seen earlier:

In [4]:
bipartite.srcdata['x'] = torch.zeros(bipartite.num_src_nodes(), bipartite.num_dst_nodes())
dst_feat = bipartite.dstdata['feat']

Also, since the bipartite graphs are constructed by DGL, you can retrieve the input node IDs (i.e. those that are required to compute the output) and output node IDs (i.e. those whose representations the current GNN layer should compute) as follows.

In [5]:
bipartite.srcdata[dgl.NID], bipartite.dstdata[dgl.NID]

(tensor([195412,  60541,  49188,  ...,  63542,  44467,  81629]),
 tensor([195412,  60541,  49188,  ..., 157118,  35460, 802935]))

## Writing GNN Modules for Bipartite Graphs for Stochastic Training

Recall from the [custom message passing tutorial for small graphs](3_message_passing.ipynb) that the message passing formulation in [Glimer et al.](https://arxiv.org/abs/1704.01212) works as follows:

$$
m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right)
$$

$$
m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)}
$$

$$
h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right)
$$

where $M$ can be any message function, $\sum$ can be any reduce function, and $U$ can be any function that combines the message aggregation and the representation of node $v$ itself.

Also, recall that the bipartite graphs yielded by the `NodeDataLoader` and `EdgeDataLoader` have the property that the first few input nodes are always identical to the output nodes:

![](assets/bipartite2.png)

In [6]:
print(torch.equal(bipartite.srcdata[dgl.NID][:bipartite.num_dst_nodes()], bipartite.dstdata[dgl.NID]))

True


Suppose you have obtained the input node representations $h_u^{(l-1)}$:

In [7]:
bipartite.srcdata['h'] = torch.randn(bipartite.num_src_nodes(), 10)

This means that the input nodes is a union of the output nodes and their neighbors, enabling you to conveniently get the term $h_v^{(l-1)}$ via:

In [8]:
h_v = bipartite.srcdata['h'][:bipartite.num_dst_nodes()]

Suppose that the message function is simply copying the source feature (i.e. $M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right) = h_v^{(l-1)}$), and the reduce function is simply average, you can still use `update_all` to compute $m_{v}^{(l)}$.

In [9]:
import dgl.function as fn

bipartite.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h'))
m_v = bipartite.dstdata['h']
m_v

tensor([[-0.0733, -0.2727,  0.2788,  ...,  0.4624,  0.0740,  0.6660],
        [ 0.1595,  0.6646, -0.8184,  ..., -0.7316, -1.1893, -0.7929],
        [-0.2563,  0.1930,  0.0794,  ..., -0.4013,  0.5718, -0.0207],
        ...,
        [-0.9491,  0.1209, -0.6648,  ...,  0.2813, -1.0527,  1.0429],
        [-0.0577,  0.3248,  0.2989,  ...,  0.0615, -0.3329,  0.2392],
        [ 0.0766,  0.1231, -0.7424,  ..., -0.4826,  0.4557, -0.3635]])

Putting them together, you can implement a GraphSAGE convolution for large graph training as follows (the differences to the [small graph counterpart](3_message_passing.ipynb) are highlighted with arrows)

<div class="alert alert-danger">
    
**Question**: do we still suggest users to slice
    
```python
h_dst = h[:g.num_dst_nodes()]
```
    
within the NN module?
    
* If so, we will still need to have an `is_block` flag which is set from `to_block`.  Otherwise the DGL NN module will never know whether it should perform the slicing or not: it can be both a normal bipartite graph (which doesn't require slicing) or something returned by neighbor sampler (which requires slicing).
* If not, then it will break existing code, and lots of examples need to change.  Moreover, doing so for heterogeneous graph is a bit heavy as you will need something like:

  ```python
  h_dst = {ntype: h[:g.num_dst_nodes(ntype)] for ntype in g.ntypes if ntype.endswith('_dst')}
  ```
    
</div>

In [10]:
import torch.nn as nn
import torch.nn.functional as F
import tqdm

class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.
    
    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)
    
    def forward(self, g, h):
        """Forward computation
        
        Parameters
        ----------
        g : Graph
            The input bipartite graph.
        h : Tensor
            The input node feature.
        """
        with g.local_scope():
            h_dst = h[:g.num_dst_nodes()]                 # <---
            g.srcdata['h'] = h                            # <---
            # update_all is a message passing API.
            g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
            h_neigh = g.dstdata['h_neigh']
            h_total = torch.cat([h_dst, h_neigh], dim=1)  # <---
            return self.linear(h_total)
        
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)
    
    def forward(self, bipartites, in_feat):
        h = self.conv1(bipartites[0], in_feat)
        h = F.relu(h)
        h = self.conv2(bipartites[1], h)
        return h
    
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
    graph, train_nids, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)
model = Model(graph.ndata['feat'].shape[1], 128, dataset.num_classes).cuda()

with tqdm.tqdm(train_dataloader) as tq:
    for step, (input_nodes, output_nodes, bipartites) in enumerate(tq):
        bipartites = [b.to(torch.device('cuda')) for b in bipartites]
        inputs = node_features[input_nodes].cuda()
        labels = node_labels[output_nodes].cuda()
        predictions = model(bipartites, inputs)

100%|██████████| 193/193 [00:07<00:00, 24.59it/s]


Both `update_all` and the functions in `nn.functional` namespace support bipartite graphs, so you can migrate the code working for small graphs to large graph training with minimal changes introduced above.