# Stochastic Training of GNN for Node Classification on Large Graphs (TO BE REFINED)

This tutorial shows how to train a multi-layer GraphSAGE for node classification on Amazon Copurchase Network provided by [Open Graph Benchmark (OGB)](https://ogb.stanford.edu/).  The dataset contains 2.4 million nodes and 61 million edges, hence not fitting in a single GPU.

This tutorial's contents include

* Creating your DGL graph from your own data in other formats such as CSV.
* Training a GNN model with a single machine, a single GPU, on a graph of any size, with DGL's GNN modules.

## Load Dataset

OGB already prepared the data as DGL graph.

<div class="alert alert-info">
    
**Note**: If you wish to load your own large graph, please refer to <a href=2_load_data.ipynb>this tutorial</a>.

</div>

In [1]:
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset('ogbn-products')

Using backend: pytorch


KeyboardInterrupt: 

OGB dataset is a collection of graphs and their labels.  The Amazon Copurchase Network dataset only contains a single graph.  So you can simply get the graph and its node labels like this:

In [None]:
graph, node_labels = dataset[0]
print(graph)
print(node_labels)

node_features = graph.ndata['feat']
node_labels = node_labels[:, 0]
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
print('Number of classes:', num_classes)

You can get the training-validation-test split of the nodes with `get_split_idx` method.

In [None]:
idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
valid_nids = idx_split['valid']
test_nids = idx_split['test']

## Multi-layer Message Passing in Detail

The message passing formulation defined in [Gilmer et al.](https://arxiv.org/abs/1704.01212) goes as below:

$$
m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right) \\
m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)} \\
h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right)
$$

Essentially, the $l$-th layer representation of a node depends on the $(l-1)$-th layer representation of the same node, as well as the $(l-1)$-th layer representation of the neighboring nodes.

<div class="alert alert-info">
    <b>Note: </b>See <a href=3_message_passing.ipynb>this tutorial</a> for more details in message passing in DGL.
</div>

For instance, if you would like to compute the representation of the red node by a 2-layer GNN:

![Imgur](assets/seed.png)

The formulation shows that to compute the red node's second GNN layer output $\boldsymbol{h}_8^{(2)}$ you will need the first GNN layer output of the same node $\boldsymbol{h}_8^{(1)}$, as well as the first GNN layer output of its neighboring nodes $\boldsymbol{h}_4^{(1)}$, $\boldsymbol{h}_5^{(1)}$, $\boldsymbol{h}_7^{(1)}$, and $\boldsymbol{h}_{11}^{(1)}$ (colored green).  The message passing will happen on the green dashed edges visualized below.

![Imgur](assets/3.png)

To compute the first-layer representation of the red and green nodes, you further need to perform message passing on the yellow edges visualized below.  Therefore, other than the red and green nodes, the yellow nodes' input features are also necessary to compute the red nodes' second GNN layer output.

<div class="alert alert-info">
    <b>Note</b>: The edges that showed up in the second layer (i.e. as the green dashed arrows) appears again in the first layer (i.e. as the yellow dashed arrows), but they represent different message passing computations.  The second layer computes from $\boldsymbol{h}_\cdot^{(1)}$ to $\boldsymbol{h}_\cdot^{(2)}$ while the first layer computes from $\boldsymbol{h}_\cdot^{(0)}$ (i.e. the input) to $\boldsymbol{h}_\cdot^{(1)}$.
</div>

![Imgur](assets/4.png)

You may notice that to figure out which nodes' input features are necessary, you are going in the opposite direction of message aggregation: you start from the layer closest to the output and work backward to the input.  Message passing, in contrast, goes from the layer closest to the input towards the output.

## Neighbor Sampling Overview

