<a href="https://colab.research.google.com/github/Michaelmvh/bozdag-research-GNN/blob/main/Exploring_RGCN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RGCN Link Prediction Example

Source: [PyTorch Geometric Examples: RGCN Link Prediction](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rgcn_link_pred.py)

# Imports

In [1]:
# pip installs
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1


In [2]:
import pandas as pd
import networkx as nx
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn.models import GCN
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import from_networkx
from sklearn.metrics import roc_auc_score
from torch_geometric.loader import DataLoader
import os.path as osp
import time
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from tqdm import tqdm
from torch_geometric.datasets import RelLinkPredDataset
from torch_geometric.nn import GAE, RGCNConv

# Setup

## Constants

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = "cpu"
print(device)

cuda


# Load Dataset

In [4]:
# # path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'RLPD')
# path = osp.curdir
# dataset = RelLinkPredDataset(path, 'FB15k-237')
# data = dataset[0].to(device)

In [19]:
import torch
from torch_geometric.datasets import RelLinkPredDataset
from torch_geometric.loader import DataLoader  # Use PyG's DataLoader
from torch.utils.data import Dataset

class LinkPredDataset(Dataset):
    def __init__(self, edge_index, edge_type):
        self.edge_index = edge_index
        self.edge_type = edge_type

    def __len__(self):
        return self.edge_index.size(1)

    def __getitem__(self, idx):
        return self.edge_index[:, idx], self.edge_type[idx]

# Load the dataset
path = osp.curdir  # Or your desired path
dataset = RelLinkPredDataset(path, 'FB15k-237')
data = dataset[0].to(device)

# Create train and test/validation datasets for edges
train_dataset = LinkPredDataset(data.train_edge_index, data.train_edge_type)
# Assuming you want to combine validation and test for evaluation
val_test_edge_index = torch.cat([data.valid_edge_index, data.test_edge_index], dim=1)
val_test_edge_type = torch.cat([data.valid_edge_type, data.test_edge_type])
val_test_dataset = LinkPredDataset(val_test_edge_index, val_test_edge_type)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_test_loader = DataLoader(val_test_dataset, batch_size=128, shuffle=False)  # No need to shuffle for evaluation

# Define model

In [6]:
class RGCNEncoder(torch.nn.Module):
    def __init__(self, num_nodes, hidden_channels, num_relations):
        super().__init__()
        self.node_emb = Parameter(torch.empty(num_nodes, hidden_channels)) # Treat x as learnable param
        self.conv1 = RGCNConv(hidden_channels, hidden_channels, num_relations,
                              num_bases=5)
        self.conv2 = RGCNConv(hidden_channels, hidden_channels, num_relations,
                              num_bases=5)
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.node_emb) # Initialize x with xavier uniform distribution
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, edge_index, edge_type):
        x = self.node_emb
        x = self.conv1(x, edge_index, edge_type).relu_()
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv2(x, edge_index, edge_type)
        return x

In [7]:
class DistMultDecoder(torch.nn.Module):
    def __init__(self, num_relations, hidden_channels):
        super().__init__()
        self.rel_emb = Parameter(torch.empty(num_relations, hidden_channels))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.rel_emb)

    def forward(self, z, edge_index, edge_type):
        z_src, z_dst = z[edge_index[0]], z[edge_index[1]]
        rel = self.rel_emb[edge_type]
        return torch.sum(z_src * rel * z_dst, dim=1)

In [8]:
def negative_sampling(edge_index, num_nodes):
    # Sample edges by corrupting either the subject or the object of each edge.
    mask_1 = torch.rand(edge_index.size(1)) < 0.5
    mask_2 = ~mask_1

    neg_edge_index = edge_index.clone()
    neg_edge_index[0, mask_1] = torch.randint(num_nodes, (mask_1.sum(), ),
                                              device=neg_edge_index.device)
    neg_edge_index[1, mask_2] = torch.randint(num_nodes, (mask_2.sum(), ),
                                              device=neg_edge_index.device)
    return neg_edge_index

