# Link prediction on social network using DGL

In [None]:
# !pip uninstall -y dgl
# !pip install  dgl==2.2.1 -f https://data.dgl.ai/wheels/torch-2.3/repo.html

In [None]:
# import the social network graph
import pickle
with open('test.gpickle', 'rb') as f:
  Gnx = pickle.load(f)

In [None]:
import dgl

# convert the graph from networkx to dgl. We are now ready to start learning
G = dgl.from_networkx(Gnx)

In the code above, we are implementing a GraphSAGE model to perform link prediction on a graph using the Deep Graph Library (DGL) and PyTorch. We start by setting up the computational device and initializing the node features as identity matrices. The graph's edges are then split into training and testing sets to evaluate the model's performance on unseen data. Negative edges are sampled to serve as negative examples during training. We define a GraphSAGE model with two layers that aggregate neighbor information and a dot-product-based edge predictor to compute edge scores. The model is trained using binary cross-entropy loss, optimized with the Adam optimizer. After training for a specified number of epochs, we evaluate the model's performance using common metrics on the test set.

In [None]:
import dgl
import torch
import torch.nn.functional as F
from dgl.nn import SAGEConv
from sklearn.metrics import f1_score, precision_score, recall_score
import numpy as np
import scipy.sparse as sp
from torch import nn
import itertools
import dgl.function as fn

# Set the computation device to GPU if available, otherwise CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Assuming graph G is pre-defined and moving it to the computation device
graph = G.to(device)

Once the graph is loaded, we need to perform the following steps:
- assign the fake features (i.e. the identity matrix)
- splitting edges into training edges (90%) and test edges (10%)

In [None]:
# Assigning a unique identity feature to each node
# This helps the model to have initial distinguishable features for each node
node_features = torch.eye(graph.number_of_nodes()).to(device)
graph.ndata['features'] = node_features

# Splitting edges into training and test sets
# This helps in evaluating the model performance on unseen data
src_nodes, dst_nodes = graph.edges()
edge_ids = np.arange(graph.number_of_edges())
np.random.shuffle(edge_ids)

# Define the number of test edges (10% of total edges)
test_edge_count = int(0.1 * len(edge_ids))
train_edge_count = len(edge_ids) - test_edge_count

Next, we need to find negative (i.e. non existent) edges. This because we may want to train the model whether an edge exists.. or not!
We will be doing this by defining an adjacency matrix and randomly picking negative edges.

Finally, we create a test graph for model evaluation.

In [None]:
# Splitting edges into positive training and testing sets
# Positive edges simulate the real edges in the graph
test_pos_src, test_pos_dst = src_nodes[edge_ids[:test_edge_count]], dst_nodes[edge_ids[:test_edge_count]]
train_pos_src, train_pos_dst = src_nodes[edge_ids[test_edge_count:]], dst_nodes[edge_ids[test_edge_count:]]

# Creating an adjacency matrix and finding negative edges
# Negative edges are non-existent edges in the graph used for negative sampling
adj_matrix = sp.coo_matrix((np.ones(len(src_nodes)), (src_nodes.numpy(), dst_nodes.numpy())), shape=(graph.number_of_nodes(), graph.number_of_nodes()))
neg_adj_matrix = 1 - adj_matrix.toarray() - np.eye(graph.number_of_nodes())
neg_src, neg_dst = np.where(neg_adj_matrix != 0)
neg_edge_ids = np.random.choice(len(neg_src), size=graph.number_of_edges(), replace=False)

# Splitting negative edges into training and testing sets
# These edges serve as negative samples during training and testing
test_neg_src, test_neg_dst = neg_src[neg_edge_ids[:test_edge_count]], neg_dst[neg_edge_ids[:test_edge_count]]
train_neg_src, train_neg_dst = neg_src[neg_edge_ids[test_edge_count:]], neg_dst[neg_edge_ids[test_edge_count:]]

# Creating a training graph by removing test edges
# This prevents the model from training on test data and helps evaluate its generalization capability
train_graph = dgl.remove_edges(graph, edge_ids[:test_edge_count])

