In [None]:
!pip install dgl

## Overview
This chapter discusses how to train a graph neural network for node classification, edge classification, link prediction, and graph classification for small graph(s), by message passing methods introduced in Chapter 2: Message Passing and neural network modules introduced in Chapter 3: Building GNN Modules.

This chapter assumes that your graph as well as all of its node and edge features can fit into GPU; see Chapter 6: Stochastic Training on Large Graphs if they cannot.

The following text assumes that the graph(s) and node/edge features are already prepared. If you plan to use the dataset DGL provides or other compatible `DGLDataset` as is described in Chapter 4: Graph Data Pipeline, you can get the graph for a single-graph dataset with something like

In [7]:
import dgl
import torch

dataset = dgl.data.CiteseerGraphDataset()
graph = dataset[0]

  NumNodes: 3327
  NumEdges: 9228
  NumFeats: 3703
  NumClasses: 6
  NumTrainingSamples: 120
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


## Node Classification/Regression

One of the most popular and widely adopted tasks for graph neural networks is node classification, where each node in the training/validation/test set is assigned a ground truth category from a set of predefined categories. Node regression is similar, where each node in the training/validation/test set is assigned a ground truth number.

### Overview
To classify nodes, graph neural network performs message passing discussed in Chapter 2: Message Passing to utilize the node’s own features, but also its neighboring node and edge features. Message passing can be repeated multiple rounds to incorporate information from larger range of neighborhood.

### Writing neural network model
DGL provides a few built-in graph convolution modules that can perform one round of message passing. In this guide, we choose `dgl.nn.pytorch.SAGEConv` (also available in MXNet and Tensorflow), the graph convolution module for GraphSAGE.

Usually for deep learning models on graphs we need a multi-layer graph neural network, where we do multiple rounds of message passing. This can be achieved by stacking graph convolution modules as follows.



In [8]:
# Contruct a two-layer GNN model
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        self.conv1 = dglnn.SAGEConv(
            in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
        self.conv2 = dglnn.SAGEConv(
            in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = F.relu(h)
        h = self.conv2(graph, h)
        return h

**Note** that you can use the model above for not only node classification, but also obtaining hidden node representations for other downstream tasks such as 5.2 Edge Classification/Regression, 5.3 Link Prediction, or 5.4 Graph Classification.

For a complete list of built-in graph convolution modules, please refer to apinn.

For more details in how DGL neural network modules work and how to write a custom neural network module with message passing please refer to the example in Chapter 3: Building GNN Modules.

### Training loop

Training on the full graph simply involves a forward propagation of the model defined above, and computing the loss by comparing the prediction against ground truth labels on the training nodes.

This section uses a DGL built-in dataset `dgl.data.CiteseerGraphDataset` to show a training loop. The node features and labels are stored on its graph instance, and the training-validation-test split are also stored on the graph as boolean masks. This is similar to what you have seen in Chapter 4: Graph Data Pipeline.

In [9]:
node_features = graph.ndata['feat']
node_labels = graph.ndata['label']
train_mask = graph.ndata['train_mask']
valid_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1)

The following is an example of evaluating your model by accuracy.



In [10]:
def evaluate(model, graph, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(graph, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

You can then write our training loop as follows.

In [11]:
model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())

for epoch in range(10):
    model.train()
    # forward propagation by using all nodes
    logits = model(graph, node_features)
    # compute loss
    loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
    # compute validation accuracy
    acc = evaluate(model, graph, node_features, node_labels, valid_mask)
    # backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

    # Save model if necessary.  Omitted in this example.



1.7925282716751099
1.777905821800232
1.7636406421661377
1.7488237619400024
1.7331427335739136
1.7165350914001465
1.699110746383667
1.6809422969818115
1.661994457244873
1.6423442363739014


GraphSAGE provides an end-to-end homogeneous graph node classification example. You could see the corresponding model implementation is in the GraphSAGE class in the example with adjustable number of layers, dropout probabilities, and customizable aggregation functions and nonlinearities.