# Dataset loading from txt


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from pathlib import Path
from tqdm import tqdm
from torch_geometric.utils import degree, add_self_loops
from torch_geometric.nn import MessagePassing
import os
import math
import random
from pathlib import Path
from typing import List, Tuple, Dict


# Dataset preparation

In [11]:
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- Paths (adjust if needed) ----
_train_path = Path("../WN18RR/train.txt")
_valid_path = Path("../WN18RR/valid.txt")
_test_path  = Path("../WN18RR/test.txt")

def load_dataset(path: Path) -> List[Tuple[str, str, str]]:
    data = []
    with open(path, "r") as f:
        for line in f:
            h, r, t = line.strip().split("\t")
            data.append((h, r, t))
    return data

train_dataset = load_dataset(_train_path)
valid_dataset = load_dataset(_valid_path)
test_dataset  = load_dataset(_test_path)

# ---- Build ID maps ----
entities: set[str] = set()
relations: set[str] = set()
for h, r, t in (train_dataset + valid_dataset + test_dataset):
    entities.add(h); entities.add(t); relations.add(r)

ent2id: Dict[str, int] = {e: i for i, e in enumerate(sorted(entities))}
rel2id: Dict[str, int] = {r: i for i, r in enumerate(sorted(relations))}
id2ent = {v: k for k, v in ent2id.items()}
id2rel = {v: k for k, v in rel2id.items()}

num_entities  = len(ent2id)
num_relations = len(rel2id)
print(f"#entities={num_entities}, #relations={num_relations}")

def triples_to_tensor(triples: List[Tuple[str, str, str]]) -> torch.LongTensor:
    arr = np.array([(ent2id[h], rel2id[r], ent2id[t]) for h, r, t in triples], dtype=np.int64)
    return torch.from_numpy(arr)

train_triples = triples_to_tensor(train_dataset).to(device)  # [Ntr, 3] (h, r, t)
valid_triples = triples_to_tensor(valid_dataset).to(device)  # [Nv, 3]
test_triples  = triples_to_tensor(test_dataset).to(device)   # [Nt, 3]

# ---- Collapsed undirected graph for LightGCN encoder ----
edges = []
for h, r, t in train_triples.tolist():
    edges.append((h, t))
    edges.append((t, h))
edge_index = torch.tensor(edges, dtype=torch.long, device=device).t().contiguous()  # [2, E]
edge_index, _ = add_self_loops(edge_index, num_nodes=num_entities)  # optional self-loops
print("edge_index:", tuple(edge_index.size()))


#entities=40943, #relations=11
edge_index: (2, 214613)


# Model definition and decoder 

```LightGCN``` :expects 2 types of data when using hetero:


```x_dict```   : which is the dictionary of embedding related to that node (trainable)


```edge_index_dict```  : which is the ```{('entity', relation, 'entity'): tensor[[src],[dst]]}``` 
 

we also add new inverse type of relation on top of the 11 that already exists.
this allows for information to be passed around which originally didnt.
A -> B is one way, and there should be an inverse relationship (or some information) which is missed out.

In [12]:
import torch.nn as nn
import torch.nn.functional as F
import torch 
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree

# -------- LightGCN layer (parameter-free) --------
class LightGCNConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr='add')

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.clamp(min=1).pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j: torch.Tensor, norm: torch.Tensor) -> torch.Tensor:
        return norm.view(-1, 1) * x_j

