# Stochastic Training of GNN for Link Prediction 

This tutorial will show how to train a multi-layer GraphSAGE for link prediction on Amazon Copurchase Network provided by OGB.  The dataset contains 2.4 million nodes and 61 million edges, hence not fitting a single GPU.

## Link Prediction Overview

Link prediction requires the model to predict the probability of existence of an edge.  This tutorial does so by computing a dot product between the representations of both incident nodes.

$$
\hat{y}_{u\sim v} = \sigma(h_u^T h_v)
$$

It then minimizes the following binary cross entropy loss.

$$
\mathcal{L} = -\sum_{u\sim v\in \mathcal{D}}\left( y_{u\sim v}\log(\hat{y}_{u\sim v}) + (1-y_{u\sim v})\log(1-\hat{y}_{u\sim v})) \right)
$$

This is identical to the link prediction formulation in [the previous tutorial on link prediction](4_link_predict.ipynb).

## Loading Dataset

This tutorial loads the dataset from the `ogb` package as in the [previous tutorial](L1_large_node_classification.ipynb).

In [1]:
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset('ogbn-products')

graph, node_labels = dataset[0]
print(graph)
print(node_labels)

node_features = graph.ndata['feat']
node_labels = node_labels[:, 0]
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
print('Number of classes:', num_classes)

idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
valid_nids = idx_split['valid']
test_nids = idx_split['test']

Using backend: pytorch


Graph(num_nodes=2449029, num_edges=123718280,
      ndata_schemes={'feat': Scheme(shape=(100,), dtype=torch.float32)}
      edata_schemes={})
tensor([[0],
        [1],
        [2],
        ...,
        [8],
        [2],
        [4]])
Number of classes: 47


## Defining Neighbor Sampler and Data Loader in DGL

Different from the [link prediction tutorial for small graph](4_link_predict.ipynb), you will need to iterate over the edges in minibatches, since computing the probability of all edges is usually impossible.  For each minibatch of edges, you compute the output representation of their incident nodes using neighbor sampling and GNN, in a similar fashion introduced in the [large-scale node classification tutorial](L1_large_node_classification.ipynb).

DGL provides `dgl.dataloading.EdgeDataLoader` that allows you to iterate over edges for edge classification or link prediction tasks.

To perform link prediction, you need to specify a negative sampler.  A negative sampler takes in a list of edges as positive examples and returns a list of negative examples.  In DGL, negative samplers can be any callable that has the following signature:

```python
def negative_sampler(g: DGLGraph, eids: Tensor) -> Tuple[Tensor, Tensor]:
    pass
```

The first argument is the original graph and the second argument is the minibatch of edge IDs.  The function returns a pair of $u$-$v^-$ node ID tensors as negative examples.

<div class="alert alert-info">
    
