In [15]:
!pip install torch_geometric torchmetrics



# Import Libraries

In [16]:
# -----------------------------
# Standard library
# -----------------------------
import os
import json
import csv
import random
from pathlib import Path
from collections import defaultdict
from datetime import datetime
from typing import List, Tuple, Dict, Optional

# -----------------------------
# Third-party libraries
# -----------------------------
import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from google.colab import drive

# -----------------------------
# PyTorch and PyG libraries
# -----------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch_geometric.utils import degree, add_self_loops
from torch_geometric.nn import MessagePassing
# from torch_scatter import scatter

# import sys
# from pathlib import Path
# current_dir = Path(__file__).parent if '__file__' in globals() else Path.cwd()
# parent_dir = current_dir.parent
# sys.path.append(str(parent_dir))
# from utility.dataset_loader import KGDataModuleCollapsed, KGDataModuleTyped

# device = torch.device('cpu')

In [17]:
# -----------------------
# Helpers
# -----------------------
def _read_triples(path: Path) -> List[Tuple[str, str, str]]:
    triples = []
    with open(path, "r", newline="") as f:
        reader = csv.reader(f, delimiter="\t")
        for row in reader:
            if not row:
                continue
            h, r, t = row
            triples.append((h, r, t))
    return triples


def _build_id_maps(
        train_p: Path,
        valid_p: Optional[Path] = None,
        test_p: Optional[Path] = None,
) -> Tuple[Dict[str, int], Dict[str, int]]:
    """Build ent2id and rel2id from all splits available."""
    ents, rels = set(), set()
    for p in [train_p, valid_p, test_p]:
        if p is None:
            continue
        for h, r, t in _read_triples(p):
            ents.add(h); ents.add(t); rels.add(r)
    ent2id = {e: i for i, e in enumerate(sorted(ents))}
    rel2id = {r: i for i, r in enumerate(sorted(rels))}
    return ent2id, rel2id


# -----------------------
# Datasets
# -----------------------
class _PairsDataset(Dataset):
    """Untyped pairs (h, t), label=1 for each positive edge."""
    def __init__(self, pairs: torch.LongTensor):
        # pairs: [N,2]
        self.pairs = pairs
        self.labels = torch.ones(pairs.size(0), dtype=torch.float32)

    def __len__(self):
        return self.pairs.size(0)

    def __getitem__(self, idx):
        return self.pairs[idx], self.labels[idx]


class _TriplesDataset(Dataset):
    """Typed triples (h, r, t), label=1 for each positive triple."""
    def __init__(self, triples: torch.LongTensor):
        # triples: [N,3]
        self.triples = triples
        self.labels = torch.ones(triples.size(0), dtype=torch.float32)

    def __len__(self):
        return self.triples.size(0)

    def __getitem__(self, idx):
        return self.triples[idx], self.labels[idx]


# ============================================================
# Collapsed edge types (untyped) datamodule
# ============================================================
class KGDataModuleCollapsed:
    """
    Produces untyped (h, t) pairs from KG triples.

    Public methods:
      - train_loader()
      - val_loader()
      - test_loader()

    Args:
        train_path, valid_path, test_path: Path to split files (FB15K format).
        batch_size: int
        shuffle: bool (applied to train loader only)
        num_workers: int (DataLoader)
        add_reverse: if True, also add (t, h) for every (h, t)
    """
    def __init__(
            self,
            train_path: Path,
            valid_path: Optional[Path] = None,
            test_path: Optional[Path] = None,
            batch_size: int = 2048,
            shuffle: bool = True,
            num_workers: int = 0,
            add_reverse: bool = True,
    ):
        self.train_path = Path(train_path)
        self.valid_path = Path(valid_path) if valid_path else None
        self.test_path  = Path(test_path)  if test_path  else None

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.add_reverse = add_reverse

        # Build ids
        self.ent2id, self.rel2id = _build_id_maps(self.train_path, self.valid_path, self.test_path)

        # Build tensors per split
        self._train_pairs = self._make_pairs(self.train_path)
        self._valid_pairs = self._make_pairs(self.valid_path) if self.valid_path else None
        self._test_pairs  = self._make_pairs(self.test_path)  if self.test_path  else None

        # Datasets
        self._train_ds = _PairsDataset(self._train_pairs)
        self._valid_ds = _PairsDataset(self._valid_pairs) if self._valid_pairs is not None else None
        self._test_ds  = _PairsDataset(self._test_pairs)  if self._test_pairs  is not None else None

    def _make_pairs(self, path: Path) -> torch.LongTensor:
        pairs = []
        for h, _, t in _read_triples(path):
            h_id, t_id = self.ent2id[h], self.ent2id[t]
            pairs.append((h_id, t_id))
            if self.add_reverse:
                pairs.append((t_id, h_id))
        if not pairs:
            return torch.empty(0, 2, dtype=torch.long)
        return torch.tensor(pairs, dtype=torch.long)

    # -------- public loaders --------
    def train_loader(self) -> DataLoader:
        return DataLoader(
            self._train_ds, batch_size=self.batch_size, shuffle=self.shuffle,
            num_workers=self.num_workers, pin_memory=False
        )

    def val_loader(self) -> Optional[DataLoader]:
        if self._valid_ds is None:
            return None
        return DataLoader(
            self._valid_ds, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=False
        )

    def test_loader(self) -> Optional[DataLoader]:
        if self._test_ds is None:
            return None
        return DataLoader(
            self._test_ds, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=False
        )