# -------- LightGCN encoder + dot-product decoder --------
class LightGCN(nn.Module):
    def __init__(self, num_nodes: int, emb_dim: int = 64, num_layers: int = 3):
        super().__init__()
        self.embedding = nn.Embedding(num_nodes, emb_dim)
        nn.init.xavier_uniform_(self.embedding.weight)
        self.convs = nn.ModuleList([LightGCNConv() for _ in range(num_layers)])
        self.num_layers = num_layers

    def encode(self, edge_index: torch.Tensor) -> torch.Tensor:
        x0 = self.embedding.weight
        out = x0
        x = x0
        for conv in self.convs:
            x = conv(x, edge_index)
            out = out + x
        return out / (self.num_layers + 1)   # layer-wise average

    @staticmethod
    def decode(z: torch.Tensor, pairs: torch.LongTensor) -> torch.Tensor:
        # pairs: [2, B] with [src; dst]
        return (z[pairs[0]] * z[pairs[1]]).sum(dim=1)

# -------- R-LightGCN: relation-aware propagation (learnable per-relation scalars) --------
class RLightGCNConv(MessagePassing):
    def __init__(self, num_relations: int):
        super().__init__(aggr='add')
        self.alpha = nn.Parameter(torch.ones(num_relations))  # per-relation scalar

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_type: torch.LongTensor) -> torch.Tensor:
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.clamp(min=1).pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return self.propagate(edge_index, x=x, edge_type=edge_type, norm=norm)

    def message(self, x_j: torch.Tensor, edge_type: torch.LongTensor, norm: torch.Tensor) -> torch.Tensor:
        w = self.alpha[edge_type].view(-1, 1)    # weight message by relation type
        return w * norm.view(-1, 1) * x_j

class RLightGCN(nn.Module):
    def __init__(self, num_nodes: int, num_relations: int, emb_dim: int = 64, num_layers: int = 3):
        super().__init__()
        self.embedding = nn.Embedding(num_nodes, emb_dim)
        nn.init.xavier_uniform_(self.embedding.weight)
        self.convs = nn.ModuleList([RLightGCNConv(num_relations) for _ in range(num_layers)])
        self.num_layers = num_layers

    def encode(self, edge_index: torch.Tensor, edge_type: torch.LongTensor) -> torch.Tensor:
        x0 = self.embedding.weight
        out = x0
        x = x0
        for conv in self.convs:
            x = conv(x, edge_index, edge_type)
            out = out + x
        return out / (self.num_layers + 1)

    @staticmethod
    def decode(z: torch.Tensor, pairs: torch.LongTensor) -> torch.Tensor:
        return (z[pairs[0]] * z[pairs[1]]).sum(dim=1)


# Negative Sampling

In [13]:
@torch.no_grad()
def pairs_from_triples(triples: torch.LongTensor) -> torch.LongTensor:
    """
    Convert (h, r, t) -> pairs [2, N] = (h, t) for decoding on collapsed graph.
    """
    return triples[:, [0, 2]].t().contiguous()  # [2, N]

@torch.no_grad()
def negative_sample_heads(triples: torch.LongTensor, num_nodes: int) -> torch.LongTensor:
    """
    Corrupt heads: (h, r, t) -> (h', t)
    Returns pairs [2, N].
    """
    N = triples.size(0)
    neg_h = torch.randint(0, num_nodes, (N,), device=triples.device)
    t = triples[:, 2]
    return torch.stack([neg_h, t], dim=0)

@torch.no_grad()
def negative_sample_tails(triples: torch.LongTensor, num_nodes: int) -> torch.LongTensor:
    """
    Corrupt tails: (h, r, t) -> (h, t')
    Returns pairs [2, N].
    """
    N = triples.size(0)
    h = triples[:, 0]
    neg_t = torch.randint(0, num_nodes, (N,), device=triples.device)
    return torch.stack([h, neg_t], dim=0)


# Training and Batch eval

In [14]:
def bpr_loss(pos_scores: torch.Tensor, neg_scores: torch.Tensor) -> torch.Tensor:
    # Bayesian Personalized Ranking loss
    return -torch.log(torch.sigmoid(pos_scores - neg_scores) + 1e-12).mean()

def logistic_pair_loss(pos_scores: torch.Tensor, neg_scores: torch.Tensor) -> torch.Tensor:
    # Alternative: pairwise logistic
    return -F.logsigmoid(pos_scores).mean() - F.logsigmoid(-neg_scores).mean()

