# Stochastic Training of GNN for Node Classification

This tutorial shows how to train a multi-layer GraphSAGE for node classification on Amazon Copurchase Network provided by [Open Graph Benchmark (OGB)](https://ogb.stanford.edu/).  The dataset contains 2.4 million nodes and 61 million edges, hence not fitting in a single GPU.

This tutorial's contents include

* Training a GNN model with a single machine, a single GPU, on a graph of any size, with DGL's GNN modules.

## Loading Dataset

OGB already prepared the data as DGL graph.

<div class="alert alert-info">
    
**Note**: If you wish to load your own large graph and a single machine's CPU memory can hold it, please refer to <a href=2_load_data.ipynb>this tutorial</a>.

</div>

In [2]:
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset('ogbn-products')

OGB dataset is a collection of graphs and their labels.  The Amazon Copurchase Network dataset only contains a single graph.  So you can simply get the graph and its node labels like this:

In [3]:
graph, node_labels = dataset[0]
print(graph)
print(node_labels)

node_features = graph.ndata['feat']
node_labels = node_labels[:, 0]
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
print('Number of classes:', num_classes)

Graph(num_nodes=2449029, num_edges=123718280,
      ndata_schemes={'feat': Scheme(shape=(100,), dtype=torch.float32)}
      edata_schemes={})
tensor([[0],
        [1],
        [2],
        ...,
        [8],
        [2],
        [4]])
Number of classes: 47


<div class="alert alert-danger">
    
**Question**: you should call `g.formats()` to create all CSR/CSC so that multiprocessing dataloaders will not compute their own CSR/CSC representations of the graph to waste memory.  How should I present it?
    
</div>

You can get the training-validation-test split of the nodes with `get_split_idx` method.

In [4]:
idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
valid_nids = idx_split['valid']
test_nids = idx_split['test']

## Defining Neighbor Sampler and Data Loader in DGL

DGL provides useful tools to iterate over the dataset in minibatches while generating the computation dependencies to compute their outputs without involving all the nodes.  For node classification, you can use `dgl.dataloading.NodeDataLoader` for iterating over the dataset.  Then you can use `dgl.dataloading.MultiLayerNeighborSampler` to generate computation dependencies of the nodes from a multi-layer GNN with *neighbor sampling*, i.e. taking only a fixed number of neighbors for each node to aggregate messages.

The syntax of `dgl.dataloading.NodeDataLoader` is mostly similar to a PyTorch `DataLoader`, with the addition that it needs a graph to generate computation dependency from, a set of node IDs to iterate on, and the neighbor sampler you defined.

Let's consider training a 2-layer GraphSAGE with neighbor sampling, and each node will gather message from 4 neighbors on each layer.  The code defining the data loader and neighbor sampler will look like the following.

In [5]:
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
    graph, train_nids, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)

You can iterate over the data loader and see what it yields.

In [6]:
example_minibatch = next(iter(train_dataloader))
print(example_minibatch)

(tensor([129839, 123246, 162633,  ...,  58772, 793634, 151106]), tensor([129839, 123246, 162633,  ..., 140799, 167173, 185392]), [Block(num_src_nodes=23628, num_dst_nodes=5061, num_edges=20200), Block(num_src_nodes=5061, num_dst_nodes=1024, num_edges=4088)])


`NodeDataLoader` gives us three items per iteration.

* The input node list for the nodes whose input features are needed to compute the outputs.
* The output node list whose GNN representation are to be computed.
* The list of computation dependency for each layer as a list of **bipartite graphs**.

In [7]:
input_nodes, output_nodes, bipartites = example_minibatch
print("To compute {} nodes' outputs, we need {} nodes' input features".format(len(output_nodes), len(input_nodes)))

To compute 1024 nodes' outputs, we need 23628 nodes' input features


In [8]:
print(bipartites)