**Note**: for heterogeneous graphs, the signature of the negative sampler will change.  See [here](https://todo) for more details.
    
</div>

The following code implements a negative sampler that find non-existent edges by sampling `k` $v^-$ for each $u$ according to a distribution $P^-(v) \propto d(v)^{0.75}$, where $d(v)$ is the degree of $v$.

In [2]:
class NegativeSampler(object):
    def __init__(self, g, k):
        self.k = k
        self.weights = g.in_degrees().float() ** 0.75
    def __call__(self, g, eids):
        src, _ = g.find_edges(eids)
        src = src.repeat_interleave(self.k)
        dst = self.weights.multinomial(len(src), replacement=True)
        return src, dst

After defining the negative sampler, one can then define the edge data loader with neighbor sampling.  Here this tutorial takes 5 negative examples per positive example.

In [3]:
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
k = 5
train_dataloader = dgl.dataloading.EdgeDataLoader(
    graph, torch.arange(graph.number_of_edges()), sampler,
    negative_sampler=NegativeSampler(graph, k),
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=4
)

You can peek one minibatch from `train_dataloader` and see what it will give you.

In [4]:
example_minibatch = next(iter(train_dataloader))
print(example_minibatch)

(tensor([2223486, 1488848, 1690274,  ..., 1601512,  288853,  258641]), Graph(num_nodes=7142, num_edges=1024,
      ndata_schemes={'feat': Scheme(shape=(100,), dtype=torch.float32), '_ID': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}), Graph(num_nodes=7142, num_edges=5120,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), [Block(num_src_nodes=146105, num_dst_nodes=34484, num_edges=137581), Block(num_src_nodes=34484, num_dst_nodes=7142, num_edges=28455)])


The example minibatch consists of four elements.

* The input node list necessary for computing the representation of output nodes.
* The subgraph induced by the nodes being sampled in the minibatch (including those in the negative examples) as well as the edges sampled in the minibatch.
* The subgraph induced by the nodes being sampled in the minibatch (including those in the negative examples) as well as the non-existent edges sampled by the negative sampler.
* The list of bipartite graphs, one for each layer.

In [5]:
input_nodes, pos_graph, neg_graph, bipartites = example_minibatch
print('Number of input nodes:', len(input_nodes))
print('Positive graph # nodes:', pos_graph.number_of_nodes(), '# edges:', pos_graph.number_of_edges())
print('Negative graph # noeds:', neg_graph.number_of_nodes(), '# edges:', neg_graph.number_of_edges())
print(bipartites)

Number of input nodes: 146105
Positive graph # nodes: 7142 # edges: 1024
Negative graph # noeds: 7142 # edges: 5120
[Block(num_src_nodes=146105, num_dst_nodes=34484, num_edges=137581), Block(num_src_nodes=34484, num_dst_nodes=7142, num_edges=28455)]


## Defining Model for Node Representation

The model is almost identical to the one in the [node classification tutorial](L1_large_node_classification.ipynb).  The only difference is that since you are doing link prediction, the output dimension will not be the number of classes in the dataset.

In [6]:
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

class Model(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
        self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type='mean')
        self.h_feats = h_feats
        
    def forward(self, bipartites, x):
        h = self.conv1(bipartites[0], x)
        h = F.relu(h)
        h = self.conv2(bipartites[1], h)
        return h
    
model = Model(num_features, 128).cuda()

## Defining the Score Predictor for Edges

After getting the node representation necessary for the minibatch, the last thing to do is to predict the score of the edges and non-existent edges in the sampled minibatch.  This can be easily accomplished with `apply_edges` method.  Here, this tutorial will simply compute the score by dot product of the representations of both incident nodes.

In [7]:
class ScorePredictor(nn.Module):
    def forward(self, subgraph, x):
        with subgraph.local_scope():
            subgraph.ndata['x'] = x
            subgraph.apply_edges(dgl.function.u_dot_v('x', 'x', 'score'))
            return subgraph.edata['score']

## Evaluating Performance

There are various ways to evaluate the performance of link prediction.  This tutorial follows the practice of [GraphSAGE paper](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf), where it treats the node embeddings learned by link prediction via training and evaluating a linear classifier on top of the learned node embeddings.

To obtain the representations of all the nodes this tutorial uses neighbor sampling as introduced in the [node classification tutorial](L1_large_node_classification.ipynb).

<div class="alert alert-info">
    
**Note**: if you would like to obtain node representations without neighbor sampling during inference, please refer to this [user guide](https://docs.dgl.ai/guide/minibatch-inference.html).
    
</div>

In [8]:
def inference(model, graph, node_features):
    with torch.no_grad():
        nodes = torch.arange(graph.number_of_nodes())

        sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
        train_dataloader = dgl.dataloading.NodeDataLoader(
            graph, torch.arange(graph.number_of_nodes()), sampler,
            batch_size=1024,
            shuffle=False,
            drop_last=False,
            num_workers=4)

        result = []
        for input_nodes, output_nodes, bipartites in train_dataloader:
            bipartites = [b.to(torch.device('cuda')) for b in bipartites]
            inputs = node_features[input_nodes].cuda()
            result.append(model(bipartites, inputs))

        return torch.cat(result)

In [9]:
import sklearn.metrics

def evaluate(emb, label, train_nids, valid_nids, test_nids):
    classifier = nn.Linear(emb.shape[1], label.max().item()).cuda()
    opt = torch.optim.LBFGS(classifier.parameters())
    
    def compute_loss():
        pred = classifier(emb[train_nids].cuda())
        loss = F.cross_entropy(pred, label[train_nids].cuda())
        return loss
    
    def closure():
        loss = compute_loss()
        opt.zero_grad()
        loss.backward()
        return loss
    
    prev_loss = float('inf')
    for i in range(1000):
        opt.step(closure)
        with torch.no_grad():
            loss = compute_loss().item()
            if np.abs(loss - prev_loss) < 1e-4:
                print('Converges at iteration', i)
                break
            else:
                prev_loss = loss
                
    with torch.no_grad():
        pred = classifier(emb.cuda()).cpu()
        label = label
        valid_acc = sklearn.metrics.accuracy_score(label[valid_nids].numpy(), pred[valid_nids].numpy().argmax(1))
        test_acc = sklearn.metrics.accuracy_score(label[test_nids].numpy(), pred[test_nids].numpy().argmax(1))
    return valid_acc, test_acc

## Defining Training Loop

The following initializes the model and defines the optimizer.

In [10]:
model = Model(node_features.shape[1], 128).cuda()
predictor = ScorePredictor().cuda()
opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))

