# Stochastic Training of GNN for Link Prediction on Large Graphs 

This tutorial will show how to train a multi-layer GraphSAGE for link prediction on Amazon Copurchase Network provided by OGB.  The dataset contains 2.4 million nodes and 61 million edges, hence not fitting a single GPU.

## Link Prediction Overview

Link prediction requires the model to predict the probability of existence of an edge.  This tutorial does so by computing a dot product between the representations of both incident nodes.

$$
\hat{y}_{u\sim v} = \sigma(h_u^T h_v)
$$

It then minimizes the following binary cross entropy loss.

$$
\mathcal{L} = -\sum_{u\sim v\in \mathcal{D}}\left( y_{u\sim v}\log(\hat{y}_{u\sim v}) + (1-y_{u\sim v})\log(1-\hat{y}_{u\sim v})) \right)
$$

This is identical to the link prediction formulation in [the previous tutorial on link prediction](4_link_predict.ipynb).

## Load Dataset

This tutorial loads the dataset from the `ogb` package as in the [previous tutorial](L1_large_node_classification.ipynb).

In [1]:
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset('ogbn-products')

graph, node_labels = dataset[0]
print(graph)
print(node_labels)

node_features = graph.ndata['feat']
node_labels = node_labels[:, 0]
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
print('Number of classes:', num_classes)

idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
valid_nids = idx_split['valid']
test_nids = idx_split['test']

Using backend: pytorch


Graph(num_nodes=2449029, num_edges=123718280,
      ndata_schemes={'feat': Scheme(shape=(100,), dtype=torch.float32)}
      edata_schemes={})
tensor([[0],
        [1],
        [2],
        ...,
        [8],
        [2],
        [4]])
Number of classes: 47


## Define Data Loader with Neighbor Sampling

Different from the [link prediction tutorial for small graph](4_link_predict.ipynb), you will need to iterate over the edges in minibatches, since computing the probability of all edges is usually impossible.  For each minibatch of edges, you compute the output representation of their incident nodes using neighbor sampling and GNN, in a similar fashion introduced in the [large-scale node classification tutorial](L1_large_node_classification.ipynb).

DGL provides `dgl.dataloading.EdgeDataLoader` that allows you to iterate over edges for edge classification or link prediction tasks.

To perform link prediction, you need to specify a negative sampler.  A negative sampler takes in a list of edges as positive examples and returns a list of negative examples.  In DGL, negative samplers can be any callable that has the following signature:

```python
def negative_sampler(g: DGLGraph, eids: Tensor) -> Tuple[Tensor, Tensor]:
    pass
```

The first argument is the original graph and the second argument is the minibatch of edge IDs.  The function returns a pair of $u$-$v^-$ node ID tensors as negative examples.

<div class="alert alert-info">
    