[Block(num_src_nodes=23628, num_dst_nodes=5061, num_edges=20200), Block(num_src_nodes=5061, num_dst_nodes=1024, num_edges=4088)]


Minibatch training of GNNs usually involves message passing on such bipartite graphs.

<div class="alert alert-info">
   
**Note**: if you are interested in the details of neighbor sampling, or if you are curious about why neighbor sampling will yield a list of *bipartite* graphs instead of simply some subgraphs of the original graph, please refer to the [neighbor sampling tutorial](L3_custom_sampler.ipynb).
    
</div>

## Defining Model

The model can be written as follows:

In [21]:
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
        self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type='mean')
        self.h_feats = h_feats
        
    def forward(self, bipartites, x):
        h = self.conv1(bipartites[0], x)
        h = F.relu(h)
        h = self.conv2(bipartites[1], h)
        return h
    
model = Model(num_features, 128, num_classes).cuda()

If you compare against the code in the [introduction](1_introduction.ipynb), you will notice a difference in `forward()` function where instead of computing on the full graph:

```python
h = self.conv1(g, x)
```

you only compute on the sampled bipartite graph:

```python
h = self.conv1(bipartites[0], x)
```

## Defining Training Loop

The following initializes the model and defines the optimizer.

In [22]:
opt = torch.optim.Adam(model.parameters())

When computing the validation score for model selection, usually you can also do neighbor sampling.  To do that, you need to define another data loader.

In [23]:
valid_dataloader = dgl.dataloading.NodeDataLoader(
    graph, valid_nids, sampler,
    batch_size=1024,
    shuffle=False,
    drop_last=False,
    num_workers=0
)

The following is a training loop that performs validation every epoch.  It also saves the model with the best validation accuracy into a file.

In [24]:
import tqdm
import sklearn.metrics

best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(10):
    model.train()
    
    with tqdm.tqdm(train_dataloader) as tq:
        for step, (input_nodes, output_nodes, bipartites) in enumerate(tq):
            bipartites = [b.to(torch.device('cuda')) for b in bipartites]
            inputs = node_features[input_nodes].cuda()
            labels = node_labels[output_nodes].cuda()
            predictions = model(bipartites, inputs)

            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy())
            
            tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False)
        
    model.eval()
    
    predictions = []
    labels = []
    with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
        for input_nodes, output_nodes, bipartites in tq:
            bipartites = [b.to(torch.device('cuda')) for b in bipartites]
            inputs = node_features[input_nodes].cuda()
            labels.append(node_labels[output_nodes].numpy())
            predictions.append(model(bipartites, inputs).argmax(1).cpu().numpy())
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
        print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy))
        if best_accuracy < accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)

100%|██████████| 193/193 [00:05<00:00, 35.30it/s, loss=0.512, acc=0.857]
100%|██████████| 39/39 [00:00<00:00, 46.53it/s]
  2%|▏         | 4/193 [00:00<00:05, 37.67it/s, loss=0.894, acc=0.785]

Epoch 0 Validation Accuracy 0.793988251150726


100%|██████████| 193/193 [00:04<00:00, 46.67it/s, loss=1.074, acc=0.571]
100%|██████████| 39/39 [00:00<00:00, 65.66it/s]
  3%|▎         | 5/193 [00:00<00:04, 45.37it/s, loss=0.643, acc=0.844]

Epoch 1 Validation Accuracy 0.8278106960303131


100%|██████████| 193/193 [00:04<00:00, 40.74it/s, loss=0.656, acc=0.857]
100%|██████████| 39/39 [00:00<00:00, 53.47it/s]
  2%|▏         | 3/193 [00:00<00:06, 28.02it/s, loss=0.565, acc=0.847]

Epoch 2 Validation Accuracy 0.8419754342242454


100%|██████████| 193/193 [00:04<00:00, 41.73it/s, loss=0.410, acc=0.857]
100%|██████████| 39/39 [00:00<00:00, 54.97it/s]
  3%|▎         | 5/193 [00:00<00:04, 46.03it/s, loss=0.588, acc=0.841]