The following is the training loop for unsupervised learning and evaluation, and also saves the model that performs the best on the validation set:

In [None]:
import tqdm
import sklearn.metrics

best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(1):
    with tqdm.tqdm(train_dataloader) as tq:
        for step, (input_nodes, pos_graph, neg_graph, bipartites) in enumerate(tq):
            bipartites = [b.to(torch.device('cuda')) for b in bipartites]
            pos_graph = pos_graph.to(torch.device('cuda'))
            neg_graph = neg_graph.to(torch.device('cuda'))
            inputs = node_features[input_nodes].cuda()
            outputs = model(bipartites, inputs)
            pos_score = predictor(pos_graph, outputs)
            neg_score = predictor(neg_graph, outputs)
            
            score = torch.cat([pos_score, neg_score])
            label = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
            loss = F.binary_cross_entropy_with_logits(score, label)
            
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            tq.set_postfix({'loss': '%.03f' % loss.item()}, refresh=False)
            
            if step % 1000 == 0:
                model.eval()
                emb = inference(model, graph, node_features)
                valid_acc, test_acc = evaluate(emb, node_labels, train_nids, valid_nids, test_nids)
                print('Epoch {} Validation Accuracy {} Test Accuracy {}'.format(epoch, valid_acc, test_acc))
                if best_accuracy < valid_acc:
                    best_accuracy = valid_acc
                    torch.save(model.state_dict(), best_model_path)
                model.train()

  0%|          | 0/120819 [00:00<?, ?it/s]

Converges at iteration 19


  0%|          | 1/120819 [02:07<4281:11:42, 127.57s/it, loss=89.331]

Epoch 0 Validation Accuracy 0.7142384863820156 Test Accuracy 0.5608115527106657


  1%|          | 1000/120819 [05:41<6:40:05,  4.99it/s, loss=0.937]  

Converges at iteration 12


  1%|          | 1001/120819 [07:26<1052:04:04, 31.61s/it, loss=0.905]

Epoch 0 Validation Accuracy 0.5542049182412329 Test Accuracy 0.42179783840791


  2%|▏         | 2000/120819 [10:57<7:12:22,  4.58it/s, loss=0.727]   

Converges at iteration 14


  2%|▏         | 2001/120819 [12:33<955:05:57, 28.94s/it, loss=0.734]

Epoch 0 Validation Accuracy 0.596622841593978 Test Accuracy 0.459599266365459


  2%|▏         | 3000/120819 [16:01<6:49:50,  4.79it/s, loss=0.703]  

Converges at iteration 14


  2%|▏         | 3001/120819 [17:38<962:19:21, 29.40s/it, loss=0.697]

Epoch 0 Validation Accuracy 0.6528749078147649 Test Accuracy 0.5134461258032318


  3%|▎         | 4000/120819 [21:08<6:32:22,  4.96it/s, loss=0.690]  

Converges at iteration 16


  3%|▎         | 4001/120819 [22:44<942:44:40, 29.05s/it, loss=0.663]

Epoch 0 Validation Accuracy 0.6992854054878824 Test Accuracy 0.5486733261307375


  4%|▍         | 5000/120819 [26:16<7:49:07,  4.11it/s, loss=0.657]  