@torch.no_grad()
def batch_scores(z: torch.Tensor, pairs: torch.LongTensor, batch_size: int = 4096) -> torch.Tensor:
    scores = []
    for i in range(0, pairs.size(1), batch_size):
        batch = pairs[:, i:i+batch_size]
        s = (z[batch[0]] * z[batch[1]]).sum(dim=1)
        scores.append(s)
    return torch.cat(scores, dim=0)


# Evaluation function 

In [15]:
from sklearn.metrics import roc_auc_score, average_precision_score
from typing import Dict


@torch.no_grad()
def evaluate_auc_ap(z: torch.Tensor,
                    pos_triples: torch.LongTensor,
                    num_nodes: int,
                    neg_mode: str = "head",
                    batch_size: int = 4096) -> Dict[str, float]:
    """
    Simple AUC/AP using equal number of sampled negatives.
    neg_mode: "head" or "tail" corruption.
    """
    pos_pairs = pairs_from_triples(pos_triples)
    if neg_mode == "head":
        neg_pairs = negative_sample_heads(pos_triples, num_nodes)
    else:
        neg_pairs = negative_sample_tails(pos_triples, num_nodes)

    pos = batch_scores(z, pos_pairs, batch_size=batch_size).detach().cpu().numpy()
    neg = batch_scores(z, neg_pairs, batch_size=batch_size).detach().cpu().numpy()

    y_true = np.concatenate([np.ones_like(pos), np.zeros_like(neg)])
    y_score = np.concatenate([pos, neg])

    return {
        "AUC": float(roc_auc_score(y_true, y_score)),
        "AP":  float(average_precision_score(y_true, y_score)),
    }


# Actual Training

In [20]:
# -------- Train LightGCN on collapsed graph (recommendation-style link prediction) --------
lr = 1e-3
epochs = 50
emb_dim = 64
num_layers = 3
eval_every = 5
batch_size = 4096  # scoring batch (not training mini-batch; training here uses full graph + full triples)

model = LightGCN(num_nodes=num_entities, emb_dim=emb_dim, num_layers=num_layers).to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in range(1, epochs + 1):
    model.train()
    opt.zero_grad()

    z = model.encode(edge_index)

    pos_pairs = pairs_from_triples(train_triples)             # [2, N]
    neg_pairs = negative_sample_heads(train_triples, num_entities)  # [2, N] (you can mix head/tail)

    pos_scores = model.decode(z, pos_pairs)
    neg_scores = model.decode(z, neg_pairs)

    loss = F.binary_cross_entropy_with_logits(
    torch.cat([pos_scores, neg_scores]),
    torch.cat([torch.ones_like(pos_scores), torch.zeros_like(neg_scores)])
    )

    loss.backward()
    opt.step()

    if epoch % eval_every == 0 or epoch == 1:
        model.eval()
        with torch.no_grad():
            z = model.encode(edge_index)
            val_metrics_h = evaluate_auc_ap(z, valid_triples, num_entities, neg_mode="head")
            val_metrics_t = evaluate_auc_ap(z, valid_triples, num_entities, neg_mode="tail")
        print(f"[LightGCN] Epoch {epoch:03d} | Loss {loss.item():.4f} | "
              f"Val(AUC/AP head) {val_metrics_h['AUC']:.4f}/{val_metrics_h['AP']:.4f} | "
              f"Val(AUC/AP tail) {val_metrics_t['AUC']:.4f}/{val_metrics_t['AP']:.4f}")

# Final test
model.eval()
with torch.no_grad():
    z = model.encode(edge_index)
    test_h = evaluate_auc_ap(z, test_triples, num_entities, neg_mode="head")
    test_t = evaluate_auc_ap(z, test_triples, num_entities, neg_mode="tail")