# ============================================================
# Typed (with edge types) datamodule
# ============================================================
class KGDataModuleTyped:
    """
    Produces typed (h, r, t) triples from KG files.

    Public methods:
      - train_loader()
      - val_loader()
      - test_loader()

    Args:
        train_path, valid_path, test_path: Path to split files.
        batch_size, shuffle, num_workers: DataLoader args.
        add_reverse: If True, add reverse links.
        reverse_relation_strategy:
            'duplicate_rel' -> create an inverse relation id per r (e.g., r#inv)
            'same_rel'      -> reuse the same r id for reverse triple
    """
    def __init__(
            self,
            train_path: Path,
            valid_path: Optional[Path] = None,
            test_path: Optional[Path] = None,
            batch_size: int = 2048,
            shuffle: bool = True,
            num_workers: int = 0,
            add_reverse: bool = True,
            reverse_relation_strategy: str = "duplicate_rel",
    ):
        assert reverse_relation_strategy in ("duplicate_rel", "same_rel")
        self.train_path = Path(train_path)
        self.valid_path = Path(valid_path) if valid_path else None
        self.test_path  = Path(test_path)  if test_path  else None

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.add_reverse = add_reverse
        self.reverse_relation_strategy = reverse_relation_strategy

        # Base id maps from original relations
        self.ent2id, base_rel2id = _build_id_maps(self.train_path, self.valid_path, self.test_path)

        # Possibly extend rel2id with inverse relations
        if add_reverse and reverse_relation_strategy == "duplicate_rel":
            # Make space for inverse ids
            self.rel2id = dict(base_rel2id)
            self.inv_of = {}  # map original rel -> inverse rel id
            next_id = max(self.rel2id.values(), default=-1) + 1
            # Ensure deterministic order
            for r in sorted(base_rel2id.keys()):
                inv_name = r + "_rev"
                if inv_name not in self.rel2id:
                    self.rel2id[inv_name] = next_id
                    self.inv_of[r] = next_id
                    next_id += 1
        else:
            self.rel2id = base_rel2id
            self.inv_of = None  # not used

        # Build tensors per split
        self._train_triples = self._make_triples(self.train_path)
        self._valid_triples = self._make_triples(self.valid_path) if self.valid_path else None
        self._test_triples  = self._make_triples(self.test_path)  if self.test_path  else None

        # Datasets
        self._train_ds = _TriplesDataset(self._train_triples)
        self._valid_ds = _TriplesDataset(self._valid_triples) if self._valid_triples is not None else None
        self._test_ds  = _TriplesDataset(self._test_triples)  if self._test_triples  is not None else None

    def _make_triples(self, path: Path) -> torch.LongTensor:
        if path is None:
            return torch.empty(0, 3, dtype=torch.long)

        rows = []
        for h, r, t in _read_triples(path):
            h_id, t_id = self.ent2id[h], self.ent2id[t]
            r_id = self.rel2id[r]  # original direction
            rows.append((h_id, r_id, t_id))

            if self.add_reverse:
                if self.reverse_relation_strategy == "duplicate_rel":
                    r_inv_id = self.inv_of[r]  # created in __init__
                    rows.append((t_id, r_inv_id, h_id))
                else:  # same_rel
                    rows.append((t_id, r_id, h_id))

        if not rows:
            return torch.empty(0, 3, dtype=torch.long)
        return torch.tensor(rows, dtype=torch.long)

    # -------- public loaders --------
    def train_loader(self) -> DataLoader:
        return DataLoader(
            self._train_ds, batch_size=self.batch_size, shuffle=self.shuffle,
            num_workers=self.num_workers, pin_memory=False
        )

    def val_loader(self) -> Optional[DataLoader]:
        if self._valid_ds is None:
            return None
        return DataLoader(
            self._valid_ds, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=False
        )

    def test_loader(self) -> Optional[DataLoader]:
        if self._test_ds is None:
            return None
        return DataLoader(
            self._test_ds, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=False
        )

# Dataset Loading


In [18]:
# Dataset loading using dataset_loader.py
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Mount Google Drive
drive.mount('/content/drive')

# Define the path to your dataset
base_path = Path("/content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/FB15K-237.2")

# Dataset paths
train_path = base_path / "train.txt"
valid_path = base_path / "valid.txt"
test_path = base_path / "test.txt"