# Define Train/Test

In [9]:
# def train(model, optimizer, train_loader):
#     model.train()
#     optimizer.zero_grad()


#     for batch_idx, (edge_index, edge_type) in enumerate(train_loader):
#       z = model.encode(data.edge_index, data.edge_type)

#       pos_out = model.decode(z, edge_index, edge_type)

#       neg_edge_index = negative_sampling(data.train_edge_index, data.num_nodes)
#       neg_out = model.decode(z, neg_edge_index, data.train_edge_type)

#       out = torch.cat([pos_out, neg_out])
#       gt = torch.cat([torch.ones_like(pos_out), torch.zeros_like(neg_out)])
#       cross_entropy_loss = F.binary_cross_entropy_with_logits(out, gt)
#       reg_loss = z.pow(2).mean() + model.decoder.rel_emb.pow(2).mean()
#       loss = cross_entropy_loss + 1e-2 * reg_loss

#     loss.backward()
#     torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
#     optimizer.step()

#     return float(loss)

In [10]:
def train(model, optimizer, train_loader, data): # Pass the full data object and device
    model.train()
    total_loss = 0
    num_batches = 0

    # --- Encode the full graph ONCE before the loop ---
    # This is still a potential memory bottleneck for very large graphs.
    # For truly large graphs, consider neighbor sampling techniques (e.g., NeighborLoader).
    with torch.no_grad(): # No need to track gradients for encoding if done once
         z = model.encode(data.edge_index, data.edge_type)
    # ----------------------------------------------------

    for batch_idx, (batch_edge_index, batch_edge_type) in enumerate(train_loader):
        batch_edge_index = batch_edge_index.T # Fix this later on
        optimizer.zero_grad()

        # --- Use pre-computed full graph embeddings 'z' ---
        # Decode positive edges from the current batch
        pos_out = model.decode(z, batch_edge_index, batch_edge_type)

        # Negative sampling based on the edges in the current batch
        # Note: negative_sampling function needs modification or replacement
        # if you want batch-specific negative sampling.
        # Current negative_sampling samples from the *entire* training set.
        # For simplicity here, we'll sample globally but decode using batch indices.
        # A more advanced approach would sample negatives only involving nodes in the batch.

        # Sample negatives globally (as in original code)
        neg_edge_index = negative_sampling(data.train_edge_index, data.num_nodes) # Sample globally
        neg_edge_index = neg_edge_index.to(device)

        # Select corresponding global edge types for the sampled negative edges
        # This is an approximation; ideally, you'd know the types for corrupted edges.
        # If negative_sampling corrupts based on existing edges, we might reuse types.
        # Let's assume we need to select types corresponding to the *indices* of the
        # original edges from which negatives were derived, which is complex.
        # A simpler (though less precise) approach might be to randomly assign types
        # or use the types corresponding to the *original* positive edges being corrupted.
        # Given the original `negative_sampling` structure, linking back type is hard.
        # Let's stick to the original logic's structure for now, using global train_edge_type:
        # This part might need refinement depending on how negative sampling should interact with types.
        neg_out = model.decode(z, neg_edge_index, data.train_edge_type) # Still using global types

        # --- Combine positive and negative outputs for loss calculation ---
        out = torch.cat([pos_out, neg_out])
        gt = torch.cat([
            torch.ones_like(pos_out),
            torch.zeros_like(neg_out)
        ]).to(device) # Ensure ground truth is on the correct device

        cross_entropy_loss = F.binary_cross_entropy_with_logits(out, gt)

        # Regularization (consider applying only to parameters used in the batch if needed)
        # Using the full 'z' means full regularization is applied based on all node embeddings.
        reg_loss = z.pow(2).mean() + model.decoder.rel_emb.pow(2).mean()
        loss = cross_entropy_loss + 1e-2 * reg_loss

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += float(loss)
        num_batches += 1

    return total_loss / num_batches # Return average loss per batch