print(f"[LightGCN][TEST] AUC/AP head {test_h['AUC']:.4f}/{test_h['AP']:.4f} | "
      f"AUC/AP tail {test_t['AUC']:.4f}/{test_t['AP']:.4f}")


[LightGCN] Epoch 001 | Loss 0.6931 | Val(AUC/AP head) 0.7077/0.7713 | Val(AUC/AP tail) 0.7098/0.7640
[LightGCN] Epoch 005 | Loss 0.6929 | Val(AUC/AP head) 0.7899/0.8407 | Val(AUC/AP tail) 0.7920/0.8474
[LightGCN] Epoch 010 | Loss 0.6924 | Val(AUC/AP head) 0.8292/0.8685 | Val(AUC/AP tail) 0.8313/0.8803
[LightGCN] Epoch 015 | Loss 0.6915 | Val(AUC/AP head) 0.8427/0.8766 | Val(AUC/AP tail) 0.8486/0.8915
[LightGCN] Epoch 020 | Loss 0.6898 | Val(AUC/AP head) 0.8533/0.8839 | Val(AUC/AP tail) 0.8618/0.8990
[LightGCN] Epoch 025 | Loss 0.6873 | Val(AUC/AP head) 0.8594/0.8849 | Val(AUC/AP tail) 0.8668/0.9017
[LightGCN] Epoch 030 | Loss 0.6839 | Val(AUC/AP head) 0.8668/0.8906 | Val(AUC/AP tail) 0.8735/0.9098
[LightGCN] Epoch 035 | Loss 0.6794 | Val(AUC/AP head) 0.8685/0.8946 | Val(AUC/AP tail) 0.8784/0.9078
[LightGCN] Epoch 040 | Loss 0.6738 | Val(AUC/AP head) 0.8748/0.8979 | Val(AUC/AP tail) 0.8818/0.9134
[LightGCN] Epoch 045 | Loss 0.6671 | Val(AUC/AP head) 0.8775/0.9008 | Val(AUC/AP tail) 0.88

# R-LightGCN

In [21]:
# -------- Build relation-aware edge_index and edge_type (with reverse edges) --------
rel_edges = []
rel_types = []
for h, r, t in train_triples.tolist():
    rel_edges.append((h, t)); rel_types.append(r)
    rel_edges.append((t, h)); rel_types.append(r)  # add reverse edge with same relation

rel_edge_index = torch.tensor(rel_edges, dtype=torch.long, device=device).t().contiguous()
edge_type = torch.tensor(rel_types, dtype=torch.long, device=device)

# ---- Add self-loops with a new relation id ----
self_loop_edges = torch.arange(num_entities, device=device)
self_loop_edges = torch.stack([self_loop_edges, self_loop_edges], dim=0)  # [2, N]
rel_edge_index = torch.cat([rel_edge_index, self_loop_edges], dim=1)

self_loop_types = torch.full((num_entities,), num_relations, dtype=torch.long, device=device)
edge_type = torch.cat([edge_type, self_loop_types], dim=0)

# update relation count
num_relations_with_loops = num_relations + 1

# -------- Train R-LightGCN --------
lr = 1e-3
epochs = 50
emb_dim = 64
num_layers = 3
eval_every = 5

rmodel = RLightGCN(
    num_nodes=num_entities,
    num_relations=num_relations_with_loops,  # ✅ updated
    emb_dim=emb_dim,
    num_layers=num_layers
).to(device)

ropt = torch.optim.Adam(rmodel.parameters(), lr=lr)