# Initialize data modules
# Collapsed mode for LightGCN (untyped pairs)
dm_collapsed = KGDataModuleCollapsed(
    train_path=train_path,
    valid_path=valid_path,
    test_path=test_path,
    batch_size=4096,
    shuffle=True,
    add_reverse=True
)

# Typed mode for R-LightGCN (typed triples)
dm_typed = KGDataModuleTyped(
    train_path=train_path,
    valid_path=valid_path,
    test_path=test_path,
    batch_size=4096,
    shuffle=True,
    add_reverse=True,
    reverse_relation_strategy="duplicate_rel"
)

num_entities = len(dm_collapsed.ent2id)
num_relations = len(dm_collapsed.rel2id)  # Original relations
num_relations_with_inv = len(dm_typed.rel2id)  # With inverse relations

print(f"Dataset: FB15K")
print(f"Entities: {num_entities:,}")
print(f"Original Relations: {num_relations:,}")
print(f"Relations (with inverse): {num_relations_with_inv:,}")
print(f"Training pairs: {len(dm_collapsed._train_pairs):,}")
print(f"Training triples (typed): {len(dm_typed._train_triples):,}")

# Build edge_index for LightGCN (collapsed pairs)
edge_index = dm_collapsed._train_pairs.t().contiguous().to(device)
edge_index, _ = add_self_loops(edge_index, num_nodes=num_entities)

print(f"Graph edges (with self-loops): {edge_index.shape[1]:,}")
print(f"Graph edge_index shape: {tuple(edge_index.shape)}")

train_loader_collapsed = dm_collapsed.train_loader()
val_loader_collapsed = dm_collapsed.val_loader()
test_loader_collapsed = dm_collapsed.test_loader()

train_loader_typed = dm_typed.train_loader()
val_loader_typed = dm_typed.val_loader()
test_loader_typed = dm_typed.test_loader()

print(f"Data loaders created:")
print(f"Train batches (collapsed): {len(train_loader_collapsed)}")
print(f"Val batches (collapsed): {len(val_loader_collapsed) if val_loader_collapsed else 0}")
print(f"Test batches (collapsed): {len(test_loader_collapsed) if test_loader_collapsed else 0}")

Using device: cuda
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Dataset: FB15K
Entities: 14,541
Original Relations: 237
Relations (with inverse): 474
Training pairs: 544,230
Training triples (typed): 544,230
Graph edges (with self-loops): 558,771
Graph edge_index shape: (2, 558771)
Data loaders created:
Train batches (collapsed): 133
Val batches (collapsed): 9
Test batches (collapsed): 10


# Helper Functions

In [19]:
# Model Saving Utility Functions
def save_model_checkpoint(model, optimizer, hyperparameters, final_test_metrics,
                         training_history, model_name, filename):
    """
    Save model checkpoint with comprehensive information.

    Args:
        model: The trained model
        optimizer: The optimizer used
        hyperparameters: Dict with training hyperparameters
        final_test_metrics: Final test evaluation results
        training_history: Training metrics history
        model_name: Name of the model for display
        filename: Output filename for the checkpoint
    """
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'hyperparameters': hyperparameters,
        'final_test_metrics': final_test_metrics,
        'training_history': training_history,
        'model_name': model_name
    }

    torch.save(checkpoint, filename)
    print(f"{model_name} model checkpoint saved to {filename}")

def load_model_checkpoint(filename, model_class, device, **model_kwargs):
    """
    Load model checkpoint and restore model state.

    Args:
        filename: Path to checkpoint file
        model_class: Model class to instantiate
        device: Device to load model on
        **model_kwargs: Arguments for model instantiation

    Returns:
        model: Loaded model
        checkpoint: Full checkpoint data
    """
    checkpoint = torch.load(filename, map_location=device)

    # Create model instance
    model = model_class(**model_kwargs).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])

    print(f"Model loaded from {filename}")
    print(f"Model: {checkpoint.get('model_name', 'Unknown')}")
    print(f"Hyperparameters: {checkpoint.get('hyperparameters', {})}")

    if 'final_test_metrics' in checkpoint:
        metrics = checkpoint['final_test_metrics']
        if 'head' in metrics and 'tail' in metrics:
            auc_avg = (metrics['head']['auc'] + metrics['tail']['auc']) / 2
            hits10_avg = (metrics['head']['hits@10'] + metrics['tail']['hits@10']) / 2
            print(f"Test AUC (avg): {auc_avg:.4f}")
            print(f"Test Hits@10 (avg): {hits10_avg:.4f}")

    return model, checkpoint