In [11]:
@torch.no_grad()
def test(model, val_test_loader, data):
    model.eval()
    z = model.encode(data.edge_index, data.edge_type)

    valid_mrr = compute_mrr(z, data.valid_edge_index, data.valid_edge_type)
    test_mrr = compute_mrr(z, data.test_edge_index, data.test_edge_type)

    return valid_mrr, test_mrr

In [12]:
@torch.no_grad()
def compute_rank(ranks):
    # fair ranking prediction as the average
    # of optimistic and pessimistic ranking
    true = ranks[0]
    optimistic = (ranks > true).sum() + 1
    pessimistic = (ranks >= true).sum()
    return (optimistic + pessimistic).float() * 0.5

In [13]:
@torch.no_grad()
def compute_mrr(z, edge_index, edge_type):
    ranks = []
    for i in tqdm(range(edge_type.numel())):
        (src, dst), rel = edge_index[:, i], edge_type[i]

        # Try all nodes as tails, but delete true triplets:
        tail_mask = torch.ones(data.num_nodes, dtype=torch.bool)
        for (heads, tails), types in [
            (data.train_edge_index, data.train_edge_type),
            (data.valid_edge_index, data.valid_edge_type),
            (data.test_edge_index, data.test_edge_type),
        ]:
            tail_mask[tails[(heads == src) & (types == rel)]] = False

        tail = torch.arange(data.num_nodes)[tail_mask]
        tail = torch.cat([torch.tensor([dst]), tail])
        head = torch.full_like(tail, fill_value=src)
        eval_edge_index = torch.stack([head, tail], dim=0)
        eval_edge_type = torch.full_like(tail, fill_value=rel)

        out = model.decode(z, eval_edge_index, eval_edge_type)
        rank = compute_rank(out)
        ranks.append(rank)

        # Try all nodes as heads, but delete true triplets:
        head_mask = torch.ones(data.num_nodes, dtype=torch.bool)
        for (heads, tails), types in [
            (data.train_edge_index, data.train_edge_type),
            (data.valid_edge_index, data.valid_edge_type),
            (data.test_edge_index, data.test_edge_type),
        ]:
            head_mask[heads[(tails == dst) & (types == rel)]] = False

        head = torch.arange(data.num_nodes)[head_mask]
        head = torch.cat([torch.tensor([src]), head])
        tail = torch.full_like(head, fill_value=dst)
        eval_edge_index = torch.stack([head, tail], dim=0)
        eval_edge_type = torch.full_like(head, fill_value=rel)

        out = model.decode(z, eval_edge_index, eval_edge_type)
        rank = compute_rank(out)
        ranks.append(rank)

    return (1. / torch.tensor(ranks, dtype=torch.float)).mean()

# Run model