You can also see from the previous example that computing representation for a small number of nodes often requires input features of a significantly larger number of nodes.  Taking all neighbors for message aggregation is often too costly since the nodes needed for input features would easily cover a large portion of the graph, especially for real-world graphs which are often [scale-free](https://en.wikipedia.org/wiki/Scale-free_network).

Neighbor sampling addresses this issue by selecting a subset of the neighbors to perform aggregation.  For instance, to compute $\boldsymbol{h}_8^{(1)}$, you can choose two of the neighbors instead of all of them to aggregate, so you need the first layer representation of the red node and only two green nodes.

![Imgur](assets/5.png)

Similarly, to compute the red and green nodes' first layer representation, you can also do neighbor sampling that takes two of each node's neighbors.

<div class="alert alert-info">
    <b>Note</b>: Since you need the first layer representation of the red node, you will need to sample the neighbors of the red node again.
</div>

![Imgur](assets/6.png)

You can see that this method could give us fewer nodes needed for input features.

### Defining neighbor sampler and data loader in DGL

DGL provides useful tools to generate such computation dependencies while iterating over the dataset in minibatches.  For node classification, you can use `dgl.dataloading.NodeDataLoader` for iterating over the dataset, and `dgl.dataloading.MultiLayerNeighborSampler` to generate computation dependencies of the nodes from a multi-layer GNN with neighbor sampling.

The syntax of `dgl.dataloading.NodeDataLoader` is mostly similar to a PyTorch `DataLoader`, with the addition that it needs a graph to generate computation dependency from, a set of node IDs to iterate on, and the neighbor sampler you defined.

Let's consider training a 3-layer GraphSAGE with neighbor sampling, and each node will gather message from 4 neighbors on each layer.  The code defining the data loader and neighbor sampler will look like the following.

In [None]:
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
    graph, train_nids, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)

We can iterate over the data loader we created and see what it gives us.

In [None]:
example_minibatch = next(iter(train_dataloader))
print(example_minibatch)

`NodeDataLoader` gives us three items per iteration.

* The input node list for the nodes whose input features are needed to compute the outputs.
* The output node list whose GNN representation are to be computed.
* The list of computation dependency for each layer.

In [None]:
input_nodes, output_nodes, bipartites = example_minibatch
print("To compute {} nodes' output we need {} nodes' input features".format(len(output_nodes), len(input_nodes)))

The variable `bipartites` shows how messages aggregate on each layer.  As its name suggests, it is a **list** of bipartite graphs.

So why does DGL return a list of *bipartite* graphs for training a *homogeneous* graph?  The reason is that the number of nodes for input and that for output of a given GNN layer is different.  Take the example above:

![Imgur](assets/6.png)

That GNN layer will output the representation of three nodes (two green nodes and one red node), but it will require input from 7 nodes (the green nodes and red node, plus 4 yellow nodes).  Only a bipartite graph can describe such computation:

![](assets/bipartite.png)

Minibatch training of GNNs usually involves message passing on such bipartite graphs.

In [None]:
print(bipartites)

## Defining Model

The model can be written as follows:

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn

class SAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        
    def forward(self, bipartites, x):
        for l, (layer, bipartite) in enumerate(zip(self.layers, bipartites)):
            x = layer(bipartite, x)
            if l != self.n_layers - 1:
                x = F.relu(x)
        return x

You can see that here we are iterating over the pairs of NN module layer and bipartite graphs generated by the data loader.

## Defining Training Loop

The following initializes the model and defines the optimizer.

In [None]:
model = SAGE(num_features, 128, num_classes, 3).cuda()
opt = torch.optim.Adam(model.parameters())

When computing the validation score for model selection, usually you can also do neighbor sampling.  To do that, you need to define another data loader.

In [None]:
valid_dataloader = dgl.dataloading.NodeDataLoader(
    graph, valid_nids, sampler,
    batch_size=1024,
    shuffle=False,
    drop_last=False,
    num_workers=0
)

The following is a training loop that performs validation every epoch.  It also saves the model with the best validation accuracy into a file.

In [None]:
import tqdm
import sklearn.metrics

best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(10):
    model.train()
    
    with tqdm.tqdm(train_dataloader) as tq:
        for step, (input_nodes, output_nodes, bipartites) in enumerate(tq):
            bipartites = [b.to(torch.device('cuda')) for b in bipartites]
            inputs = node_features[input_nodes].cuda()
            labels = node_labels[output_nodes].cuda()
            predictions = model(bipartites, inputs)

            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy())
            
            tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False)
        
    model.eval()
    
    predictions = []
    labels = []
    with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
        for input_nodes, output_nodes, bipartites in tq:
            bipartites = [b.to(torch.device('cuda')) for b in bipartites]
            inputs = node_features[input_nodes].cuda()
            labels.append(node_labels[output_nodes].numpy())
            predictions.append(model(bipartites, inputs).argmax(1).cpu().numpy())
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
        print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy))
        if best_accuracy < accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)