In [20]:
class MetricsTracker:
    """Simplified metrics tracker without plotting"""
    def __init__(self):
        self.metrics = defaultdict(list)

    def add(self, epoch, **kwargs):
        self.metrics['epoch'].append(epoch)
        for key, value in kwargs.items():
            self.metrics[key].append(value)

    def get_best_epoch(self, metric='val_auc_head'):
        """Get the epoch with the best performance for a given metric"""
        if metric not in self.metrics or not self.metrics[metric]:
            return 0

        # For AUC and Hits@K, higher is better
        best_idx = np.argmax(self.metrics[metric])
        return self.metrics['epoch'][best_idx]

    def save_to_file(self, filepath):
        with open(filepath, 'w') as f:
            f.write("Epoch\tLoss\tAUC\tHits@1\tHits@5\tHits@10\n")
            for i in range(len(self.metrics['epoch'])):
                epoch = self.metrics['epoch'][i]
                loss = self.metrics['loss'][i]
                auc = self.metrics['val_auc'][i]
                h1 = self.metrics['val_hits1'][i]
                h5 = self.metrics['val_hits5'][i]
                h10 = self.metrics['val_hits10'][i]

                f.write(f"{epoch}\t{loss:.6f}\t{auc:.6f}\t{h1:.6f}\t{h5:.6f}\t{h10:.6f}\n")


class EarlyStopping:
    def __init__(self, patience=10, metric='val_auc_head', mode='max', min_delta=0.001):
        self.patience = patience
        self.metric = metric
        self.mode = mode  # 'max' for AUC, Hits@K; 'min' for loss
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.should_stop = False

    def __call__(self, metrics_tracker):
        current_score = metrics_tracker.metrics[self.metric][-1] if self.metric in metrics_tracker.metrics else 0

        if self.best_score is None:
            self.best_score = current_score
        else:
            if self.mode == 'max':
                improved = current_score > self.best_score + self.min_delta
            else:  # mode == 'min'
                improved = current_score < self.best_score - self.min_delta

            if improved:
                self.best_score = current_score
                self.counter = 0
            else:
                self.counter += 1

        self.should_stop = self.counter >= self.patience
        return self.should_stop

In [21]:
# Training functions with data loaders
def train_one_epoch_lightgcn(model, data_loader, optimizer, edge_index, device, num_entities, max_grad_norm=5.0):
    """
    Stable LightGCN training loop (BPR loss).
    Handles NaNs, exploding gradients, and large score differences gracefully.

    Args:
        model: LightGCN model
        data_loader: DataLoader yielding positive pairs (batch_size, 2)
        optimizer: Optimizer instance
        edge_index: Graph edges tensor on the correct device
        device: torch.device to run on
        num_entities: Total number of entities (int)
        max_grad_norm: Optional gradient clipping value
    """
    model.train()
    total_loss = 0.0
    num_batches = 0

    for batch_idx, batch in enumerate(data_loader):
        optimizer.zero_grad()

        # ------------------------------
        # Handle batch format
        # ------------------------------
        if isinstance(batch, (list, tuple)):
            pairs_tensor = batch[0]
        else:
            pairs_tensor = batch

        pairs_tensor = pairs_tensor.to(device)
        if pairs_tensor.dim() == 1:
            pairs_tensor = pairs_tensor.unsqueeze(0)

        pos_edges = pairs_tensor  # [batch_size, 2]
        num_pos = pos_edges.size(0)

        # ------------------------------
        # Create random negatives
        # ------------------------------
        neg_tail = torch.randint(0, num_entities, (num_pos,), device=device)
        neg_head = torch.randint(0, num_entities, (num_pos,), device=device)

        # ------------------------------
        # Encode embeddings (per-batch to avoid retain_graph issues)
        # ------------------------------
        embeddings = model.encode(edge_index)

        # Normalize to prevent exploding scores
        emb = embeddings / (embeddings.norm(dim=1, keepdim=True) + 1e-9)

        # ------------------------------
        # Compute scores
        # ------------------------------
        pos_scores = (emb[pos_edges[:, 0]] * emb[pos_edges[:, 1]]).sum(dim=1)
        neg_scores_tail = (emb[pos_edges[:, 0]] * emb[neg_tail]).sum(dim=1)
        neg_scores_head = (emb[neg_head] * emb[pos_edges[:, 1]]).sum(dim=1)

        # ------------------------------
        # Stable BPR loss
        # ------------------------------
        loss_tail = F.softplus(-(pos_scores - neg_scores_tail)).mean()
        loss_head = F.softplus(-(pos_scores - neg_scores_head)).mean()
        loss = loss_tail + loss_head

        # ------------------------------
        # NaN / inf guard
        # ------------------------------
        if not torch.isfinite(loss):
            print(f"[Batch {batch_idx}] NaN or inf detected — skipping this batch.")
            continue

        # ------------------------------
        # Backprop and update
        # ------------------------------
        loss.backward()

        # Optional gradient clipping (prevents explosions)
        if max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        # Track loss
        total_loss += loss.item()
        num_batches += 1

    avg_loss = total_loss / max(1, num_batches)
    return avg_loss

### Negative Sampling
- A training technique used in knowledge graph link prediction to create "negative examples", triples that are likely to be false.
- Since knowledge graphs only contain positive facts (true triples), we need to artificially create negative examples for the model to learn what relationships are incorrect.
- Both head and tail corruptions are used to train the model to understand connections flowing in both directions. Model learns "What subjects fit this relation-object" and "What object fits this subject-relation".