In [14]:
model = GAE(
    RGCNEncoder(data.num_nodes, 500, dataset.num_relations),
    DistMultDecoder(dataset.num_relations, 500),
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
times = []
for epoch in range(1, 200):
    start = time.time()
    loss = train(model, optimizer, train_loader, data)
    print(f'Epoch: {epoch:05d}, Loss: {loss:.4f}')
    if (epoch % 100) == 0:
        valid_mrr, test_mrr = test(model, val_test_loader, data)
        print(f'Val MRR: {valid_mrr:.4f}, Test MRR: {test_mrr:.4f}')
    times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")

Epoch: 00001, Loss: 0.6705
Epoch: 00002, Loss: 0.6661
Epoch: 00003, Loss: 0.6648
Epoch: 00004, Loss: 0.6649
Epoch: 00005, Loss: 0.6648
Epoch: 00006, Loss: 0.6647
Epoch: 00007, Loss: 0.6646
Epoch: 00008, Loss: 0.6645
Epoch: 00009, Loss: 0.6650
Epoch: 00010, Loss: 0.6645


# Archive

In [None]:
# def train(model, optimizer, train_loader, data): # Pass the full data object and device
model.train()
total_loss = 0
num_batches = 0

# --- Encode the full graph ONCE before the loop ---
# This is still a potential memory bottleneck for very large graphs.
# For truly large graphs, consider neighbor sampling techniques (e.g., NeighborLoader).
# with torch.no_grad(): # No need to track gradients for encoding if done once
      # z = model.encode(data.edge_index, data.edge_type)
# ----------------------------------------------------

for batch_idx, (batch_edge_index, batch_edge_type) in enumerate(train_loader):
    # print(batch_idx)
    optimizer.zero_grad()
    batch_edge_index = batch_edge_index.T

    # --- Use pre-computed full graph embeddings 'z' ---
    # Decode positive edges from the current batch
    print(z.shape)
    print(batch_edge_index.shape)
    # print(batch_edge_type.shape)
    pos_out = model.decode(z, batch_edge_index, batch_edge_type)
    print("Success: ")
    print(pos_out)
    break
    # Negative sampling based on the edges in the current batch
    # Note: negative_sampling function needs modification or replacement
    # if you want batch-specific negative sampling.
    # Current negative_sampling samples from the *entire* training set.
    # For simplicity here, we'll sample globally but decode using batch indices.
    # A more advanced approach would sample negatives only involving nodes in the batch.

    # Sample negatives globally (as in original code)
    neg_edge_index = negative_sampling(data.train_edge_index, data.num_nodes) # Sample globally
    neg_edge_index = neg_edge_index.to(device)

    # Select corresponding global edge types for the sampled negative edges
    # This is an approximation; ideally, you'd know the types for corrupted edges.
    # If negative_sampling corrupts based on existing edges, we might reuse types.
    # Let's assume we need to select types corresponding to the *indices* of the
    # original edges from which negatives were derived, which is complex.
    # A simpler (though less precise) approach might be to randomly assign types
    # or use the types corresponding to the *original* positive edges being corrupted.
    # Given the original `negative_sampling` structure, linking back type is hard.
    # Let's stick to the original logic's structure for now, using global train_edge_type:
    # This part might need refinement depending on how negative sampling should interact with types.
    neg_out = model.decode(z, neg_edge_index, data.train_edge_type) # Still using global types

    # --- Combine positive and negative outputs for loss calculation ---
    out = torch.cat([pos_out, neg_out])
    gt = torch.cat([
        torch.ones_like(pos_out),
        torch.zeros_like(neg_out)
    ]).to(device) # Ensure ground truth is on the correct device

    cross_entropy_loss = F.binary_cross_entropy_with_logits(out, gt)

    # Regularization (consider applying only to parameters used in the batch if needed)
    # Using the full 'z' means full regularization is applied based on all node embeddings.
    reg_loss = z.pow(2).mean() + model.decoder.rel_emb.pow(2).mean()
    loss = cross_entropy_loss + 1e-2 * reg_loss

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    total_loss += float(loss)
    num_batches += 1

print(f"Loss: {total_loss}, batches: {num_batches} Avg Loss: {total_loss/(1+num_batches)}")
# return total_loss / num_batches # Return average loss per batch

In [None]:
rel.shape

In [None]:
z_src, z_dst = z[edge_index.T[0]], z[edge_index.T[1]]

In [None]:
print(z_src.shape)
print(z_dst.shape)
print((z_src * z_dst).shape)

In [None]:
model.decoder.rel_emb.shape

In [None]:
for batch_idx, (batch_edge_index, batch_edge_type) in enumerate(train_loader):
  print(batch_edge_index.shape)
  print(batch_edge_index[0].shape)
  print(z[batch_edge_index[0]].shape)
  print(batch_edge_type.shape)
  print(model.decoder.rel_emb[batch_edge_type].shape)
  print(z[batch_edge_index].shape)
  break

In [None]:
(z_src * z_dst).shape

In [None]:
rel.T.shape

In [None]:
test = z_src * rel[0] * z_dst

In [None]:
test.shape

In [None]:
torch.sum(test, dim=1).shape

In [None]:
dataset.num_relations