We are now ready to train the model.
The next steps are the followings:-
- create a GNN model (we choose a GraphSAGE model in this case)
- attach an edge predictor (in this case we choose to compute the "existence" score for an edge by taking the dot product of the embeddings of the two end nodes
- implement the train loop which computes the predictions, the loss value, and applies backpropagate to update the model weights.

In [None]:
# Building the GraphSAGE model
# This model consists of two GraphSAGE layers, each computes new node representations by averaging neighbor information
# DGL provides dgl.nn.SAGEConv that conveniently creates a GraphSAGE layer
class GraphSAGENetwork(nn.Module):
    def __init__(self, in_feats, hidden_feats):
        super(GraphSAGENetwork, self).__init__()
        self.conv1 = SAGEConv(in_feats, hidden_feats, aggregator_type='mean')
        self.conv2 = SAGEConv(hidden_feats, hidden_feats, aggregator_type='mean')

    def forward(self, g, features):
        h = self.conv1(g, features)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

# Defining the edge predictor using dot product
# This predictor computes the score for an edge by taking the dot product of the embeddings of the two end nodes
class DotProductPredictor(nn.Module):
    def forward(self, graph, node_embeddings):
        with graph.local_scope():
            graph.ndata['h'] = node_embeddings
            graph.apply_edges(dgl.function.u_dot_v('h', 'h', 'score'))
            return graph.edata['score'][:, 0]

# Initialize the GraphSAGE model and the predictor
sage_model = GraphSAGENetwork(graph.ndata['features'].shape[1], 16).to(device)
predictor = DotProductPredictor().to(device)

# Function to compute the loss
# This combines the positive and negative scores and uses binary cross-entropy loss to measure performance
def compute_loss(pos_scores, neg_scores):
    scores = torch.cat([pos_scores, neg_scores])
    labels = torch.cat([torch.ones_like(pos_scores), torch.zeros_like(neg_scores)])
    return F.binary_cross_entropy_with_logits(scores, labels)

# Optimizer setup
# Using Adam optimizer to update model parameters based on the gradients computed during backpropagation
optimizer = torch.optim.Adam(itertools.chain(sage_model.parameters(), predictor.parameters()), lr=0.01)

# Training loop
# The model is trained for a specified number of epochs
for epoch in range(100):
    sage_model.train()

    # Compute node embeddings
    node_embeddings = sage_model(train_graph, train_graph.ndata['features'])

    # Compute scores for positive and negative edges
    pos_scores = predictor(dgl.graph((train_pos_src, train_pos_dst), num_nodes=graph.number_of_nodes()).to(device), node_embeddings)
    neg_scores = predictor(dgl.graph((train_neg_src, train_neg_dst), num_nodes=graph.number_of_nodes()).to(device), node_embeddings)

    # Compute loss
    loss = compute_loss(pos_scores, neg_scores)

    # Backward pass: compute gradients and update model parameters
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print loss every 5 epochs
    if epoch % 5 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')

Let's evaluate the model by means of f1-score, precision and recall.

In [None]:

def normalize(scores):
  return (scores - scores.min()) / (scores.max() - scores.min())

# Define the score computation to evaluate model performance on classification tasks
def compute_scores(positive_scores, negative_scores):
    scores = torch.cat([positive_scores, negative_scores]).numpy()
    labels = torch.cat([torch.ones(positive_scores.shape[0]), torch.zeros(negative_scores.shape[0])]).numpy()
    return (f1_score(labels, scores),
            precision_score(labels, scores),
            recall_score(labels, scores))

test_pos_graph = dgl.graph((test_pos_src, test_pos_dst), num_nodes=graph.number_of_nodes()).to(device)
test_neg_graph = dgl.graph((test_neg_src, test_neg_dst), num_nodes=graph.number_of_nodes()).to(device)
test_node_embeddings = sage_model(graph, graph.ndata['features'])

# Evaluate model performance using proper metrics
with torch.no_grad():
    test_pos_scores = predictor(test_pos_graph, test_node_embeddings)
    test_neg_scores = predictor(test_neg_graph, test_node_embeddings)

    pos_test_scores = predictor(test_pos_graph, node_embeddings)
    neg_test_scores = predictor(test_neg_graph, node_embeddings)

    pos_test_scores = (normalize(pos_test_scores) > 0.5) * 1
    neg_test_scores = (normalize(neg_test_scores) > 0.5) * 1

    f1, prec, rec = compute_scores(pos_test_scores, neg_test_scores)
    print(f'F1 Score: {f1}')
    print(f'Precision: {prec}')
    print(f'Recall: {rec}')

## Dealing with large graphs
In the previous example we have seen how to predict link using DGL. However, you may have noticed that we computed the probability of all edges at once during training, which, in case of large graphs, is not feasible.

To overcome this issue, we can use some functionalities provided by graph machine learning libraries, including DGL. In the next example, instead of fitting the whole graph in memory, we will be iterating over the edges in minibatches.

For readability we are not going to implement the validation and testing part, however it can be done as we have done above!

In [None]:
# DGL provides dgl.dataloading.EdgeDataLoader to iterate over edges for edge classification or link prediction tasks.
# For link prediction, we also need to specify a negative sampler
# builtin negative samplers ( non-existing edges) such as dgl.dataloading.negative_sampler.Uniform can be used for this purpose.

# load 5 negative sample per each positive sample (existing edges)
negative_sampler = dgl.dataloading.negative_sampler.Uniform(5)

# define the edge loader
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
sampler = dgl.dataloading.as_edge_prediction_sampler(
    sampler, negative_sampler=negative_sampler)

dataloader = dgl.dataloading.DataLoader(
    # The following arguments are specific to NodeDataLoader.
    graph,                                      # The graph
    torch.arange(graph.number_of_edges()),  # The edges to iterate over
    sampler,                                # The neighbor sampler
    device=device,                          # Put the MFGs on CPU or GPU
    # The following arguments are inherited from PyTorch DataLoader.
    batch_size=128,    # Batch size
    shuffle=True,       # Whether to shuffle the nodes for every epoch
    drop_last=False,    # Whether to drop the last incomplete batch
    num_workers=0       # Number of sampler processes
)

In [None]:
input_nodes, pos_graph, neg_graph, mfgs = next(iter(dataloader))
print('Number of input nodes:', len(input_nodes))
print('Positive graph # nodes:', pos_graph.number_of_nodes(), '# edges:', pos_graph.number_of_edges())
print('Negative graph # nodes:', neg_graph.number_of_nodes(), '# edges:', neg_graph.number_of_edges())

print(mfgs)
# Notice that the last element is a list of message flow graphs (MFGs) storing the computation dependencies for each GNN layer.
# The MFGs are used to compute the GNN outputs of the nodes involved in positive/negative graph.
# Check more on https://docs.dgl.ai/en/0.8.x/generated/dgl.dataloading.BlockSampler.html

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

class GraphSAGENetwork(nn.Module):
    def __init__(self, in_feats, hidden_feats):
        super(GraphSAGENetwork, self).__init__()
        self.conv1 = SAGEConv(in_feats, hidden_feats, aggregator_type='mean')
        self.conv2 = SAGEConv(hidden_feats, hidden_feats, aggregator_type='mean')

    def forward(self, g, features):
        h = self.conv1(g[0], features)
        h = F.relu(h)
        h = self.conv2(g[1], h)
        return h

# Defining the edge predictor using dot product
# This predictor computes the score for an edge by taking the dot product of the embeddings of the two end nodes
class DotProductPredictor(nn.Module):
    def forward(self, graph, node_embeddings):
        with graph.local_scope():
            graph.ndata['h'] = node_embeddings
            graph.apply_edges(dgl.function.u_dot_v('h', 'h', 'score'))
            return graph.edata['score'][:, 0]

# Initialize the GraphSAGE model and the predictor
sage_model = GraphSAGENetwork(graph.ndata['features'].shape[1], 16).to(device)
predictor = DotProductPredictor().to(device)

# Optimizer setup
# Using Adam optimizer to update model parameters based on the gradients computed during backpropagation
optimizer = torch.optim.Adam(itertools.chain(sage_model.parameters(), predictor.parameters()), lr=0.01)

# Training loop
# The model is trained for a specified number of epochs
for epoch in range(5):
  total_loss = total_examples = 0
  for (input_nodes, pos_graph, neg_graph, mfgs) in dataloader:
    sage_model.train()

    input_features = mfgs[0].srcdata['features']

    # Compute node embeddings
    node_embeddings = sage_model(mfgs, input_features)

    # Compute scores for positive and negative edges
    pos_scores = predictor(pos_graph, node_embeddings)
    neg_scores = predictor(neg_graph, node_embeddings)

    # Compute loss
    loss = compute_loss(pos_scores, neg_scores)

    # Backward pass: compute gradients and update model parameters
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += float(loss) * (len(pos_scores) + len(neg_scores))
    total_examples += (len(pos_scores) + len(neg_scores))

  print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")

# Link prediction on social network using PyG
We will now replicate the example using another popular library for graph machine learning: Pytorch Geometric

In [None]:
!pip install torch_geometric

# Optional dependencies:
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.3.0+cpu.html

In [None]:
from torch_geometric.utils.convert import from_networkx
import torch_geometric.transforms as T
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import SAGEConv
import torch.nn.functional as F

# Convert the graph into PyTorch geometric
G = from_networkx(Gnx)

In [None]:
# let's add fake features
G.x = torch.eye(G.num_nodes)

In [None]:
# we first split the set of edges into training (80%), validation (10%),
# and testing edges (10%). We also generate fixed negative (non existing)
# edges for evaluation with a ratio of 2:1.
# We can leverage the `RandomLinkSplit()` transform to perform all the steps:
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.3,
    neg_sampling_ratio=2.0,
    add_negative_train_samples=False
)
train_data, val_data, test_data = transform(G)