In [22]:
@torch.no_grad()
def pairs_from_triples(triples: torch.LongTensor) -> torch.LongTensor:
    """
    Convert (h, r, t) -> pairs [2, N] = (h, t) for decoding on collapsed graph.
    """
    return triples[:, [0, 2]].t().contiguous()  # [2, N]

@torch.no_grad()
def negative_sample_heads(triples: torch.LongTensor, num_nodes: int) -> torch.LongTensor:
    """
    Corrupt heads: (h, r, t) -> (h', t)
    Returns pairs [2, N].
    """
    N = triples.size(0)
    neg_h = torch.randint(0, num_nodes, (N,), device=triples.device)
    t = triples[:, 2]
    return torch.stack([neg_h, t], dim=0)

@torch.no_grad()
def negative_sample_tails(triples: torch.LongTensor, num_nodes: int) -> torch.LongTensor:
    """
    Corrupt tails: (h, r, t) -> (h, t')
    Returns pairs [2, N].
    """
    N = triples.size(0)
    h = triples[:, 0]
    neg_t = torch.randint(0, num_nodes, (N,), device=triples.device)
    return torch.stack([h, neg_t], dim=0)


# Model 1 LightGCN
- We also add new inverse type of relation on top of the 11 that already exists.
- This allows for information to be passed around which originally did not.
- A -> B is one way, and there should be an inverse relationship (or some information) which is missed out.

In [23]:
# -------- LightGCN layer --------
class LightGCNConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr='add')

    # Compute symmetric normalization term D^-0.5*A*D^-0.5 to propagate messages through normalized adjacency
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.clamp(min=1).pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return self.propagate(edge_index, x=x, norm=norm)

    # Scales the neighbor embeddings
    def message(self, x_j: torch.Tensor, norm: torch.Tensor) -> torch.Tensor:
        return norm.view(-1, 1) * x_j

# -------- LightGCN encoder + dot-product decoder --------
class LightGCN(nn.Module):
    # Initialize trainable node embeddings
    def __init__(self, num_nodes: int, emb_dim: int = 64, num_layers: int = 3):
        super().__init__()
        self.embedding = nn.Embedding(num_nodes, emb_dim)
        nn.init.xavier_uniform_(self.embedding.weight)
        self.convs = nn.ModuleList([LightGCNConv() for _ in range(num_layers)])
        self.num_layers = num_layers

    def encode(self, edge_index: torch.Tensor) -> torch.Tensor:
        x0 = self.embedding.weight
        out = x0
        x = x0
        # Each layer's output is accumulated and averaged
        for conv in self.convs:
            x = conv(x, edge_index)
            out = out + x
        return out / (self.num_layers + 1)

    # Compute dot product between node embeddings for each edge (positive or negative pair)
    @staticmethod
    def decode(z: torch.Tensor, pairs: torch.LongTensor) -> torch.Tensor:
        # pairs: [2, B] with [src; dst]
        return (z[pairs[0]] * z[pairs[1]]).sum(dim=1)

In [24]:
import torch
import numpy as np
from sklearn.metrics import roc_auc_score