## Offline Inference without Neighbor Sampling

Usually for offline inference it is desirable to aggregate over the entire neighborhood to eliminate randomness introduced by neighbor sampling.  However, using the same methodology in training is not efficient, because there will be a lot of redundant computation.  Moreover, simply doing neighbor sampling by taking all neighbors will often exhaust GPU memory because the number of nodes required for input features may be too large to fit into GPU memory.

Instead, you need to compute the representations layer by layer: you first compute the output of the first GNN layer for all nodes, then you compute the output of second GNN layer for all nodes using the first GNN layer's output as input, etc.  This gives us a different algorithm from what is being used in training.  During training we have an outer loop that iterates over the nodes, and an inner loop that iterates over the layers.  In contrast, during inference we have an outer loop that iterates over the layers, and an inner loop that iterates over the nodes.

If you do not care about randomness too much (e.g., during model selection in validation), you can still use the `dgl.dataloading.MultiLayerNeighborSampler` and `dgl.dataloading.NodeDataLoader` to do offline inference, since it is usually faster for evaluating a small number of nodes.

![Imgur](assets/anim.gif)

In [None]:
def inference(model, graph, input_features, batch_size):
    nodes = torch.arange(graph.number_of_nodes())
    
    sampler = dgl.dataloading.MultiLayerNeighborSampler([None])  # one layer at a time, taking all neighbors
    dataloader = dgl.dataloading.NodeDataLoader(
        graph, nodes, sampler,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=0)
    
    with torch.no_grad():
        for l, layer in enumerate(model.layers):
            # Allocate a buffer of output representations for every node
            # Note that the buffer is on CPU memory.
            output_features = torch.zeros(
                graph.number_of_nodes(), model.n_hidden if l != model.n_layers - 1 else model.n_classes)

            for input_nodes, output_nodes, bipartites in tqdm.tqdm(dataloader):
                bipartite = bipartites[0].to(torch.device('cuda'))

                x = input_features[input_nodes].cuda()

                # the following code is identical to the loop body in model.forward()
                x = layer(bipartite, x)
                if l != model.n_layers - 1:
                    x = F.relu(x)

                output_features[output_nodes] = x.cpu()
            input_features = output_features
    return output_features

The following code loads the best model from the file saved previously and performs offline inference.  It computes the accuracy on the test set afterwards.

In [None]:
model.load_state_dict(torch.load(best_model_path))
all_predictions = inference(model, graph, node_features, 8192)

In [None]:
test_predictions = all_predictions[test_nids].argmax(1)
test_labels = node_labels[test_nids]
test_accuracy = sklearn.metrics.accuracy_score(test_predictions.numpy(), test_labels.numpy())
print('Test accuracy:', test_accuracy)

## Conclusion

In this tutorial, you have learned how to train a multi-layer GraphSAGE with neighbor sampling on a large dataset that cannot fit into GPU.  The method you have learned can scale to a graph of any size, and works on a single machine with a single GPU.

## What's next?

* [Stochastic Training of GNN for Link Prediction on Large Graphs](L2_large_link_prediction.ipynb).
* [Customizing Neighbor Sampler (TODO)]().
* [Adapting your custom GNN module for large scale training (TODO)]().

For large-scale heterogeneous graph training, please see [Scaling to large heterogeneous graphs](H5_large_heterogeneous_graph.ipynb).  We recommend you going through the tutorial [Node Classification on Heterogeneous Graphs](H1_node_classification.ipynb) to grasp an idea of how DGL handles heterogeneous graphs first.

For single-machine multi-GPU training on a single large graph, please see the tutorial [Stochastic Training of GNN with Multiple GPUs](D2_multi_gpu_large_graph.ipynb).