Similar to what we have done above, we will be using a mini-batch loader: our graph is quite small, so it is perfectly fine to load it in memory while training. However, for larger graphs, since computing the probability of all edges is usually not feasible, a mini-batch loader is required to load parts of the graph step by step.

PyG makes use of the loader.LinkNeighborLoader to sample multiple hops from both ends of a link and creates a subgraph from it.

In [None]:
# Define seed edges:
edge_label_index = train_data.edge_label_index
edge_label = train_data.edge_label
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[20, 20],
    neg_sampling_ratio=2.0,
    edge_label_index=edge_label_index,
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

In [None]:
# Building the GraphSAGE model
# This model consists of two GraphSAGE layers, each computes new node representations by averaging neighbor information
class GraphSAGENetwork(nn.Module):
    def __init__(self, in_feats, hidden_feats):
        super(GraphSAGENetwork, self).__init__()
        self.conv1 = SAGEConv(in_feats, hidden_feats)
        self.conv2 = SAGEConv(hidden_feats, hidden_feats)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = F.relu(h)
        h = self.conv2(h, edge_index)
        return h

# Defining the edge predictor using dot product
# This predictor computes the score for an edge by taking the dot product of the embeddings of the two end nodes
class DotProductPredictor(nn.Module):
    def forward(self, z, edge_index):
        src, dst = edge_index
        return (z[src] * z[dst]).sum(dim=-1)