Converges at iteration 14


  4%|▍         | 5001/120819 [27:51<921:38:06, 28.65s/it, loss=0.669]

Epoch 0 Validation Accuracy 0.7449584212801669 Test Accuracy 0.5972452104319254


  5%|▍         | 6000/120819 [31:21<7:02:26,  4.53it/s, loss=0.651]  

Converges at iteration 14


  5%|▍         | 6001/120819 [32:54<902:48:12, 28.31s/it, loss=0.645]

Epoch 0 Validation Accuracy 0.7749408742974849 Test Accuracy 0.624674267800104


  6%|▌         | 7000/120819 [36:27<7:03:32,  4.48it/s, loss=0.645]  

Converges at iteration 12


  6%|▌         | 7001/120819 [38:10<984:17:22, 31.13s/it, loss=0.641]

Epoch 0 Validation Accuracy 0.7919792487856979 Test Accuracy 0.6426748832289318


  7%|▋         | 8000/120819 [41:38<6:14:32,  5.02it/s, loss=0.647]  

Converges at iteration 14


  7%|▋         | 8001/120819 [43:25<1012:19:35, 32.30s/it, loss=0.636]

Epoch 0 Validation Accuracy 0.7939373903313582 Test Accuracy 0.6474717939750331


  7%|▋         | 9000/120819 [46:52<6:13:01,  5.00it/s, loss=0.642]   

Converges at iteration 12


  7%|▋         | 9001/120819 [48:19<816:26:21, 26.29s/it, loss=0.641]

Epoch 0 Validation Accuracy 0.8002441319329654 Test Accuracy 0.6522131263468154


  8%|▊         | 10000/120819 [51:45<6:09:26,  5.00it/s, loss=0.631] 

Converges at iteration 11


  8%|▊         | 10001/120819 [53:22<894:46:42, 29.07s/it, loss=0.636]

Epoch 0 Validation Accuracy 0.8068051776314117 Test Accuracy 0.6554728205934596


  9%|▉         | 11000/120819 [56:50<6:14:44,  4.88it/s, loss=0.636]  

Converges at iteration 13


  9%|▉         | 11001/120819 [58:28<898:17:37, 29.45s/it, loss=0.635]

Epoch 0 Validation Accuracy 0.8087633191770719 Test Accuracy 0.6603591989665134


 10%|▉         | 12000/120819 [1:02:00<6:00:40,  5.03it/s, loss=0.639]

Converges at iteration 11


 10%|▉         | 12001/120819 [1:03:33<846:31:49, 28.01s/it, loss=0.634]

Epoch 0 Validation Accuracy 0.8088904712254914 Test Accuracy 0.6621693369138458


 11%|█         | 13000/120819 [1:06:59<6:22:16,  4.70it/s, loss=0.640]  

Converges at iteration 12


 11%|█         | 13001/120819 [1:08:29<813:03:15, 27.15s/it, loss=0.628]

Epoch 0 Validation Accuracy 0.8103145741677898 Test Accuracy 0.664467931955803


 12%|█▏        | 14000/120819 [1:11:58<6:30:34,  4.56it/s, loss=0.626]  

Converges at iteration 11


 12%|█▏        | 14001/120819 [1:13:32<842:16:26, 28.39s/it, loss=0.632]

Epoch 0 Validation Accuracy 0.815095491188363 Test Accuracy 0.669851804557517


 12%|█▏        | 15000/120819 [1:17:03<6:04:56,  4.83it/s, loss=0.631]  

Converges at iteration 11


 12%|█▏        | 15001/120819 [1:18:34<812:17:23, 27.63s/it, loss=0.640]

Epoch 0 Validation Accuracy 0.8154006561045698 Test Accuracy 0.6722141113944252


 13%|█▎        | 15511/120819 [1:20:20<5:52:29,  4.98it/s, loss=0.637]  

## Conclusion

In this tutorial, you have learned how to train a multi-layer GraphSAGE for unsupervised learning via link prediction on a large dataset that cannot fit into GPU.  The method you have learned can scale to a graph of any size, and works on a single machine with a single GPU.