**Note**: for heterogeneous graphs, the signature of the negative sampler will change.  See [here](https://todo) for more details.
    
</div>

The following code implements a negative sampler that find non-existent edges by sampling `k` $v^-$ for each $u$ according to a distribution $P^-(v) \propto d(v)^{0.75}$, where $d(v)$ is the degree of $v$.

In [2]:
class NegativeSampler(object):
    def __init__(self, g, k):
        self.k = k
        self.weights = g.in_degrees().float() ** 0.75
    def __call__(self, g, eids):
        src, _ = g.find_edges(eids)
        src = src.repeat_interleave(self.k)
        dst = self.weights.multinomial(len(src), replacement=True)
        return src, dst

After defining the negative sampler, one can then define the edge data loader with neighbor sampling.  Here this tutorial takes 5 negative examples per positive example.

In [3]:
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
k = 5
train_dataloader = dgl.dataloading.EdgeDataLoader(
    graph, torch.arange(graph.number_of_edges()), sampler,
    negative_sampler=NegativeSampler(graph, k),
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=4
)

You can peek one minibatch from `train_dataloader` and see what it will give you.

In [4]:
example_minibatch = next(iter(train_dataloader))
print(example_minibatch)

(tensor([ 644215, 2208123, 1392033,  ...,   69636, 1440917, 1394437]), Graph(num_nodes=7143, num_edges=1024,
      ndata_schemes={'feat': Scheme(shape=(100,), dtype=torch.float32), '_ID': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}), Graph(num_nodes=7143, num_edges=5120,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), [Block(num_src_nodes=145447, num_dst_nodes=34441, num_edges=137458), Block(num_src_nodes=34441, num_dst_nodes=7143, num_edges=28491)])


The example minibatch consists of four elements.

* The input node list necessary for computing the representation of output nodes.
* The subgraph induced by the nodes being sampled in the minibatch (including those in the negative examples) as well as the edges sampled in the minibatch.
* The subgraph induced by the nodes being sampled in the minibatch (including those in the negative examples) as well as the non-existent edges sampled by the negative sampler.
* The list of bipartite graphs, one for each layer.

In [5]:
input_nodes, pos_graph, neg_graph, bipartites = example_minibatch
print('Number of input nodes:', len(input_nodes))
print('Positive graph # nodes:', pos_graph.number_of_nodes(), '# edges:', pos_graph.number_of_edges())
print('Negative graph # noeds:', neg_graph.number_of_nodes(), '# edges:', neg_graph.number_of_edges())
print(bipartites)

Number of input nodes: 145447
Positive graph # nodes: 7143 # edges: 1024
Negative graph # noeds: 7143 # edges: 5120
[Block(num_src_nodes=145447, num_dst_nodes=34441, num_edges=137458), Block(num_src_nodes=34441, num_dst_nodes=7143, num_edges=28491)]


## Defining Model for Node Representation

The model is almost identical to the one in the [node classification tutorial](L1_large_node_classification.ipynb).  The only difference is that since you are doing link prediction, the output dimension will not be the number of classes in the dataset.

In [6]:
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

class Model(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
        self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type='mean')
        self.h_feats = h_feats
        
    def forward(self, bipartites, x):
        h = self.conv1(bipartites[0], x)
        h = F.relu(h)
        h = self.conv2(bipartites[1], h)
        return h
    
model = Model(num_features, 128).cuda()

## Obtaining Node Representation from GNN

The [node classification tutorial](L1_large_node_classification.ipynb) introduced how to obtain node representations without neighbor samplking for inference.  This can be directly copy-pasted for link prediction as well.

In [7]:
def inference(model, graph, input_features, batch_size):
    nodes = torch.arange(graph.number_of_nodes())
    
    sampler = dgl.dataloading.MultiLayerNeighborSampler([None])  # one layer at a time, taking all neighbors
    dataloader = dgl.dataloading.NodeDataLoader(
        graph, nodes, sampler,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=0)
    
    layers = [model.conv1, model.conv2]
    
    with torch.no_grad():
        for l, layer in enumerate(layers):
            # Allocate a buffer of output representations for every node
            # Note that the buffer is on CPU memory.
            output_features = torch.zeros(graph.number_of_nodes(), model.h_feats)

            for input_nodes, output_nodes, bipartites in dataloader:
                bipartite = bipartites[0].to(torch.device('cuda'))

                x = input_features[input_nodes].cuda()

                # the following code is identical to the loop body in model.forward()
                x = layer(bipartite, x)
                if l != len(layers) - 1:
                    x = F.relu(x)

                output_features[output_nodes] = x.cpu()
            input_features = output_features
    return output_features

## Define the Score Predictor for Edges

After getting the node representation necessary for the minibatch, the last thing to do is to predict the score of the edges and non-existent edges in the sampled minibatch.  This can be easily accomplished with `apply_edges` method.  Here, this tutorial will simply compute the score by dot product of the representations of both incident nodes.

In [8]:
class ScorePredictor(nn.Module):
    def forward(self, subgraph, x):
        with subgraph.local_scope():
            subgraph.ndata['x'] = x
            subgraph.apply_edges(dgl.function.u_dot_v('x', 'x', 'score'))
            return subgraph.edata['score']

## Evaluate Performance

There are various ways to evaluate the performance of link prediction.  This tutorial follows the practice of [GraphSAGE paper](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf), where it treats the node embeddings learned by link prediction via training and evaluating a linear classifier on top of the learned node embeddings.  This tutorial implements the evaluation with `scikit-learn`.

In [9]:
import sklearn.metrics

def evaluate(emb, label, train_nids, valid_nids, test_nids):
    classifier = nn.Linear(emb.shape[1], label.max().item()).cuda()
    opt = torch.optim.LBFGS(classifier.parameters())
    def closure():
        pred = classifier(emb[train_nids].cuda())
        loss = F.cross_entropy(pred, label[train_nids].cuda())
        opt.zero_grad()
        loss.backward()
        return loss
    for _ in range(1000):
        opt.step(closure)
    with torch.no_grad():
        pred = classifier(emb.cuda()).cpu()
        label = label
        valid_acc = sklearn.metrics.accuracy_score(label[valid_nids].numpy(), pred[valid_nids].numpy().argmax(1))
        test_acc = sklearn.metrics.accuracy_score(label[test_nids].numpy(), pred[test_nids].numpy().argmax(1))
    return valid_acc, test_acc

## Defining Training Loop

The following initializes the model and defines the optimizer.

In [10]:
model = Model(node_features.shape[1], 128).cuda()
predictor = ScorePredictor().cuda()
opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))

The following is the training loop for unsupervised learning and evaluation, and also saves the model that performs the best on the validation set:

In [None]:
import tqdm
import sklearn.metrics

best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(1):
    with tqdm.tqdm(train_dataloader) as tq:
        for step, (input_nodes, pos_graph, neg_graph, bipartites) in enumerate(tq):
            bipartites = [b.to(torch.device('cuda')) for b in bipartites]
            pos_graph = pos_graph.to(torch.device('cuda'))
            neg_graph = neg_graph.to(torch.device('cuda'))
            inputs = node_features[input_nodes].cuda()
            outputs = model(bipartites, inputs)
            pos_score = predictor(pos_graph, outputs)
            neg_score = predictor(neg_graph, outputs)
            
            score = torch.cat([pos_score, neg_score])
            label = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
            loss = F.binary_cross_entropy_with_logits(score, label)
            
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            tq.set_postfix({'loss': '%.03f' % loss.item()}, refresh=False)
            
            if step % 10000 == 0:
                model.eval()
                emb = inference(model, graph, node_features, 16384)
                valid_acc, test_acc = evaluate(emb, node_labels, train_nids, valid_nids, test_nids)
                print('Epoch {} Validation Accuracy {} Test Accuracy {}'.format(epoch, valid_acc, test_acc))
                if best_accuracy < valid_acc:
                    best_accuracy = valid_acc
                    torch.save(model.state_dict(), best_model_path)
                model.train()

  0%|          | 0/120819 [00:00<?, ?it/s]

## Conclusion

In this tutorial, you have learned how to train a multi-layer GraphSAGE for unsupervised learning via link prediction on a large dataset that cannot fit into GPU.  The method you have learned can scale to a graph of any size, and works on a single machine with a single GPU.