# Initialize the GraphSAGE model and the predictor
sage_model = GraphSAGENetwork(G.num_features, 16).to(device)
predictor = DotProductPredictor().to(device)

In [None]:
# Function to compute the loss
# This combines the positive and negative scores and uses binary cross-entropy loss to measure performance
def compute_loss(pred, ground_truth):
    loss = F.binary_cross_entropy_with_logits(pred, ground_truth)
    return loss

# Function to compute the prediction score
def compute_scores(labels, scores):
    return (f1_score(labels, scores),
            precision_score(labels, scores),
            recall_score(labels, scores))

In [None]:
from tqdm import tqdm

# Optimizer setup
# Using Adam optimizer to update model parameters based on the gradients computed during backpropagation
optimizer = torch.optim.Adam(itertools.chain(sage_model.parameters(), predictor.parameters()), lr=0.01)

# Training loop
# The model is trained for a specified number of epochs
for epoch in range(1):
    sage_model.train()
    total_loss = total_examples = 0

    for batch in tqdm(train_loader):
      optimizer.zero_grad()
      batch.to(device)

      # Compute node embeddings
      node_embeddings = sage_model(batch.x, batch.edge_index)
      scores = predictor(node_embeddings, batch.edge_label_index)

      # Compute loss
      loss = compute_loss(scores, batch.edge_label)

      # Backward pass: compute gradients and update model parameters
      loss.backward()
      optimizer.step()
      total_loss += float(loss) * scores.numel()
      total_examples += scores.numel()

    print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")

Let's evaluate the model. For doing this we will be creating a proper linkneighborloader

In [None]:
# Define the validation seed edges:
edge_label_index = val_data.edge_label_index
edge_label = val_data.edge_label
val_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors=[20, 20],
    edge_label_index=edge_label_index,
    edge_label=edge_label,
    batch_size=128,
    shuffle=False,
)
sampled_data = next(iter(val_loader))
sampled_data

In [None]:
preds = []
ground_truths = []

for batch in tqdm(val_loader):
    with torch.no_grad():
        batch.to(device)

        # compute predictions
        node_embeddings = sage_model(batch.x, batch.edge_index)
        scores = predictor(node_embeddings, batch.edge_label_index)

        preds.append(scores)
        ground_truths.append(batch.edge_label)

In [None]:
def normalize(scores):
  return (scores - scores.min()) / (scores.max() - scores.min())

pred = torch.cat(preds, dim=0).cpu().numpy()
ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()

pred = normalize(pred) > 0.5
ground_truth = normalize(ground_truth) > 0.5

f1, prec, rec = compute_scores(ground_truth, pred)

print(f'F1 Score: {f1}')
print(f'Precision: {prec}')
print(f'Recall: {rec}')