for epoch in range(1, epochs + 1):
    rmodel.train()
    ropt.zero_grad()

    z = rmodel.encode(rel_edge_index, edge_type)

    pos_pairs = pairs_from_triples(train_triples)
    neg_pairs = negative_sample_heads(train_triples, num_entities)

    pos_scores = rmodel.decode(z, pos_pairs)
    neg_scores = rmodel.decode(z, neg_pairs)

    loss = F.binary_cross_entropy_with_logits(
    torch.cat([pos_scores, neg_scores]),
    torch.cat([torch.ones_like(pos_scores), torch.zeros_like(neg_scores)])
    )

    loss.backward()
    ropt.step()

    if epoch % eval_every == 0 or epoch == 1:
        rmodel.eval()
        with torch.no_grad():
            z = rmodel.encode(rel_edge_index, edge_type)
            val_metrics_h = evaluate_auc_ap(z, valid_triples, num_entities, neg_mode="head")
            val_metrics_t = evaluate_auc_ap(z, valid_triples, num_entities, neg_mode="tail")
        print(f"[R-LightGCN] Epoch {epoch:03d} | Loss {loss.item():.4f} | "
              f"Val(AUC/AP head) {val_metrics_h['AUC']:.4f}/{val_metrics_h['AP']:.4f} | "
              f"Val(AUC/AP tail) {val_metrics_t['AUC']:.4f}/{val_metrics_t['AP']:.4f}")

# Final test
rmodel.eval()
with torch.no_grad():
    z = rmodel.encode(rel_edge_index, edge_type)
    test_h = evaluate_auc_ap(z, test_triples, num_entities, neg_mode="head")
    test_t = evaluate_auc_ap(z, test_triples, num_entities, neg_mode="tail")
print(f"[R-LightGCN][TEST] AUC/AP head {test_h['AUC']:.4f}/{test_h['AP']:.4f} | "
      f"AUC/AP tail {test_t['AUC']:.4f}/{test_t['AP']:.4f}")


[R-LightGCN] Epoch 001 | Loss 0.6931 | Val(AUC/AP head) 0.7117/0.7706 | Val(AUC/AP tail) 0.7041/0.7594
[R-LightGCN] Epoch 005 | Loss 0.6929 | Val(AUC/AP head) 0.7971/0.8479 | Val(AUC/AP tail) 0.7976/0.8495
[R-LightGCN] Epoch 010 | Loss 0.6924 | Val(AUC/AP head) 0.8288/0.8665 | Val(AUC/AP tail) 0.8345/0.8809
[R-LightGCN] Epoch 015 | Loss 0.6914 | Val(AUC/AP head) 0.8494/0.8814 | Val(AUC/AP tail) 0.8483/0.8919
[R-LightGCN] Epoch 020 | Loss 0.6896 | Val(AUC/AP head) 0.8533/0.8802 | Val(AUC/AP tail) 0.8630/0.9005
[R-LightGCN] Epoch 025 | Loss 0.6868 | Val(AUC/AP head) 0.8604/0.8841 | Val(AUC/AP tail) 0.8653/0.9016
[R-LightGCN] Epoch 030 | Loss 0.6829 | Val(AUC/AP head) 0.8687/0.8948 | Val(AUC/AP tail) 0.8746/0.9083
[R-LightGCN] Epoch 035 | Loss 0.6777 | Val(AUC/AP head) 0.8705/0.8958 | Val(AUC/AP tail) 0.8763/0.9068
[R-LightGCN] Epoch 040 | Loss 0.6709 | Val(AUC/AP head) 0.8728/0.8962 | Val(AUC/AP tail) 0.8761/0.9048
[R-LightGCN] Epoch 045 | Loss 0.6626 | Val(AUC/AP head) 0.8764/0.9003 | V

In [23]:
# After training is done

# Save LightGCN
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": opt.state_dict(),
    "num_entities": num_entities,
    "emb_dim": emb_dim,
    "num_layers": num_layers
}, "lightgcn_wn18rr.pt")

# Save R-LightGCN
torch.save({
    "model_state_dict": rmodel.state_dict(),
    "optimizer_state_dict": ropt.state_dict(),
    "num_entities": num_entities,
    "num_relations": num_relations,
    "emb_dim": emb_dim,
    "num_layers": num_layers
}, "rlightgcn_wn18rr.pt")
