# Link Prediction using Graph Neural Networks

This tutorial teaches the basic workflow of using GNNs for link prediction, i.e. predicting whether an edge exist between two nodes. This tutorial again uses the Cora dataset but try to predict interactions (citation relationships) between two papers in a graph.

Goal of this tutorial:

* Prepare training and testing sets for link prediction task.
* Build a GNN-based link prediction model.
* Train the model and verify the result.

<div class="alert alert-info">
    <b>Note: </b>The Cora dataset provided by DGL is bidirectional, meaning that the edges can only represent whether a citation relationship exist between two papers: they cannot tell which paper cites which other paper.
</div>

In [2]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import numpy as np
import scipy.sparse as sp

Using backend: pytorch


## Load graph and features

The dataset used in the tutorial will still be Cora following the [introduction](1_introduction.ipynb).

In [3]:
import dgl.data

dataset = dgl.data.CoraGraphDataset()
g = dataset[0]

Loading from cache failed, re-processing.
Finished data loading and preprocessing.
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.


## Prepare training and testing sets

In general, a link prediction data set contains two types of edges, *positive* and *negative edges*. Positive edges are usually drawn from the existing edges in the graph. This tutorial randomly picks 1000 edges for testing and leaves the rest for training.

In [4]:
# Split edge set for training and testing
TEST_SIZE = 1000
u, v = g.edges()
eids = np.arange(g.num_edges())
eids = np.random.permutation(eids)
test_eids = eids[:TEST_SIZE]
train_eids = eids[TEST_SIZE:]
test_u, test_v = u[test_eids], v[test_eids]
train_u, train_v = u[train_eids], v[train_eids]

Because you will be predicting whether an edge in the test set exists on a graph, the test edges should not exist in the training graph.  We therefore need to train on a graph only consisting of the edges in the training set.

In [5]:
g = g.edge_subgraph(train_eids, preserve_nodes=True)

## Define a GraphSAGE model

Our model will be a two-layer [GraphSAGE convolution (Hamilton et al., 2017)](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf).  DGL supports, alongside GraphSAGE via [`dgl.nn.SAGEConv`](https://docs.dgl.ai/api/python/nn.pytorch.html#sageconv), [many other graph convolution modules](https://docs.dgl.ai/api/python/nn.pytorch.html#module-dgl.nn.pytorch.conv).

In [13]:
from dgl.nn import SAGEConv

# ----------- 2. create model -------------- #
# build a two-layer GraphSAGE model
class GraphSAGE(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, 'mean')
        self.conv2 = SAGEConv(h_feats, h_feats, 'mean')
    
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h
    
# Create the model with hidden layer dimension 16.
net = GraphSAGE(g.ndata['feat'].shape[1], 16)

We then optimize the model using the following loss function.

$$
\hat{y}_{u\to v} = \sigma(h_u^T h_v)
$$

$$
\mathcal{L} = -\sum_{u\to v\in \mathcal{D}}\left( y_{u\to v}\log(\hat{y}_{u\to v}) + (1-y_{u\to v})\log(1-\hat{y}_{u\to v})) \right)
$$

Essentially, to predict whether an edge exists between two nodes, the model predicts a score by computing a dot product between both nodes' representations.  The model then minimizes the loss function above so that node pairs that have an edge (or *positive examples*) in between get a higher score, while the other node pairs that do not have an edge in between (or *negative examples*) get a lower score.

Since the number of possible node pairs is large, one often samples a small number of node pairs and computes loss only on the sampled node pairs instead.  This is called *negative sampling*.  Here we simply pick the node pairs uniformly; more sophisticated negative sampling strategies are beyond the scope of this tutorial.

In [None]:
def generate_negative_examples(g):
    # Randomly pick as many negative examples as the positive examples
    # (i.e. all edges in the training graph).
    neg_u = torch.randint(0, g.num_nodes(), (g.num_edges(),))
    neg_v = torch.randint(0, g.num_nodes(), (g.num_edges(),))
    return neg_u, neg_v

In [17]:
# ----------- 3. set up loss and optimizer -------------- #
# in this case, loss will in training loop
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)

# ----------- 4. training -------------------------------- #
all_logits = []
for e in range(100):
    # Forward computation that computes output embedding of the nodes
    h = net(g, g.ndata['feat'])
    
    neg_u, neg_v = generate_negative_examples(g)
    pred_pos = (h[train_u] * logits[train_v]).sum(dim=1)
    pred_neg = (h[neg_u] * h[neg_v]).sum(dim=1)
    label_pos = torch.ones_like(pred_pos)
    label_neg = torch.zeros_like(pred_neg)
    pred = torch.cat([pred_pos, pred_neg])
    label = torch.cat([label_pos, label_neg])
    # compute loss
    loss = F.binary_cross_entropy(pred, label)
    
    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    all_logits.append(logits.detach())
    
    if e % 5 == 0:
        print('In epoch {}, loss: {}'.format(e, loss))

In epoch 0, loss: 0.715376615524292
In epoch 5, loss: 0.6942322850227356
In epoch 10, loss: 0.6904299855232239
In epoch 15, loss: 0.683499276638031
In epoch 20, loss: 0.6637445092201233
In epoch 25, loss: 0.6305882334709167
In epoch 30, loss: 0.5955677032470703
In epoch 35, loss: 0.5764950513839722
In epoch 40, loss: 0.5546678900718689
In epoch 45, loss: 0.5326592922210693
In epoch 50, loss: 0.5074127912521362
In epoch 55, loss: 0.47796398401260376
In epoch 60, loss: 0.451864093542099
In epoch 65, loss: 0.42299002408981323
In epoch 70, loss: 0.3948816955089569
In epoch 75, loss: 0.36992791295051575
In epoch 80, loss: 0.3451216220855713
In epoch 85, loss: 0.32089129090309143
In epoch 90, loss: 0.2978387773036957
In epoch 95, loss: 0.27626192569732666


In [18]:
# ----------- 5. check results ------------------------ #
pred = torch.sigmoid((logits[test_u] * logits[test_v]).sum(dim=1))
print('Accuracy', ((pred >= 0.5) == test_label).sum().item() / len(pred))

Accuracy 0.7495


## What's next?

If you wish to scale up your link prediction model, please see the tutorial [Stochastic Training of GNN for Link Prediction on Large Graphs](L2_large_link_prediction.ipynb).
* The training experience on large graph is different from training on full graphs, so we recommend you to go through the tutorial [Stochastic Training of GNN for Node Classification on Large Graphs](L1_large_node_classification.ipynb) first to get an idea of how large graph training works.

If you have heterogeneous graphs, please see the tutorial [Link Prediction on Heterogeneous Graphs (TODO)](H4_link_predict.ipynb).
* We recommend you to go through the tutorial [Node Classification on Heterogeneous Graphs (TODO)](H1_node_classification.ipynb) to get an idea of how heterogeneous graph training works.