def evaluate_auc_hits(model, triples, num_entities, edge_index=None, batch_size=4096, device=None):
    """
    Evaluate AUC and Hits@K (K=1,5,10) for LightGCN or R-LightGCN.
    Unfiltered version — each positive triple is compared to 99 random negatives.

    Args:
        model: LightGCN or R-LightGCN with .encode()
        triples: torch.Tensor [N, 3] or [N, 2] validation/test triples
        num_entities: total number of entities
        edge_index: graph structure for encoding
        batch_size: number of triples per batch
        device: torch.device to use
    """
    model.eval()
    if device is None:
        device = next(model.parameters()).device

    # -------------------------
    # Helper functions
    # -------------------------
    def batch_iter(tensor, size):
        for i in range(0, len(tensor), size):
            yield tensor[i:i + size]

    def sample_negatives(pos_batch, num_entities):
        """Corrupt tail entities randomly."""
        neg_batch = pos_batch.clone()
        neg_batch[:, -1] = torch.randint(0, num_entities, (pos_batch.size(0),), device=pos_batch.device)
        return neg_batch

    # -------------------------
    # 1️⃣ AUC computation
    # -------------------------
    scores_all, labels_all = [], []

    for pos in batch_iter(triples, batch_size):
        pos = pos.to(device)
        neg = sample_negatives(pos, num_entities)

        with torch.no_grad():
            if hasattr(model, 'encode'):
                emb = model.encode(edge_index)
                s_pos = (emb[pos[:, 0]] * emb[pos[:, -1]]).sum(dim=1)
                s_neg = (emb[neg[:, 0]] * emb[neg[:, -1]]).sum(dim=1)
            else:
                s_pos = model(pos)
                s_neg = model(neg)

            scores_all.append(torch.cat([s_pos, s_neg], dim=0))
            labels_all.append(torch.cat([torch.ones_like(s_pos), torch.zeros_like(s_neg)], dim=0))

    # If no scores were collected, return defaults
    if len(scores_all) == 0:
        return {"auc": 0.5, "hits@1": 0.0, "hits@5": 0.0, "hits@10": 0.0}

    scores_all = torch.cat(scores_all, dim=0)
    labels_all = torch.cat(labels_all, dim=0)
    # Compute ROC-AUC using torchmetrics (pure torch, avoids numpy)
    try:
        from torchmetrics.classification import BinaryAUROC
        auroc_metric = BinaryAUROC().to(scores_all.device)
        auc = auroc_metric(scores_all, labels_all.int()).item()
    except Exception:
        auc = 0.5

    # -------------------------
    # 2️⃣ Hits@K computation (1,5,10)
    # -------------------------
    hits_at = {1: 0, 5: 0, 10: 0}
    n_trials = 0

    with torch.no_grad():
        ent = model.encode(edge_index) if hasattr(model, 'encode') else model.encoder(edge_index)

        for pos in batch_iter(triples, batch_size):
            pos = pos.to(device)
            B = pos.size(0)
            true_t = pos[:, -1]
            rand_t = torch.randint(0, num_entities, (B, 99), device=device)
            tails = torch.cat([true_t.unsqueeze(1), rand_t], dim=1)  # [B,100]

            e_h = ent[pos[:, 0]]                                    # [B,d]
            e_candidates = ent[tails]                               # [B,100,d]
            s = (e_h.unsqueeze(1) * e_candidates).sum(dim=2)        # [B,100]

            # rank position of true tail
            ranks = (s.argsort(dim=1, descending=True) == 0).nonzero()[:, 1] + 1  # 1-based

            for k in hits_at.keys():
                hits_at[k] += (ranks <= k).sum().item()

            n_trials += B

    # Normalize
    hits_at = {f"hits@{k}": v / n_trials for k, v in hits_at.items()}

    return {"auc": float(auc), **hits_at}


In [None]:
# ===============================
# Hyperparameters
# ===============================
lr = 1e-3
epochs = 100
emb_dim = 64
num_layers = 3
eval_every = 5
batch_size = 2048
patience = 10

# ===============================
# Setup
# ===============================
lightgcn_metrics_tracker = MetricsTracker()
lightgcn_early_stopping = EarlyStopping(patience=patience, metric='val_auc')

# Create model + optimizer
lightgcn_model = LightGCN(num_nodes=num_entities, emb_dim=emb_dim, num_layers=num_layers).to(device)

optimizer = torch.optim.Adam(lightgcn_model.parameters(), lr=lr)

print("Starting LightGCN training...")
print(f"Max epochs: {epochs}, Early stopping patience: {patience}")

# ===============================
# Training Loop
# ===============================
for epoch in tqdm(range(1, epochs + 1), desc="Training LightGCN"):
    avg_loss = train_one_epoch_lightgcn(
        lightgcn_model, train_loader_collapsed, optimizer,
        edge_index, device, num_entities
    )

    if epoch <= 5 or epoch % 5 == 0:
        print(f"Epoch {epoch:3d}: Loss={avg_loss:.4f}")

    # ===============================
    # Evaluation
    # ===============================
    if epoch % eval_every == 0 or epoch == 1:
        print(f"\nEvaluating epoch {epoch}...")
        # Convert val_loader to flat tensor of pairs/triples
        all_val_data = []
        for batch in val_loader_collapsed:
            data_tensor = batch[0] if isinstance(batch, (list, tuple)) else batch
            if data_tensor.dim() == 1:
                data_tensor = data_tensor.unsqueeze(0)
            all_val_data.append(data_tensor)
        val_triples = torch.cat(all_val_data, dim=0)

        # Run simplified evaluation (AUC + Hits@K)
        val_metrics = evaluate_auc_hits(
            lightgcn_model,
            val_triples,
            num_entities=num_entities,
            edge_index=edge_index,
            batch_size=2048,
            device=device
        )

        # Print summary
        print(f"Epoch {epoch}: Loss={avg_loss:.4f}, "
              f"AUC={val_metrics['auc']:.4f}, "
              f"Hits@1={val_metrics['hits@1']:.4f}, "
              f"Hits@5={val_metrics['hits@5']:.4f}, "
              f"Hits@10={val_metrics['hits@10']:.4f}")

        # Log to tracker (simple flattening)
        lightgcn_metrics_tracker.add(
            epoch=epoch,
            loss=avg_loss,
            val_auc=val_metrics['auc'],
            val_hits1=val_metrics['hits@1'],
            val_hits5=val_metrics['hits@5'],
            val_hits10=val_metrics['hits@10']
        )

        # Early stopping
        if lightgcn_early_stopping(lightgcn_metrics_tracker):
            print(f"Early stopping at epoch {epoch} "
                  f"(Best AUC: {lightgcn_early_stopping.best_score:.4f})")
            break