Epoch 3 Validation Accuracy 0.8506980647458231


100%|██████████| 193/193 [00:04<00:00, 41.32it/s, loss=0.057, acc=1.000]
100%|██████████| 39/39 [00:00<00:00, 58.68it/s]
  3%|▎         | 5/193 [00:00<00:04, 42.71it/s, loss=0.463, acc=0.862]

Epoch 4 Validation Accuracy 0.8559112987310226


100%|██████████| 193/193 [00:04<00:00, 41.21it/s, loss=0.144, acc=1.000]
100%|██████████| 39/39 [00:00<00:00, 48.56it/s]
  2%|▏         | 4/193 [00:00<00:05, 34.52it/s, loss=0.486, acc=0.875]

Epoch 5 Validation Accuracy 0.8609973806678025


100%|██████████| 193/193 [00:04<00:00, 46.25it/s, loss=0.102, acc=1.000]
100%|██████████| 39/39 [00:00<00:00, 51.76it/s]
  1%|          | 2/193 [00:00<00:13, 14.63it/s, loss=0.427, acc=0.879]

Epoch 6 Validation Accuracy 0.8643287643363935


100%|██████████| 193/193 [00:05<00:00, 32.19it/s, loss=0.055, acc=1.000]
100%|██████████| 39/39 [00:00<00:00, 47.01it/s]
  2%|▏         | 4/193 [00:00<00:05, 32.68it/s, loss=0.372, acc=0.907]

Epoch 7 Validation Accuracy 0.8668718053047835


100%|██████████| 193/193 [00:06<00:00, 31.50it/s, loss=0.027, acc=1.000]
100%|██████████| 39/39 [00:01<00:00, 26.57it/s]
  1%|          | 2/193 [00:00<00:12, 15.18it/s, loss=0.424, acc=0.881]

Epoch 8 Validation Accuracy 0.8682196170180302


100%|██████████| 193/193 [00:05<00:00, 35.87it/s, loss=0.189, acc=0.857]
100%|██████████| 39/39 [00:00<00:00, 43.48it/s]

Epoch 9 Validation Accuracy 0.8715255702769371





## Conclusion

In this tutorial, you have learned how to train a multi-layer GraphSAGE with neighbor sampling on a large dataset that cannot fit into GPU.  The method you have learned can scale to a graph of any size, and works on a single machine with a single GPU.

## What's next?

* [Stochastic training of GNN for link prediction](L2_large_link_prediction.ipynb).
* [Neighbor sampling and customization](L3_custom_sampler.ipynb).
* [Adapting your custom GNN module for stochastic training](L4_message_passing.ipynb).
* During inference you may wish to disable neighbor sampling.  If so, please refer to the [user guide on exact offline inference](https://docs.dgl.ai/guide/minibatch-inference.html).

For large-scale heterogeneous graph training, please see [Scaling to large heterogeneous graphs](H5_large_heterogeneous_graph.ipynb).  We recommend you going through the tutorial [Node Classification on Heterogeneous Graphs](H1_node_classification.ipynb) to grasp an idea of how DGL handles heterogeneous graphs first.

For single-machine multi-GPU training on a single large graph, please see the tutorial [Stochastic Training of GNN with Multiple GPUs](D2_multi_gpu_large_graph.ipynb).

<div class="alert alert-danger">
    
**Question**: should I talk about how to train node classification on a weighted graph (e.g. training a GCN is equivalent to training on a weighted graph)?  It will involve non-uniform neighbor sampling (which DGL has a not-efficient-enough support).  I thought of how to train a GCN with neighbor sampling, and it turns out that I should write a GIN with non-uniform neighbor sampling with replacement and aggregator type `sum`.
     
</div>

In [1]:
bipartites[0]

NameError: name 'bipartites' is not defined