print("\nLightGCN training completed!")

# ===============================
# Final Test Evaluation
# ===============================
print("\nRunning final test evaluation...")
all_test_data = []
for batch in test_loader_collapsed:
    data_tensor = batch[0] if isinstance(batch, (list, tuple)) else batch
    if data_tensor.dim() == 1:
        data_tensor = data_tensor.unsqueeze(0)
    all_test_data.append(data_tensor)
test_triples = torch.cat(all_test_data, dim=0)

lightgcn_final_test_metrics = evaluate_auc_hits(
    lightgcn_model,
    test_triples,
    num_entities=num_entities,
    edge_index=edge_index,
    batch_size=2048,
    device=device
)

print("\n[LightGCN FB15K TEST RESULTS]")
print(f"AUC: {lightgcn_final_test_metrics['auc']:.4f}")
print(f"Hits@1: {lightgcn_final_test_metrics['hits@1']:.4f}")
print(f"Hits@5: {lightgcn_final_test_metrics['hits@5']:.4f}")
print(f"Hits@10: {lightgcn_final_test_metrics['hits@10']:.4f}")

# ===============================
# Save model + metrics
# ===============================
save_dir = "/content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/FB15K"
os.makedirs(save_dir, exist_ok=True)

# --- Save final model ---
model_file = os.path.join(save_dir, "fb15k_lightgcn_model.pt")
torch.save(lightgcn_model.state_dict(), model_file)
print(f"Final model saved: {model_file}")

# --- Prepare and align metrics ---
metrics = lightgcn_metrics_tracker.metrics
min_len = min(len(v) for v in metrics.values())
print(f"Saving {min_len} consistent metric entries...")

# --- Save metrics safely ---
metrics_file = os.path.join(save_dir, "fb15k_lightgcn_metrics.txt")
with open(metrics_file, "w") as f:
    f.write("Epoch\tLoss\tVal_AUC\tHits@1\tHits@5\tHits@10\n")
    for i in range(min_len):
        f.write(
            f"{metrics['epoch'][i]}\t"
            f"{metrics['loss'][i]:.6f}\t"
            f"{metrics['val_auc'][i]:.6f}\t"
            f"{metrics['val_hits1'][i]:.6f}\t"
            f"{metrics['val_hits5'][i]:.6f}\t"
            f"{metrics['val_hits10'][i]:.6f}\n"
        )
print(f"Metrics saved to {metrics_file}")

# --- Optional: save complete training bundle ---
bundle_file = os.path.join(save_dir, "fb15k_lightgcn_full_bundle.pt")
torch.save({
    "model_state_dict": lightgcn_model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "hyperparameters": {
        "emb_dim": emb_dim,
        "num_layers": num_layers,
        "lr": lr,
        "num_entities": num_entities
    },
    "final_test_metrics": lightgcn_final_test_metrics,
    "training_history": metrics
}, bundle_file)
print(f"Full model + training history bundle saved: {bundle_file}")


Starting LightGCN training...
Max epochs: 100, Early stopping patience: 10


Training LightGCN:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch   1: Loss=0.8880

Evaluating epoch 1...


Training LightGCN:   1%|          | 1/100 [00:08<13:24,  8.13s/it]

Epoch 1: Loss=0.8880, AUC=0.8733, Hits@1=0.2725, Hits@5=0.5445, Hits@10=0.6833


Training LightGCN:   2%|▏         | 2/100 [00:15<12:38,  7.74s/it]

Epoch   2: Loss=0.8166


Training LightGCN:   3%|▎         | 3/100 [00:23<12:18,  7.62s/it]

Epoch   3: Loss=0.8101


Training LightGCN:   4%|▍         | 4/100 [00:30<12:12,  7.63s/it]

Epoch   4: Loss=0.8086
Epoch   5: Loss=0.8075

Evaluating epoch 5...


Training LightGCN:   5%|▌         | 5/100 [00:38<12:20,  7.79s/it]

Epoch 5: Loss=0.8075, AUC=0.8817, Hits@1=0.2779, Hits@5=0.5592, Hits@10=0.7063


Training LightGCN:   9%|▉         | 9/100 [01:08<11:26,  7.55s/it]

Epoch  10: Loss=0.8054

Evaluating epoch 10...


Training LightGCN:  10%|█         | 10/100 [01:16<11:34,  7.71s/it]

Epoch 10: Loss=0.8054, AUC=0.8834, Hits@1=0.2774, Hits@5=0.5606, Hits@10=0.7061


Training LightGCN:  14%|█▍        | 14/100 [01:46<10:50,  7.56s/it]

Epoch  15: Loss=0.8043

Evaluating epoch 15...


Training LightGCN:  15%|█▌        | 15/100 [01:55<10:56,  7.73s/it]

Epoch 15: Loss=0.8043, AUC=0.8862, Hits@1=0.2783, Hits@5=0.5620, Hits@10=0.7057


Training LightGCN:  19%|█▉        | 19/100 [02:25<10:13,  7.58s/it]

Epoch  20: Loss=0.8032

Evaluating epoch 20...


Training LightGCN:  20%|██        | 20/100 [02:33<10:18,  7.73s/it]

Epoch 20: Loss=0.8032, AUC=0.8836, Hits@1=0.2801, Hits@5=0.5622, Hits@10=0.7061


Training LightGCN:  24%|██▍       | 24/100 [03:03<09:31,  7.52s/it]

Epoch  25: Loss=0.8036

Evaluating epoch 25...


Training LightGCN:  25%|██▌       | 25/100 [03:11<09:35,  7.68s/it]

Epoch 25: Loss=0.8036, AUC=0.8840, Hits@1=0.2780, Hits@5=0.5611, Hits@10=0.7041


Training LightGCN:  29%|██▉       | 29/100 [03:41<08:57,  7.57s/it]

Epoch  30: Loss=0.8031

Evaluating epoch 30...


Training LightGCN:  30%|███       | 30/100 [03:49<09:00,  7.73s/it]

Epoch 30: Loss=0.8031, AUC=0.8831, Hits@1=0.2766, Hits@5=0.5604, Hits@10=0.7052


Training LightGCN:  34%|███▍      | 34/100 [04:19<08:21,  7.60s/it]

Epoch  35: Loss=0.8028

Evaluating epoch 35...


Training LightGCN:  35%|███▌      | 35/100 [04:27<08:23,  7.74s/it]

Epoch 35: Loss=0.8028, AUC=0.8838, Hits@1=0.2768, Hits@5=0.5614, Hits@10=0.7054


Training LightGCN:  39%|███▉      | 39/100 [04:57<07:40,  7.55s/it]

Epoch  40: Loss=0.8031

Evaluating epoch 40...


Training LightGCN:  40%|████      | 40/100 [05:05<07:42,  7.71s/it]

Epoch 40: Loss=0.8031, AUC=0.8854, Hits@1=0.2743, Hits@5=0.5604, Hits@10=0.7041


Training LightGCN:  44%|████▍     | 44/100 [05:35<07:03,  7.57s/it]

Epoch  45: Loss=0.8026

Evaluating epoch 45...


Training LightGCN:  45%|████▌     | 45/100 [05:43<07:05,  7.73s/it]

Epoch 45: Loss=0.8026, AUC=0.8845, Hits@1=0.2762, Hits@5=0.5594, Hits@10=0.7046


Training LightGCN:  49%|████▉     | 49/100 [06:14<06:26,  7.58s/it]

Epoch  50: Loss=0.8024

Evaluating epoch 50...


Training LightGCN:  50%|█████     | 50/100 [06:22<06:26,  7.72s/it]

Epoch 50: Loss=0.8024, AUC=0.8827, Hits@1=0.2740, Hits@5=0.5562, Hits@10=0.7040


Training LightGCN:  54%|█████▍    | 54/100 [06:52<05:46,  7.53s/it]

Epoch  55: Loss=0.8025

Evaluating epoch 55...


Training LightGCN:  55%|█████▌    | 55/100 [07:00<05:46,  7.71s/it]

Epoch 55: Loss=0.8025, AUC=0.8828, Hits@1=0.2732, Hits@5=0.5596, Hits@10=0.7030


Training LightGCN:  59%|█████▉    | 59/100 [07:30<05:10,  7.58s/it]

Epoch  60: Loss=0.8017

Evaluating epoch 60...


Training LightGCN:  60%|██████    | 60/100 [07:38<05:09,  7.73s/it]

Epoch 60: Loss=0.8017, AUC=0.8843, Hits@1=0.2740, Hits@5=0.5579, Hits@10=0.7045


Training LightGCN:  64%|██████▍   | 64/100 [08:08<04:33,  7.59s/it]

Epoch  65: Loss=0.8020

Evaluating epoch 65...


Training LightGCN:  64%|██████▍   | 64/100 [08:16<04:39,  7.76s/it]

Epoch 65: Loss=0.8020, AUC=0.8843, Hits@1=0.2723, Hits@5=0.5557, Hits@10=0.7034
Early stopping at epoch 65 (Best AUC: 0.8862)

LightGCN training completed!

Running final test evaluation...






[LightGCN FB15K TEST RESULTS]
AUC: 0.8821
Hits@1: 0.2720
Hits@5: 0.5579
Hits@10: 0.7034
Final model saved: /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/FB15K/fb15k_lightgcn_model.pt
Saving 14 consistent metric entries...
Metrics saved to /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/FB15K/fb15k_lightgcn_metrics.txt
✅ Full model + training history bundle saved: /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/FB15K/fb15k_lightgcn_full_bundle.pt
