In [1]:
!pip uninstall -y torch torchvision torchaudio

# Install PyTorch 2.2.0 with CUDA 11.8
!pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118

# Install torch-scatter for PyTorch 2.2.0 + CUDA 11.8
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-2.2.0+cu118.html

!pip install torch_geometric torchmetrics

Found existing installation: torch 2.2.0+cu118
Uninstalling torch-2.2.0+cu118:
  Successfully uninstalled torch-2.2.0+cu118
Found existing installation: torchvision 0.17.0+cu118
Uninstalling torchvision-0.17.0+cu118:
  Successfully uninstalled torchvision-0.17.0+cu118
Found existing installation: torchaudio 2.2.0+cu118
Uninstalling torchaudio-2.2.0+cu118:
  Successfully uninstalled torchaudio-2.2.0+cu118
Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting torch==2.2.0
  Using cached https://download.pytorch.org/whl/cu118/torch-2.2.0%2Bcu118-cp312-cp312-linux_x86_64.whl (811.6 MB)
Collecting torchvision==0.17.0
  Using cached https://download.pytorch.org/whl/cu118/torchvision-0.17.0%2Bcu118-cp312-cp312-linux_x86_64.whl (6.2 MB)
Collecting torchaudio==2.2.0
  Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.2.0%2Bcu118-cp312-cp312-linux_x86_64.whl (3.3 MB)
Installing collected packages: torch, torchvision, torchaudio
Successfully installed torch-2.2.0+cu

In [None]:
# -----------------------------
# Standard library
# -----------------------------
import os
import csv
import random
from pathlib import Path
from collections import defaultdict
from typing import List, Tuple, Dict, Optional

# -----------------------------
# Third-party libraries
# -----------------------------
import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from google.colab import drive

# -----------------------------
# PyTorch and PyG libraries
# -----------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch_geometric.utils import degree, add_self_loops
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter
from torchmetrics.classification import BinaryAUROC

# import sys
# current_dir = Path(__file__).parent if '__file__' in globals() else Path.cwd()
# parent_dir = current_dir.parent
# sys.path.append(str(parent_dir))
# from utility.dataset_loader import KGDataModuleCollapsed, KGDataModuleTyped

In [17]:
# -----------------------
# Helpers
# -----------------------
def _read_triples(path: Path) -> List[Tuple[str, str, str]]:
    triples = []
    with open(path, "r", newline="") as f:
        reader = csv.reader(f, delimiter="\t")
        for row in reader:
            if not row:
                continue
            h, r, t = row
            triples.append((h, r, t))
    return triples


def _build_id_maps(
        train_p: Path,
        valid_p: Optional[Path] = None,
        test_p: Optional[Path] = None,
) -> Tuple[Dict[str, int], Dict[str, int]]:
    """Build ent2id and rel2id from all splits available."""
    ents, rels = set(), set()
    for p in [train_p, valid_p, test_p]:
        if p is None:
            continue
        for h, r, t in _read_triples(p):
            ents.add(h); ents.add(t); rels.add(r)
    ent2id = {e: i for i, e in enumerate(sorted(ents))}
    rel2id = {r: i for i, r in enumerate(sorted(rels))}
    return ent2id, rel2id


# -----------------------
# Datasets
# -----------------------
class _PairsDataset(Dataset):
    """Untyped pairs (h, t), label=1 for each positive edge."""
    def __init__(self, pairs: torch.LongTensor):
        # pairs: [N,2]
        self.pairs = pairs
        self.labels = torch.ones(pairs.size(0), dtype=torch.float32)

    def __len__(self):
        return self.pairs.size(0)

    def __getitem__(self, idx):
        return self.pairs[idx], self.labels[idx]


class _TriplesDataset(Dataset):
    """Typed triples (h, r, t), label=1 for each positive triple."""
    def __init__(self, triples: torch.LongTensor):
        # triples: [N,3]
        self.triples = triples
        self.labels = torch.ones(triples.size(0), dtype=torch.float32)

    def __len__(self):
        return self.triples.size(0)

    def __getitem__(self, idx):
        return self.triples[idx], self.labels[idx]


# ============================================================
# Collapsed edge types (untyped) datamodule
# ============================================================
class KGDataModuleCollapsed:
    """
    Produces untyped (h, t) pairs from KG triples.

    Public methods:
      - train_loader()
      - val_loader()
      - test_loader()

    Args:
        train_path, valid_path, test_path: Path to split files (WN18RR format).
        batch_size: int
        shuffle: bool (applied to train loader only)
        num_workers: int (DataLoader)
        add_reverse: if True, also add (t, h) for every (h, t)
    """
    def __init__(
            self,
            train_path: Path,
            valid_path: Optional[Path] = None,
            test_path: Optional[Path] = None,
            batch_size: int = 2048,
            shuffle: bool = True,
            num_workers: int = 0,
            add_reverse: bool = True,
    ):
        self.train_path = Path(train_path)
        self.valid_path = Path(valid_path) if valid_path else None
        self.test_path  = Path(test_path)  if test_path  else None

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.add_reverse = add_reverse

        # Build ids
        self.ent2id, self.rel2id = _build_id_maps(self.train_path, self.valid_path, self.test_path)

        # Build tensors per split
        self._train_pairs = self._make_pairs(self.train_path)
        self._valid_pairs = self._make_pairs(self.valid_path) if self.valid_path else None
        self._test_pairs  = self._make_pairs(self.test_path)  if self.test_path  else None

        # Datasets
        self._train_ds = _PairsDataset(self._train_pairs)
        self._valid_ds = _PairsDataset(self._valid_pairs) if self._valid_pairs is not None else None
        self._test_ds  = _PairsDataset(self._test_pairs)  if self._test_pairs  is not None else None

    def _make_pairs(self, path: Path) -> torch.LongTensor:
        pairs = []
        for h, _, t in _read_triples(path):
            h_id, t_id = self.ent2id[h], self.ent2id[t]
            pairs.append((h_id, t_id))
            if self.add_reverse:
                pairs.append((t_id, h_id))
        if not pairs:
            return torch.empty(0, 2, dtype=torch.long)
        return torch.tensor(pairs, dtype=torch.long)

    # -------- public loaders --------
    def train_loader(self) -> DataLoader:
        return DataLoader(
            self._train_ds, batch_size=self.batch_size, shuffle=self.shuffle,
            num_workers=self.num_workers, pin_memory=False
        )

    def val_loader(self) -> Optional[DataLoader]:
        if self._valid_ds is None:
            return None
        return DataLoader(
            self._valid_ds, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=False
        )

    def test_loader(self) -> Optional[DataLoader]:
        if self._test_ds is None:
            return None
        return DataLoader(
            self._test_ds, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=False
        )


# ============================================================
# Typed (with edge types) datamodule
# ============================================================
class KGDataModuleTyped:
    """
    Produces typed (h, r, t) triples from KG files.

    Public methods:
      - train_loader()
      - val_loader()
      - test_loader()

    Args:
        train_path, valid_path, test_path: Path to split files.
        batch_size, shuffle, num_workers: DataLoader args.
        add_reverse: If True, add reverse links.
        reverse_relation_strategy:
            'duplicate_rel' -> create an inverse relation id per r (e.g., r#inv)
            'same_rel'      -> reuse the same r id for reverse triple
    """
    def __init__(
            self,
            train_path: Path,
            valid_path: Optional[Path] = None,
            test_path: Optional[Path] = None,
            batch_size: int = 2048,
            shuffle: bool = True,
            num_workers: int = 0,
            add_reverse: bool = True,
            reverse_relation_strategy: str = "duplicate_rel",
    ):
        assert reverse_relation_strategy in ("duplicate_rel", "same_rel")
        self.train_path = Path(train_path)
        self.valid_path = Path(valid_path) if valid_path else None
        self.test_path  = Path(test_path)  if test_path  else None

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.add_reverse = add_reverse
        self.reverse_relation_strategy = reverse_relation_strategy

        # Base id maps from original relations
        self.ent2id, base_rel2id = _build_id_maps(self.train_path, self.valid_path, self.test_path)

        # Possibly extend rel2id with inverse relations
        if add_reverse and reverse_relation_strategy == "duplicate_rel":
            # Make space for inverse ids
            self.rel2id = dict(base_rel2id)
            self.inv_of = {}  # map original rel -> inverse rel id
            next_id = max(self.rel2id.values(), default=-1) + 1
            # Ensure deterministic order
            for r in sorted(base_rel2id.keys()):
                inv_name = r + "_rev"
                if inv_name not in self.rel2id:
                    self.rel2id[inv_name] = next_id
                    self.inv_of[r] = next_id
                    next_id += 1
        else:
            self.rel2id = base_rel2id
            self.inv_of = None  # not used

        # Build tensors per split
        self._train_triples = self._make_triples(self.train_path)
        self._valid_triples = self._make_triples(self.valid_path) if self.valid_path else None
        self._test_triples  = self._make_triples(self.test_path)  if self.test_path  else None

        # Datasets
        self._train_ds = _TriplesDataset(self._train_triples)
        self._valid_ds = _TriplesDataset(self._valid_triples) if self._valid_triples is not None else None
        self._test_ds  = _TriplesDataset(self._test_triples)  if self._test_triples  is not None else None

    def _make_triples(self, path: Path) -> torch.LongTensor:
        if path is None:
            return torch.empty(0, 3, dtype=torch.long)

        rows = []
        for h, r, t in _read_triples(path):
            h_id, t_id = self.ent2id[h], self.ent2id[t]
            r_id = self.rel2id[r]  # original direction
            rows.append((h_id, r_id, t_id))

            if self.add_reverse:
                if self.reverse_relation_strategy == "duplicate_rel":
                    r_inv_id = self.inv_of[r]  # created in __init__
                    rows.append((t_id, r_inv_id, h_id))
                else:  # same_rel
                    rows.append((t_id, r_id, h_id))

        if not rows:
            return torch.empty(0, 3, dtype=torch.long)
        return torch.tensor(rows, dtype=torch.long)

    # -------- public loaders --------
    def train_loader(self) -> DataLoader:
        return DataLoader(
            self._train_ds, batch_size=self.batch_size, shuffle=self.shuffle,
            num_workers=self.num_workers, pin_memory=False
        )

    def val_loader(self) -> Optional[DataLoader]:
        if self._valid_ds is None:
            return None
        return DataLoader(
            self._valid_ds, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=False
        )

    def test_loader(self) -> Optional[DataLoader]:
        if self._test_ds is None:
            return None
        return DataLoader(
            self._test_ds, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=False
        )

In [4]:
# Dataset loading using dataset_loader.py
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Mount Google Drive
drive.mount('/content/drive')

# Define the path to your dataset
base_path = Path("/content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/WN18RR")

# Dataset paths
train_path = base_path / "train.txt"
valid_path = base_path / "valid.txt"
test_path = base_path / "test.txt"

# Initialize data modules
# Collapsed mode for LightGCN (untyped pairs)
dm_collapsed = KGDataModuleCollapsed(
    train_path=train_path,
    valid_path=valid_path,
    test_path=test_path,
    batch_size=4096,
    shuffle=True,
    add_reverse=True
)

# Typed mode for R-LightGCN (typed triples)
dm_typed = KGDataModuleTyped(
    train_path=train_path,
    valid_path=valid_path,
    test_path=test_path,
    batch_size=4096,
    shuffle=True,
    add_reverse=True,
    reverse_relation_strategy="duplicate_rel"
)

num_entities = len(dm_collapsed.ent2id)
num_relations = len(dm_collapsed.rel2id)  # Original relations
num_relations_with_inv = len(dm_typed.rel2id)  # With inverse relations

print(f"Dataset: WN18RR")
print(f"Entities: {num_entities:,}")
print(f"Original Relations: {num_relations:,}")
print(f"Relations (with inverse): {num_relations_with_inv:,}")
print(f"Training pairs: {len(dm_collapsed._train_pairs):,}")
print(f"Training triples (typed): {len(dm_typed._train_triples):,}")

# Build edge_index for LightGCN (collapsed pairs)
edge_index = dm_collapsed._train_pairs.t().contiguous().to(device)
edge_index, _ = add_self_loops(edge_index, num_nodes=num_entities)

print(f"Graph edges (with self-loops): {edge_index.shape[1]:,}")
print(f"Graph edge_index shape: {tuple(edge_index.shape)}")

train_loader_collapsed = dm_collapsed.train_loader()
val_loader_collapsed = dm_collapsed.val_loader()
test_loader_collapsed = dm_collapsed.test_loader()

train_loader_typed = dm_typed.train_loader()
val_loader_typed = dm_typed.val_loader()
test_loader_typed = dm_typed.test_loader()

print(f"Data loaders created:")
print(f"Train batches (collapsed): {len(train_loader_collapsed)}")
print(f"Val batches (collapsed): {len(val_loader_collapsed) if val_loader_collapsed else 0}")
print(f"Test batches (collapsed): {len(test_loader_collapsed) if test_loader_collapsed else 0}")

Using device: cuda
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Dataset: WN18RR
Entities: 40,943
Original Relations: 11
Relations (with inverse): 22
Training pairs: 173,670
Training triples (typed): 173,670
Graph edges (with self-loops): 214,613
Graph edge_index shape: (2, 214613)
Data loaders created:
Train batches (collapsed): 43
Val batches (collapsed): 2
Test batches (collapsed): 2


In [19]:
# Model Saving Utility Functions
def save_model_checkpoint(model, optimizer, hyperparameters, final_test_metrics,
                         training_history, model_name, filename):
    """
    Save model checkpoint with comprehensive information.

    Args:
        model: The trained model
        optimizer: The optimizer used
        hyperparameters: Dict with training hyperparameters
        final_test_metrics: Final test evaluation results
        training_history: Training metrics history
        model_name: Name of the model for display
        filename: Output filename for the checkpoint
    """
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'hyperparameters': hyperparameters,
        'final_test_metrics': final_test_metrics,
        'training_history': training_history,
        'model_name': model_name
    }

    torch.save(checkpoint, filename)
    print(f"{model_name} model checkpoint saved to {filename}")

def load_model_checkpoint(filename, model_class, device, **model_kwargs):
    """
    Load model checkpoint and restore model state.

    Args:
        filename: Path to checkpoint file
        model_class: Model class to instantiate
        device: Device to load model on
        **model_kwargs: Arguments for model instantiation

    Returns:
        model: Loaded model
        checkpoint: Full checkpoint data
    """
    checkpoint = torch.load(filename, map_location=device)

    # Create model instance
    model = model_class(**model_kwargs).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])

    print(f"Model loaded from {filename}")
    print(f"Model: {checkpoint.get('model_name', 'Unknown')}")
    print(f"Hyperparameters: {checkpoint.get('hyperparameters', {})}")

    if 'final_test_metrics' in checkpoint:
        metrics = checkpoint['final_test_metrics']
        if 'head' in metrics and 'tail' in metrics:
            auc_avg = (metrics['head']['auc'] + metrics['tail']['auc']) / 2
            hits10_avg = (metrics['head']['hits@10'] + metrics['tail']['hits@10']) / 2
            print(f"Test AUC (avg): {auc_avg:.4f}")
            print(f"Test Hits@10 (avg): {hits10_avg:.4f}")

    return model, checkpoint

In [None]:
class EarlyStopping:
    """
    Stop training if validation metric does not improve for 'patience' epochs.
    """
    def __init__(self, patience=5, mode='max', min_delta=1e-4):
        self.patience = patience
        self.mode = mode
        self.min_delta = min_delta
        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, current_score):
        if self.best_score is None:
            self.best_score = current_score
            return False

        if self.mode == 'max':
            improvement = current_score - self.best_score
        else:
            improvement = self.best_score - current_score

        if improvement > self.min_delta:
            self.best_score = current_score
            self.counter = 0
        else:
            self.counter += 1

        if self.counter >= self.patience:
            self.early_stop = True

        return self.early_stop

class MetricsTracker:
    """Simplified metrics tracker for training metrics without head/tail split"""
    def __init__(self):
        self.metrics = defaultdict(list)

    def add(self, epoch, **kwargs):
        """Add metrics for a given epoch"""
        self.metrics['epoch'].append(epoch)
        for key, value in kwargs.items():
            self.metrics[key].append(value)

    def get_best_epoch(self, metric='val_auc'):
        """Return the epoch with the best value for the given metric (higher is better)"""
        if metric not in self.metrics or not self.metrics[metric]:
            return 0
        best_idx = np.argmax(self.metrics[metric])
        return self.metrics['epoch'][best_idx]

    def save_to_file(self, filepath):
        """Save metrics to a text file in tabular format"""
        with open(filepath, 'w') as f:
            f.write("Epoch\tLoss\tVal_AUC\tHits@1\tHits@5\tHits@10\n")
            for i in range(len(self.metrics['epoch'])):
                epoch = self.metrics['epoch'][i]
                loss = self.metrics['loss'][i]
                val_auc = self.metrics['val_auc'][i]
                hits1 = self.metrics['hits1'][i]
                hits5 = self.metrics['hits5'][i]
                hits10 = self.metrics['hits10'][i]
                f.write(f"{epoch}\t{loss:.6f}\t{val_auc:.6f}\t{hits1:.6f}\t{hits5:.6f}\t{hits10:.6f}\n")


In [None]:
# -----------------------------
# Decoders
# -----------------------------
class DotProductDecoder(nn.Module):
    def forward(self, z: torch.Tensor, pairs: torch.LongTensor) -> torch.Tensor:
        return (z[pairs[:, 0]] * z[pairs[:, 1]]).sum(dim=1)

class DistMultDecoder(nn.Module):
    def __init__(self, num_relations: int, emb_dim: int):
        super().__init__()
        self.rel_emb = nn.Embedding(num_relations, emb_dim)
        nn.init.xavier_uniform_(self.rel_emb.weight)

    def forward(self, z: torch.Tensor, triples: torch.LongTensor) -> torch.Tensor:
        h, r, t = triples[:, 0], triples[:, 1], triples[:, 2]
        r_vec = self.rel_emb(r)
        return (z[h] * r_vec * z[t]).sum(dim=1)

    @torch.no_grad()
    def score_heads(self, z: torch.Tensor, r: int, t: int, entity_indices: torch.Tensor) -> torch.Tensor:
        r_vec = self.rel_emb.weight[r]
        chunk = z[entity_indices]
        return (chunk * r_vec * z[t]).sum(dim=1)

    @torch.no_grad()
    def score_tails(self, z: torch.Tensor, h: int, r: int, entity_indices: torch.Tensor) -> torch.Tensor:
        r_vec = self.rel_emb.weight[r]
        chunk = z[entity_indices]
        return (z[h] * r_vec * chunk).sum(dim=1)

# -----------------------------
# R-LightGCNConv with variants
# -----------------------------
class RLightGCNConvVar(MessagePassing):
    def __init__(self, num_relations: int, emb_dim: int, rel_emb: nn.Embedding,
                 relation_transform: str = 'scalar', norm_type: str = 'global'):
        super().__init__(aggr='add')
        assert relation_transform in {'scalar','diagonal'}
        assert norm_type in {'global','per_relation'}
        self.rel_emb = rel_emb
        self.emb_dim = emb_dim
        self.relation_transform = relation_transform
        self.norm_type = norm_type

        if relation_transform == 'scalar':
            self.rel_scale = nn.Embedding(num_relations, 1)
            nn.init.ones_(self.rel_scale.weight)
        else:
            self.rel_scale = nn.Embedding(num_relations, emb_dim)
            nn.init.ones_(self.rel_scale.weight)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_type: torch.LongTensor) -> torch.Tensor:
        row, col = edge_index
        if self.norm_type == 'global':
            deg = torch.bincount(col, minlength=x.size(0)).float()
            deg_inv_sqrt = deg.clamp(min=1).pow(-0.5)
            norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        else:  # per_relation
            num_nodes = x.size(0)
            num_rel = self.rel_emb.num_embeddings
            keys_col = edge_type.long() * num_nodes + col.long()
            keys_row = edge_type.long() * num_nodes + row.long()
            ones = torch.ones_like(col, dtype=x.dtype)
            dim_size = int(num_rel * num_nodes)
            deg_col_per_key = scatter(ones, keys_col, dim=0, dim_size=dim_size)
            deg_row_per_key = scatter(ones, keys_row, dim=0, dim_size=dim_size)
            deg_col = deg_col_per_key[keys_col]
            deg_row = deg_row_per_key[keys_row]
            norm = deg_row.clamp(min=1).pow(-0.5) * deg_col.clamp(min=1).pow(-0.5)
        return self.propagate(edge_index, x=x, edge_type=edge_type, norm=norm)

    def message(self, x_j: torch.Tensor, edge_type: torch.LongTensor, norm: torch.Tensor) -> torch.Tensor:
        if self.relation_transform == 'scalar':
            out = x_j * self.rel_scale(edge_type)
        else:
            out = x_j * self.rel_scale(edge_type)
        return norm.view(-1, 1) * out

# -----------------------------
# R-LightGCNVar model
# -----------------------------
class RLightGCNVar(nn.Module):
    def __init__(self, num_nodes: int, num_relations: int, emb_dim: int = 64, num_layers: int = 3,
                 relation_transform: str = 'scalar', norm_type: str = 'global', decoder: str = 'distmult'):
        super().__init__()
        self.embedding = nn.Embedding(num_nodes, emb_dim)
        nn.init.xavier_uniform_(self.embedding.weight)
        self.rel_emb = nn.Embedding(num_relations, emb_dim)
        nn.init.xavier_uniform_(self.rel_emb.weight)
        self.convs = nn.ModuleList([
            RLightGCNConvVar(num_relations, emb_dim, self.rel_emb, relation_transform, norm_type)
            for _ in range(num_layers)
        ])
        self.num_layers = num_layers
        self.decoder_type = decoder.lower()
        if self.decoder_type == 'distmult':
            self.decoder = DistMultDecoder(num_relations, emb_dim)
        elif self.decoder_type == 'dot':
            self.decoder = DotProductDecoder()
        else:
            raise ValueError("decoder must be 'dot' or 'distmult'")

    def encode(self, edge_index: torch.Tensor, edge_type: torch.LongTensor = None) -> torch.Tensor:
        x = self.embedding.weight
        out = x
        for conv in self.convs:
            x = conv(x, edge_index, edge_type)
            out = out + x
        return out / (self.num_layers + 1)

    def score_triples(self, z: torch.Tensor, triples: torch.LongTensor) -> torch.Tensor:
        if self.decoder_type == 'dot':
            pairs = triples[:, [0, 2]]
            return self.decoder(z, pairs)
        else:
            return self.decoder(z, triples)

# -----------------------------
# Training function
# -----------------------------
def train_one_epoch_r_lightgcn_var(model, data_loader, optimizer, rel_edge_index, edge_type, device,
                                   num_entities, max_grad_norm=5.0):
    model.train()
    total_loss = 0.0
    steps = 0
    for batch in data_loader:
        optimizer.zero_grad()
        triples = batch[0] if isinstance(batch, (list, tuple)) else batch
        triples = triples.to(device)
        if triples.dim() == 1: triples = triples.unsqueeze(0)

        z = model.encode(rel_edge_index, edge_type)
        z = z / (z.norm(dim=1, keepdim=True) + 1e-9)

        pos_scores = model.score_triples(z, triples)

        B = triples.size(0)
        neg_heads = torch.randint(0, num_entities, (B,), device=device)
        neg_tails = torch.randint(0, num_entities, (B,), device=device)
        neg_h_triples = triples.clone(); neg_h_triples[:, 0] = neg_heads
        neg_t_triples = triples.clone(); neg_t_triples[:, 2] = neg_tails
        neg_scores_h = model.score_triples(z, neg_h_triples)
        neg_scores_t = model.score_triples(z, neg_t_triples)

        loss = F.softplus(-(pos_scores - neg_scores_h)).mean() + \
               F.softplus(-(pos_scores - neg_scores_t)).mean()
        if not torch.isfinite(loss): continue
        loss.backward()
        if max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        total_loss += loss.item()
        steps += 1
    return total_loss / max(1, steps)

# -----------------------------
# Evaluation
# -----------------------------
def evaluate_auc_hits(model, triples, num_entities, edge_index=None, edge_type=None,
                      batch_size=4096, device=None):
    """
    AUC + Hits@K evaluation (averaged, tail-corruption)
    Works for dot or distmult decoders
    """
    model.eval()
    if device is None:
        device = next(model.parameters()).device

    if triples.numel() == 0:
        # Handle empty validation set
        return {"auc": 0.5, "hits@1": 0.0, "hits@5": 0.0, "hits@10": 0.0}

    def batch_iter(tensor, size):
        for i in range(0, len(tensor), size):
            yield tensor[i:i+size]

    def sample_negatives(pos_batch):
        neg_batch = pos_batch.clone()
        neg_batch[:, -1] = torch.randint(0, num_entities, (pos_batch.size(0),), device=pos_batch.device)
        return neg_batch

    scores_all, labels_all = [], []

    with torch.no_grad():
        emb = model.encode(edge_index, edge_type)
        emb = emb / (emb.norm(dim=1, keepdim=True) + 1e-9)

        for pos in batch_iter(triples, batch_size):
            if pos.size(0) == 0:
                continue

            pos = pos.to(device)
            neg = sample_negatives(pos)

            if isinstance(model.decoder, DistMultDecoder):
                s_pos = model.decoder(emb, pos)
                s_neg = model.decoder(emb, neg)
            else:
                s_pos = (emb[pos[:,0]] * emb[pos[:,-1]]).sum(dim=1)
                s_neg = (emb[neg[:,0]] * emb[neg[:,-1]]).sum(dim=1)

            scores_all.append(torch.cat([s_pos, s_neg], dim=0))
            labels_all.append(torch.cat([torch.ones_like(s_pos), torch.zeros_like(s_neg)], dim=0))

    # If no scores were collected, return defaults
    if len(scores_all) == 0:
        return {"auc": 0.5, "hits@1": 0.0, "hits@5": 0.0, "hits@10": 0.0}

    scores_all = torch.cat(scores_all, dim=0)
    labels_all = torch.cat(labels_all, dim=0)

    # Compute ROC-AUC using torchmetrics (pure torch, avoids numpy)
    try:
        auroc_metric = BinaryAUROC().to(scores_all.device)
        auc = auroc_metric(scores_all, labels_all.int()).item()
    except Exception:
        auc = 0.5

    # Hits@K evaluation
    hits_at = {1:0, 5:0, 10:0}
    n_trials = 0

    with torch.no_grad():
        for pos in batch_iter(triples, batch_size):
            if pos.size(0) == 0:
                continue

            pos = pos.to(device)
            B = pos.size(0)
            true_t = pos[:, -1]
            rand_t = torch.randint(0, num_entities, (B, 99), device=device)
            tails = torch.cat([true_t.unsqueeze(1), rand_t], dim=1)

            e_h = emb[pos[:,0]]
            e_candidates = emb[tails]
            s = (e_h.unsqueeze(1) * e_candidates).sum(dim=2)
            ranks = (s.argsort(dim=1, descending=True) == 0).nonzero()[:,1] + 1

            for k in hits_at.keys():
                hits_at[k] += (ranks <= k).sum().item()
            n_trials += B

    hits_at = {f"hits@{k}": v / n_trials for k, v in hits_at.items()}
    return {"auc": float(auc), **hits_at}



In [None]:
def run_r_lightgcn_ablation(
    num_nodes, num_relations, train_loader, val_triples,
    edge_index, edge_type, emb_dim=64, num_layers=3,
    lr=0.001, epochs=100, batch_size=2048, patience=10,
    save_dir="/content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/WN18RR",
    ABLATION_CONFIGS=None, device=None
):
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(save_dir, exist_ok=True)
    all_metrics = {}

    # --- Sanity checks ---
    assert val_triples.max() < num_nodes, f"Validation node index {val_triples.max()} >= num_nodes {num_nodes}"
    assert edge_index.max() < num_nodes, f"Edge index max {edge_index.max()} >= num_nodes {num_nodes}"
    assert edge_type.max() < num_relations, f"Relation index max {edge_type.max()} >= num_relations {num_relations}"

    # --- Combined tracker for all models ---
    combined_tracker = MetricsTracker()

    for cfg in ABLATION_CONFIGS:
        print(f"\n=== Training {cfg['name']} ===")
        model = RLightGCNVar(
            num_nodes=num_nodes,
            num_relations=num_relations,
            emb_dim=emb_dim,
            num_layers=num_layers,
            relation_transform=cfg['relation_transform'],
            norm_type=cfg['norm_type'],
            decoder=cfg['decoder']
        ).to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr)

        best_val_auc = 0.0
        epochs_no_improve = 0

        # --- Initialize per-model tracker ---
        tracker = MetricsTracker()

        for epoch in tqdm(range(1, epochs+1), desc=f"Training {cfg['name']}"):
            loss = train_one_epoch_r_lightgcn_var(
                model, train_loader, optimizer,
                edge_index, edge_type,
                device, num_nodes
            )

            if epoch <= 3 or epoch % 5 == 0:
                print(f"Epoch {epoch:3d} | Loss: {loss:.4f}")

            # --- Evaluate ---
            val_metrics = evaluate_auc_hits(
                model, val_triples, num_nodes,
                edge_index=edge_index, edge_type=edge_type,
                batch_size=batch_size, device=device
            )
            val_auc = val_metrics["auc"]
            if val_auc > best_val_auc:
                best_val_auc = val_auc
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            # --- Track metrics (no head/tail split) ---
            tracker.add(
                epoch,
                loss=loss,
                val_auc=val_metrics.get('auc', 0.0),
                hits1=val_metrics.get('hits@1', 0.0),
                hits5=val_metrics.get('hits@5', 0.0),
                hits10=val_metrics.get('hits@10', 0.0)
            )

            # --- Add to combined tracker ---
            combined_tracker.add(
                epoch,
                model_name=cfg['name'],
                loss=loss,
                val_auc=val_metrics.get('auc', 0.0),
                hits1=val_metrics.get('hits@1', 0.0),
                hits5=val_metrics.get('hits@5', 0.0),
                hits10=val_metrics.get('hits@10', 0.0)
            )

            if epochs_no_improve >= patience:
                print(f"Early stopping triggered at epoch {epoch}")
                break

        print(f"Validation metrics for {cfg['name']}: {val_metrics}")

        # --- Save final model to drive ---
        model_file = os.path.join(save_dir, f"{cfg['name']}.pt")
        torch.save(model.state_dict(), model_file)
        print(f"Final model saved: {model_file}")

        # --- Save per-model metrics table ---
        metrics_file = os.path.join(save_dir, f"{cfg['name']}_metrics.txt")
        with open(metrics_file, "w") as f:
            f.write("Epoch\tLoss\tVal_AUC\tHits@1\tHits@5\tHits@10\n")
            for i in range(len(tracker.metrics['epoch'])):
                f.write(f"{tracker.metrics['epoch'][i]}\t{tracker.metrics['loss'][i]:.6f}\t"
                        f"{tracker.metrics['val_auc'][i]:.6f}\t{tracker.metrics['hits1'][i]:.6f}\t"
                        f"{tracker.metrics['hits5'][i]:.6f}\t{tracker.metrics['hits10'][i]:.6f}\n")
        print(f"Per-model metrics saved to {metrics_file}")

        all_metrics[cfg['name']] = val_metrics

    # --- Save combined table for all models ---
    combined_file = os.path.join(save_dir, "all_models_metrics.txt")
    with open(combined_file, "w") as f:
        f.write("Model\tEpoch\tLoss\tVal_AUC\tHits@1\tHits@5\tHits@10\n")
        for i in range(len(combined_tracker.metrics['epoch'])):
            f.write(f"{combined_tracker.metrics.get('model_name', [''])[i]}\t"
                    f"{combined_tracker.metrics['epoch'][i]}\t"
                    f"{combined_tracker.metrics['loss'][i]:.6f}\t"
                    f"{combined_tracker.metrics['val_auc'][i]:.6f}\t"
                    f"{combined_tracker.metrics['hits1'][i]:.6f}\t"
                    f"{combined_tracker.metrics['hits5'][i]:.6f}\t"
                    f"{combined_tracker.metrics['hits10'][i]:.6f}\n")
    print(f"Combined metrics table saved to {combined_file}")

    print(f"\nAll ablation experiments completed. Metrics saved in {save_dir}")
    return all_metrics


In [None]:
# -----------------------------
# Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Define number of nodes and relations
# -----------------------------
num_nodes = num_entities        # e.g., 40943 for WN18RR
# num_relations = num_relations   # original number of relations, e.g., 11
# Use the number of relations including inverses from dm_typed for consistency
num_relations_typed = num_relations_with_inv # e.g., 22 for WN18RR with inverses

# -----------------------------
# Self-loops
# -----------------------------
# Create self-loop edges
self_loop_edges = torch.arange(num_nodes, device=device)
self_loop_edges = torch.stack([self_loop_edges, self_loop_edges], dim=0)  # [2, num_nodes]

# Assign a new relation ID for self-loops, outside the range of existing typed relations
self_loop_rel_id = num_relations_typed  # the next available relation id
self_loop_types = torch.full((num_nodes,), self_loop_rel_id, dtype=torch.long, device=device)

# -----------------------------
# Combine with training edges
# -----------------------------
# Use triples from dm_typed which already have relation IDs from 0 to num_relations_typed - 1
rel_edge_index = dm_typed._train_triples[:, [0, 2]].t().contiguous().to(device)  # [2, N]
edge_type = dm_typed._train_triples[:, 1].to(device)  # [N]

rel_edge_index = torch.cat([rel_edge_index, self_loop_edges], dim=1)
edge_type = torch.cat([edge_type, self_loop_types], dim=0)

num_relations_final = num_relations_typed + 1  # include original, inverse, and self-loop relation

# -----------------------------
# Validation triples
# -----------------------------
val_triples = torch.as_tensor(dm_typed._valid_triples, device=device)

# -----------------------------
# DataLoader for training triples
# -----------------------------
# Wrap training triples in a TensorDataset and DataLoader
train_triples_tensor = torch.as_tensor(dm_typed._train_triples, device=device)
train_dataset = TensorDataset(train_triples_tensor)
train_loader = DataLoader(train_dataset, batch_size=2048, shuffle=True)

# -----------------------------
# Ablation configs
# -----------------------------
ABLATION_CONFIGS = [
    {'name':'R-LightGCN-Basic-Global-Dot','relation_transform':'scalar','norm_type':'global','decoder':'dot'},
    {'name':'R-LightGCN-Diagonal-Global-Dot','relation_transform':'diagonal','norm_type':'global','decoder':'dot'},
    {'name':'R-LightGCN-Scalar-Global-DistMult','relation_transform':'scalar','norm_type':'global','decoder':'distmult'},
    {'name':'R-LightGCN-Diagonal-Global-DistMult','relation_transform':'diagonal','norm_type':'global','decoder':'distmult'},
    {'name':'R-LightGCN-Scalar-PerRel-Dot','relation_transform':'scalar','norm_type':'per_relation','decoder':'dot'},
    {'name':'R-LightGCN-Diagonal-PerRel-Dot','relation_transform':'diagonal','norm_type':'per_relation','decoder':'dot'},
    {'name':'R-LightGCN-Scalar-PerRel-DistMult','relation_transform':'scalar','norm_type':'per_relation','decoder':'distmult'},
    {'name':'R-LightGCN-Diagonal-PerRel-DistMult','relation_transform':'diagonal','norm_type':'per_relation','decoder':'distmult'}
]

# -----------------------------
# Run ablation
# -----------------------------
all_metrics = run_r_lightgcn_ablation(
    num_nodes=num_nodes,
    num_relations=num_relations_final, # Pass the final number of relations including self-loops
    train_loader=train_loader,
    val_triples=val_triples,
    edge_index=rel_edge_index,
    edge_type=edge_type,
    emb_dim=64,
    num_layers=3,
    lr=0.001,
    epochs=100,
    batch_size=2048,
    patience=10,
    save_dir="/content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/WN18RR",
    ABLATION_CONFIGS=ABLATION_CONFIGS,
    device=device
)


=== Training R-LightGCN-Basic-Global-Dot ===


Training R-LightGCN-Basic-Global-Dot:   1%|          | 1/100 [00:03<04:57,  3.00s/it]

Epoch   1 | Loss: 0.7934


Training R-LightGCN-Basic-Global-Dot:   2%|▏         | 2/100 [00:05<04:50,  2.97s/it]

Epoch   2 | Loss: 0.7083


Training R-LightGCN-Basic-Global-Dot:   3%|▎         | 3/100 [00:09<05:01,  3.11s/it]

Epoch   3 | Loss: 0.6929


Training R-LightGCN-Basic-Global-Dot:   5%|▌         | 5/100 [00:15<04:47,  3.03s/it]

Epoch   5 | Loss: 0.6811


Training R-LightGCN-Basic-Global-Dot:  10%|█         | 10/100 [00:30<04:30,  3.01s/it]

Epoch  10 | Loss: 0.6722


Training R-LightGCN-Basic-Global-Dot:  15%|█▌        | 15/100 [00:45<04:21,  3.08s/it]

Epoch  15 | Loss: 0.6693


Training R-LightGCN-Basic-Global-Dot:  20%|██        | 20/100 [01:01<03:59,  2.99s/it]

Epoch  20 | Loss: 0.6675


Training R-LightGCN-Basic-Global-Dot:  25%|██▌       | 25/100 [01:16<03:49,  3.06s/it]

Epoch  25 | Loss: 0.6662


Training R-LightGCN-Basic-Global-Dot:  30%|███       | 30/100 [01:32<03:43,  3.19s/it]

Epoch  30 | Loss: 0.6653


Training R-LightGCN-Basic-Global-Dot:  30%|███       | 30/100 [01:35<03:41,  3.17s/it]


Early stopping triggered at epoch 31
Validation metrics for R-LightGCN-Basic-Global-Dot: {'auc': 0.9374099969863892, 'hits@1': 0.6366183256427159, 'hits@5': 0.8190507580751483, 'hits@10': 0.8618984838497034}
Final model saved: /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Basic-Global-Dot.pt
Per-model metrics saved to /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Basic-Global-Dot_metrics.txt

=== Training R-LightGCN-Diagonal-Global-Dot ===


Training R-LightGCN-Diagonal-Global-Dot:   1%|          | 1/100 [00:03<05:14,  3.18s/it]

Epoch   1 | Loss: 0.7937


Training R-LightGCN-Diagonal-Global-Dot:   2%|▏         | 2/100 [00:06<05:14,  3.21s/it]

Epoch   2 | Loss: 0.7082


Training R-LightGCN-Diagonal-Global-Dot:   3%|▎         | 3/100 [00:10<05:34,  3.45s/it]

Epoch   3 | Loss: 0.6927


Training R-LightGCN-Diagonal-Global-Dot:   5%|▌         | 5/100 [00:16<05:07,  3.23s/it]

Epoch   5 | Loss: 0.6809


Training R-LightGCN-Diagonal-Global-Dot:  10%|█         | 10/100 [00:32<04:51,  3.23s/it]

Epoch  10 | Loss: 0.6724


Training R-LightGCN-Diagonal-Global-Dot:  15%|█▌        | 15/100 [00:49<04:50,  3.41s/it]

Epoch  15 | Loss: 0.6695


Training R-LightGCN-Diagonal-Global-Dot:  20%|██        | 20/100 [01:06<04:26,  3.33s/it]

Epoch  20 | Loss: 0.6681


Training R-LightGCN-Diagonal-Global-Dot:  25%|██▌       | 25/100 [01:22<04:05,  3.28s/it]

Epoch  25 | Loss: 0.6666


Training R-LightGCN-Diagonal-Global-Dot:  30%|███       | 30/100 [01:39<03:52,  3.32s/it]

Epoch  30 | Loss: 0.6665


Training R-LightGCN-Diagonal-Global-Dot:  35%|███▌      | 35/100 [01:55<03:32,  3.27s/it]

Epoch  35 | Loss: 0.6656


Training R-LightGCN-Diagonal-Global-Dot:  40%|████      | 40/100 [02:12<03:22,  3.37s/it]

Epoch  40 | Loss: 0.6654


Training R-LightGCN-Diagonal-Global-Dot:  41%|████      | 41/100 [02:19<03:20,  3.40s/it]


Early stopping triggered at epoch 42
Validation metrics for R-LightGCN-Diagonal-Global-Dot: {'auc': 0.9365990161895752, 'hits@1': 0.6331575477916941, 'hits@5': 0.8185563612392881, 'hits@10': 0.8630520764667106}
Final model saved: /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Diagonal-Global-Dot.pt
Per-model metrics saved to /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Diagonal-Global-Dot_metrics.txt

=== Training R-LightGCN-Scalar-Global-DistMult ===


Training R-LightGCN-Scalar-Global-DistMult:   1%|          | 1/100 [00:02<04:34,  2.78s/it]

Epoch   1 | Loss: 1.3084


Training R-LightGCN-Scalar-Global-DistMult:   2%|▏         | 2/100 [00:06<04:59,  3.06s/it]

Epoch   2 | Loss: 1.1803


Training R-LightGCN-Scalar-Global-DistMult:   3%|▎         | 3/100 [00:08<04:52,  3.02s/it]

Epoch   3 | Loss: 1.0829


Training R-LightGCN-Scalar-Global-DistMult:   5%|▌         | 5/100 [00:14<04:40,  2.95s/it]

Epoch   5 | Loss: 0.9256


Training R-LightGCN-Scalar-Global-DistMult:  10%|█         | 10/100 [00:30<04:39,  3.11s/it]

Epoch  10 | Loss: 0.6794


Training R-LightGCN-Scalar-Global-DistMult:  15%|█▌        | 15/100 [00:46<04:24,  3.11s/it]

Epoch  15 | Loss: 0.5384


Training R-LightGCN-Scalar-Global-DistMult:  20%|██        | 20/100 [01:01<04:09,  3.12s/it]

Epoch  20 | Loss: 0.4380


Training R-LightGCN-Scalar-Global-DistMult:  25%|██▌       | 25/100 [01:17<03:56,  3.16s/it]

Epoch  25 | Loss: 0.3647


Training R-LightGCN-Scalar-Global-DistMult:  30%|███       | 30/100 [01:32<03:38,  3.12s/it]

Epoch  30 | Loss: 0.3095


Training R-LightGCN-Scalar-Global-DistMult:  35%|███▌      | 35/100 [01:48<03:20,  3.09s/it]

Epoch  35 | Loss: 0.2601


Training R-LightGCN-Scalar-Global-DistMult:  40%|████      | 40/100 [02:04<03:07,  3.13s/it]

Epoch  40 | Loss: 0.2166


Training R-LightGCN-Scalar-Global-DistMult:  45%|████▌     | 45/100 [02:19<02:46,  3.02s/it]

Epoch  45 | Loss: 0.1804


Training R-LightGCN-Scalar-Global-DistMult:  50%|█████     | 50/100 [02:34<02:33,  3.07s/it]

Epoch  50 | Loss: 0.1522


Training R-LightGCN-Scalar-Global-DistMult:  55%|█████▌    | 55/100 [02:50<02:22,  3.18s/it]

Epoch  55 | Loss: 0.1311


Training R-LightGCN-Scalar-Global-DistMult:  60%|██████    | 60/100 [03:05<02:02,  3.06s/it]

Epoch  60 | Loss: 0.1156


Training R-LightGCN-Scalar-Global-DistMult:  65%|██████▌   | 65/100 [03:20<01:46,  3.03s/it]

Epoch  65 | Loss: 0.1032


Training R-LightGCN-Scalar-Global-DistMult:  70%|███████   | 70/100 [03:35<01:32,  3.07s/it]

Epoch  70 | Loss: 0.0920


Training R-LightGCN-Scalar-Global-DistMult:  75%|███████▌  | 75/100 [03:50<01:15,  3.00s/it]

Epoch  75 | Loss: 0.0830


Training R-LightGCN-Scalar-Global-DistMult:  80%|████████  | 80/100 [04:06<00:59,  2.96s/it]

Epoch  80 | Loss: 0.0727


Training R-LightGCN-Scalar-Global-DistMult:  85%|████████▌ | 85/100 [04:21<00:45,  3.07s/it]

Epoch  85 | Loss: 0.0657


Training R-LightGCN-Scalar-Global-DistMult:  90%|█████████ | 90/100 [04:36<00:30,  3.07s/it]

Epoch  90 | Loss: 0.0586


Training R-LightGCN-Scalar-Global-DistMult:  95%|█████████▌| 95/100 [04:52<00:15,  3.09s/it]

Epoch  95 | Loss: 0.0536


Training R-LightGCN-Scalar-Global-DistMult: 100%|██████████| 100/100 [05:07<00:00,  3.07s/it]


Epoch 100 | Loss: 0.0486
Validation metrics for R-LightGCN-Scalar-Global-DistMult: {'auc': 0.91287761926651, 'hits@1': 0.5545484508899143, 'hits@5': 0.7615359261700725, 'hits@10': 0.8111404087013843}
Final model saved: /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Scalar-Global-DistMult.pt
Per-model metrics saved to /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Scalar-Global-DistMult_metrics.txt

=== Training R-LightGCN-Diagonal-Global-DistMult ===


Training R-LightGCN-Diagonal-Global-DistMult:   1%|          | 1/100 [00:02<04:51,  2.95s/it]

Epoch   1 | Loss: 1.3071


Training R-LightGCN-Diagonal-Global-DistMult:   2%|▏         | 2/100 [00:05<04:50,  2.96s/it]

Epoch   2 | Loss: 1.1636


Training R-LightGCN-Diagonal-Global-DistMult:   3%|▎         | 3/100 [00:09<05:12,  3.22s/it]

Epoch   3 | Loss: 1.0663


Training R-LightGCN-Diagonal-Global-DistMult:   5%|▌         | 5/100 [00:16<05:06,  3.23s/it]

Epoch   5 | Loss: 0.9288


Training R-LightGCN-Diagonal-Global-DistMult:  10%|█         | 10/100 [00:32<04:56,  3.30s/it]

Epoch  10 | Loss: 0.6972


Training R-LightGCN-Diagonal-Global-DistMult:  15%|█▌        | 15/100 [00:48<04:37,  3.27s/it]

Epoch  15 | Loss: 0.5260


Training R-LightGCN-Diagonal-Global-DistMult:  20%|██        | 20/100 [01:05<04:25,  3.32s/it]

Epoch  20 | Loss: 0.4078


Training R-LightGCN-Diagonal-Global-DistMult:  25%|██▌       | 25/100 [01:22<04:10,  3.34s/it]

Epoch  25 | Loss: 0.3330


Training R-LightGCN-Diagonal-Global-DistMult:  30%|███       | 30/100 [01:38<03:47,  3.25s/it]

Epoch  30 | Loss: 0.2808


Training R-LightGCN-Diagonal-Global-DistMult:  35%|███▌      | 35/100 [01:54<03:28,  3.21s/it]

Epoch  35 | Loss: 0.2424


Training R-LightGCN-Diagonal-Global-DistMult:  40%|████      | 40/100 [02:11<03:23,  3.39s/it]

Epoch  40 | Loss: 0.2114


Training R-LightGCN-Diagonal-Global-DistMult:  45%|████▌     | 45/100 [02:27<03:02,  3.32s/it]

Epoch  45 | Loss: 0.1837


Training R-LightGCN-Diagonal-Global-DistMult:  50%|█████     | 50/100 [02:44<02:43,  3.28s/it]

Epoch  50 | Loss: 0.1595


Training R-LightGCN-Diagonal-Global-DistMult:  55%|█████▌    | 55/100 [03:00<02:29,  3.31s/it]

Epoch  55 | Loss: 0.1404


Training R-LightGCN-Diagonal-Global-DistMult:  57%|█████▋    | 57/100 [03:10<02:24,  3.35s/it]


Early stopping triggered at epoch 58
Validation metrics for R-LightGCN-Diagonal-Global-DistMult: {'auc': 0.9047621488571167, 'hits@1': 0.45187870797626895, 'hits@5': 0.7254449571522742, 'hits@10': 0.8050428477257745}
Final model saved: /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Diagonal-Global-DistMult.pt
Per-model metrics saved to /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Diagonal-Global-DistMult_metrics.txt

=== Training R-LightGCN-Scalar-PerRel-Dot ===


Training R-LightGCN-Scalar-PerRel-Dot:   1%|          | 1/100 [00:02<04:33,  2.77s/it]

Epoch   1 | Loss: 0.7271


Training R-LightGCN-Scalar-PerRel-Dot:   2%|▏         | 2/100 [00:05<04:29,  2.75s/it]

Epoch   2 | Loss: 0.6865


Training R-LightGCN-Scalar-PerRel-Dot:   3%|▎         | 3/100 [00:08<04:54,  3.04s/it]

Epoch   3 | Loss: 0.6791


Training R-LightGCN-Scalar-PerRel-Dot:   5%|▌         | 5/100 [00:14<04:45,  3.01s/it]

Epoch   5 | Loss: 0.6738


Training R-LightGCN-Scalar-PerRel-Dot:  10%|█         | 10/100 [00:30<04:37,  3.09s/it]

Epoch  10 | Loss: 0.6695


Training R-LightGCN-Scalar-PerRel-Dot:  15%|█▌        | 15/100 [00:45<04:15,  3.00s/it]

Epoch  15 | Loss: 0.6677


Training R-LightGCN-Scalar-PerRel-Dot:  18%|█▊        | 18/100 [00:58<04:24,  3.23s/it]


Early stopping triggered at epoch 19
Validation metrics for R-LightGCN-Scalar-PerRel-Dot: {'auc': 0.9362051486968994, 'hits@1': 0.6287079762689519, 'hits@5': 0.8127883981542519, 'hits@10': 0.8605800922874094}
Final model saved: /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Scalar-PerRel-Dot.pt
Per-model metrics saved to /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Scalar-PerRel-Dot_metrics.txt

=== Training R-LightGCN-Diagonal-PerRel-Dot ===


Training R-LightGCN-Diagonal-PerRel-Dot:   1%|          | 1/100 [00:03<05:12,  3.16s/it]

Epoch   1 | Loss: 0.7265


Training R-LightGCN-Diagonal-PerRel-Dot:   2%|▏         | 2/100 [00:06<05:21,  3.28s/it]

Epoch   2 | Loss: 0.6862


Training R-LightGCN-Diagonal-PerRel-Dot:   3%|▎         | 3/100 [00:09<05:13,  3.24s/it]

Epoch   3 | Loss: 0.6792


Training R-LightGCN-Diagonal-PerRel-Dot:   5%|▌         | 5/100 [00:16<05:05,  3.22s/it]

Epoch   5 | Loss: 0.6743


Training R-LightGCN-Diagonal-PerRel-Dot:  10%|█         | 10/100 [00:33<05:10,  3.45s/it]

Epoch  10 | Loss: 0.6699


Training R-LightGCN-Diagonal-PerRel-Dot:  15%|█▌        | 15/100 [00:49<04:41,  3.31s/it]

Epoch  15 | Loss: 0.6680


Training R-LightGCN-Diagonal-PerRel-Dot:  20%|██        | 20/100 [01:06<04:21,  3.27s/it]

Epoch  20 | Loss: 0.6672


Training R-LightGCN-Diagonal-PerRel-Dot:  25%|██▌       | 25/100 [01:22<04:08,  3.31s/it]

Epoch  25 | Loss: 0.6668


Training R-LightGCN-Diagonal-PerRel-Dot:  30%|███       | 30/100 [01:39<03:53,  3.34s/it]

Epoch  30 | Loss: 0.6660


Training R-LightGCN-Diagonal-PerRel-Dot:  34%|███▍      | 34/100 [01:56<03:45,  3.42s/it]


Epoch  35 | Loss: 0.6658
Early stopping triggered at epoch 35
Validation metrics for R-LightGCN-Diagonal-PerRel-Dot: {'auc': 0.9396750926971436, 'hits@1': 0.6290375741595253, 'hits@5': 0.8164139749505603, 'hits@10': 0.8628872775214239}
Final model saved: /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Diagonal-PerRel-Dot.pt
Per-model metrics saved to /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Diagonal-PerRel-Dot_metrics.txt

=== Training R-LightGCN-Scalar-PerRel-DistMult ===


Training R-LightGCN-Scalar-PerRel-DistMult:   1%|          | 1/100 [00:02<04:48,  2.91s/it]

Epoch   1 | Loss: 1.2738


Training R-LightGCN-Scalar-PerRel-DistMult:   2%|▏         | 2/100 [00:05<04:45,  2.91s/it]

Epoch   2 | Loss: 1.1435


Training R-LightGCN-Scalar-PerRel-DistMult:   3%|▎         | 3/100 [00:08<04:49,  2.99s/it]

Epoch   3 | Loss: 1.0574


Training R-LightGCN-Scalar-PerRel-DistMult:   5%|▌         | 5/100 [00:14<04:44,  3.00s/it]

Epoch   5 | Loss: 0.9243


Training R-LightGCN-Scalar-PerRel-DistMult:  10%|█         | 10/100 [00:30<04:29,  2.99s/it]

Epoch  10 | Loss: 0.6954


Training R-LightGCN-Scalar-PerRel-DistMult:  15%|█▌        | 15/100 [00:45<04:19,  3.05s/it]

Epoch  15 | Loss: 0.5361


Training R-LightGCN-Scalar-PerRel-DistMult:  20%|██        | 20/100 [01:00<03:58,  2.98s/it]

Epoch  20 | Loss: 0.4243


Training R-LightGCN-Scalar-PerRel-DistMult:  25%|██▌       | 25/100 [01:15<03:41,  2.95s/it]

Epoch  25 | Loss: 0.3581


Training R-LightGCN-Scalar-PerRel-DistMult:  30%|███       | 30/100 [01:31<03:36,  3.10s/it]

Epoch  30 | Loss: 0.3115


Training R-LightGCN-Scalar-PerRel-DistMult:  35%|███▌      | 35/100 [01:46<03:20,  3.09s/it]

Epoch  35 | Loss: 0.2730


Training R-LightGCN-Scalar-PerRel-DistMult:  40%|████      | 40/100 [02:01<03:03,  3.06s/it]

Epoch  40 | Loss: 0.2412


Training R-LightGCN-Scalar-PerRel-DistMult:  45%|████▌     | 45/100 [02:16<02:50,  3.10s/it]

Epoch  45 | Loss: 0.2129


Training R-LightGCN-Scalar-PerRel-DistMult:  50%|█████     | 50/100 [02:31<02:30,  3.02s/it]

Epoch  50 | Loss: 0.1886


Training R-LightGCN-Scalar-PerRel-DistMult:  55%|█████▌    | 55/100 [02:46<02:14,  2.99s/it]

Epoch  55 | Loss: 0.1672


Training R-LightGCN-Scalar-PerRel-DistMult:  58%|█████▊    | 58/100 [02:59<02:09,  3.09s/it]


Early stopping triggered at epoch 59
Validation metrics for R-LightGCN-Scalar-PerRel-DistMult: {'auc': 0.9100422263145447, 'hits@1': 0.4602834541858932, 'hits@5': 0.7318721160184575, 'hits@10': 0.8029004614370469}
Final model saved: /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Scalar-PerRel-DistMult.pt
Per-model metrics saved to /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Scalar-PerRel-DistMult_metrics.txt

=== Training R-LightGCN-Diagonal-PerRel-DistMult ===


Training R-LightGCN-Diagonal-PerRel-DistMult:   1%|          | 1/100 [00:03<05:13,  3.17s/it]

Epoch   1 | Loss: 1.2807


Training R-LightGCN-Diagonal-PerRel-DistMult:   2%|▏         | 2/100 [00:06<05:20,  3.27s/it]

Epoch   2 | Loss: 1.1546


Training R-LightGCN-Diagonal-PerRel-DistMult:   3%|▎         | 3/100 [00:09<05:16,  3.26s/it]

Epoch   3 | Loss: 1.0630


Training R-LightGCN-Diagonal-PerRel-DistMult:   5%|▌         | 5/100 [00:16<05:11,  3.28s/it]

Epoch   5 | Loss: 0.9233


Training R-LightGCN-Diagonal-PerRel-DistMult:  10%|█         | 10/100 [00:32<04:49,  3.22s/it]

Epoch  10 | Loss: 0.6973


Training R-LightGCN-Diagonal-PerRel-DistMult:  15%|█▌        | 15/100 [00:49<04:41,  3.32s/it]

Epoch  15 | Loss: 0.5520


Training R-LightGCN-Diagonal-PerRel-DistMult:  20%|██        | 20/100 [01:05<04:28,  3.36s/it]

Epoch  20 | Loss: 0.4557


Training R-LightGCN-Diagonal-PerRel-DistMult:  25%|██▌       | 25/100 [01:21<04:03,  3.24s/it]

Epoch  25 | Loss: 0.3858


Training R-LightGCN-Diagonal-PerRel-DistMult:  30%|███       | 30/100 [01:38<03:44,  3.21s/it]

Epoch  30 | Loss: 0.3337


Training R-LightGCN-Diagonal-PerRel-DistMult:  35%|███▌      | 35/100 [01:54<03:36,  3.33s/it]

Epoch  35 | Loss: 0.2945


Training R-LightGCN-Diagonal-PerRel-DistMult:  40%|████      | 40/100 [02:11<03:20,  3.33s/it]

Epoch  40 | Loss: 0.2610


Training R-LightGCN-Diagonal-PerRel-DistMult:  45%|████▌     | 45/100 [02:27<02:59,  3.27s/it]

Epoch  45 | Loss: 0.2337


Training R-LightGCN-Diagonal-PerRel-DistMult:  50%|█████     | 50/100 [02:44<02:45,  3.31s/it]

Epoch  50 | Loss: 0.2087


Training R-LightGCN-Diagonal-PerRel-DistMult:  53%|█████▎    | 53/100 [02:57<02:37,  3.34s/it]

Early stopping triggered at epoch 54
Validation metrics for R-LightGCN-Diagonal-PerRel-DistMult: {'auc': 0.9007736444473267, 'hits@1': 0.3917270929466051, 'hits@5': 0.700560316413975, 'hits@10': 0.7930125247198417}
Final model saved: /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Diagonal-PerRel-DistMult.pt
Per-model metrics saved to /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/R-LightGCN-Diagonal-PerRel-DistMult_metrics.txt
Combined metrics table saved to /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/all_models_metrics.txt

All ablation experiments completed. Metrics saved in /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints





In [None]:
# -----------------------------
# 1️⃣ Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# 2️⃣ Define number of nodes and relations
# -----------------------------
num_nodes = num_entities        # e.g., 40943 for WN18RR
# num_relations = num_relations   # original number of relations, e.g., 11
# Use the number of relations including inverses from dm_typed for consistency
num_relations_typed = num_relations_with_inv # e.g., 22 for WN18RR with inverses

# -----------------------------
# Self-loops
# -----------------------------
# Create self-loop edges
self_loop_edges = torch.arange(num_nodes, device=device)
self_loop_edges = torch.stack([self_loop_edges, self_loop_edges], dim=0)  # [2, num_nodes]

# Assign a new relation ID for self-loops, outside the range of existing typed relations
self_loop_rel_id = num_relations_typed  # the next available relation id
self_loop_types = torch.full((num_nodes,), self_loop_rel_id, dtype=torch.long, device=device)

# -----------------------------
# Combine with training edges
# -----------------------------
# Use triples from dm_typed which already have relation IDs from 0 to num_relations_typed - 1
rel_edge_index = dm_typed._train_triples[:, [0, 2]].t().contiguous().to(device)  # [2, N]
edge_type = dm_typed._train_triples[:, 1].to(device)  # [N]

rel_edge_index = torch.cat([rel_edge_index, self_loop_edges], dim=1)
edge_type = torch.cat([edge_type, self_loop_types], dim=0)

num_relations_final = num_relations_typed + 1  # include original, inverse, and self-loop relation

# -----------------------------
# Validation triples
# -----------------------------
val_triples = torch.as_tensor(dm_typed._valid_triples, device=device)

# -----------------------------
# DataLoader for training triples
# -----------------------------
# Wrap training triples in a TensorDataset and DataLoader
train_triples_tensor = torch.as_tensor(dm_typed._train_triples, device=device)
train_dataset = TensorDataset(train_triples_tensor)
train_loader = DataLoader(train_dataset, batch_size=2048, shuffle=True)

# -----------------------------
# Ablation configs
# -----------------------------
ABLATION_CONFIGS = [
    {'name':'R-LightGCN-Basic-Global-Dot','relation_transform':'scalar','norm_type':'global','decoder':'dot'},
    {'name':'R-LightGCN-Diagonal-Global-Dot','relation_transform':'diagonal','norm_type':'global','decoder':'dot'},
    {'name':'R-LightGCN-Scalar-Global-DistMult','relation_transform':'scalar','norm_type':'global','decoder':'distmult'},
    {'name':'R-LightGCN-Diagonal-Global-DistMult','relation_transform':'diagonal','norm_type':'global','decoder':'distmult'},
    {'name':'R-LightGCN-Scalar-PerRel-Dot','relation_transform':'scalar','norm_type':'per_relation','decoder':'dot'},
    {'name':'R-LightGCN-Diagonal-PerRel-Dot','relation_transform':'diagonal','norm_type':'per_relation','decoder':'dot'},
    {'name':'R-LightGCN-Scalar-PerRel-DistMult','relation_transform':'scalar','norm_type':'per_relation','decoder':'distmult'},
    {'name':'R-LightGCN-Diagonal-PerRel-DistMult','relation_transform':'diagonal','norm_type':'per_relation','decoder':'distmult'}
]

# -----------------------------
# Run ablation
# -----------------------------
all_metrics = run_r_lightgcn_ablation(
    num_nodes=num_nodes,
    num_relations=num_relations_final, # Pass the final number of relations including self-loops
    train_loader=train_loader,
    val_triples=val_triples,
    edge_index=rel_edge_index,
    edge_type=edge_type,
    emb_dim=64,
    num_layers=3,
    lr=0.001,
    epochs=100,
    batch_size=2048,
    patience=10,
    save_dir="/content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints/WN18RR",
    ABLATION_CONFIGS=ABLATION_CONFIGS,
    device=device
)


=== Training R-LightGCN-Basic-Global-Dot ===


Training R-LightGCN-Basic-Global-Dot:   1%|          | 1/100 [00:03<05:19,  3.23s/it]

Epoch   1 | Loss: 0.7935


Training R-LightGCN-Basic-Global-Dot:   2%|▏         | 2/100 [00:06<05:06,  3.12s/it]

Epoch   2 | Loss: 0.7081


Training R-LightGCN-Basic-Global-Dot:   3%|▎         | 3/100 [00:09<05:06,  3.16s/it]

Epoch   3 | Loss: 0.6930


Training R-LightGCN-Basic-Global-Dot:   5%|▌         | 5/100 [00:15<04:56,  3.12s/it]

Epoch   5 | Loss: 0.6812


Training R-LightGCN-Basic-Global-Dot:  10%|█         | 10/100 [00:31<04:43,  3.15s/it]

Epoch  10 | Loss: 0.6721


Training R-LightGCN-Basic-Global-Dot:  15%|█▌        | 15/100 [00:46<04:28,  3.16s/it]

Epoch  15 | Loss: 0.6697


Training R-LightGCN-Basic-Global-Dot:  20%|██        | 20/100 [01:01<04:03,  3.05s/it]

Epoch  20 | Loss: 0.6673


Training R-LightGCN-Basic-Global-Dot:  25%|██▌       | 25/100 [01:16<03:45,  3.01s/it]

Epoch  25 | Loss: 0.6662


Training R-LightGCN-Basic-Global-Dot:  30%|███       | 30/100 [01:32<03:34,  3.07s/it]

Epoch  30 | Loss: 0.6656


Training R-LightGCN-Basic-Global-Dot:  35%|███▌      | 35/100 [01:47<03:13,  2.98s/it]

Epoch  35 | Loss: 0.6653


Training R-LightGCN-Basic-Global-Dot:  40%|████      | 40/100 [02:02<03:00,  3.02s/it]

Epoch  40 | Loss: 0.6647


Training R-LightGCN-Basic-Global-Dot:  42%|████▏     | 42/100 [02:11<03:02,  3.14s/it]


Early stopping triggered at epoch 43
Validation metrics for R-LightGCN-Basic-Global-Dot: {'auc': 0.9386716485023499, 'hits@1': 0.6293671720500988, 'hits@5': 0.8182267633487146, 'hits@10': 0.8637112722478576}

=== Training R-LightGCN-Diagonal-Global-Dot ===


Training R-LightGCN-Diagonal-Global-Dot:   1%|          | 1/100 [00:03<05:17,  3.20s/it]

Epoch   1 | Loss: 0.7944


Training R-LightGCN-Diagonal-Global-Dot:   2%|▏         | 2/100 [00:06<05:41,  3.49s/it]

Epoch   2 | Loss: 0.7084


Training R-LightGCN-Diagonal-Global-Dot:   3%|▎         | 3/100 [00:10<05:24,  3.34s/it]

Epoch   3 | Loss: 0.6931


Training R-LightGCN-Diagonal-Global-Dot:   5%|▌         | 5/100 [00:16<05:02,  3.19s/it]

Epoch   5 | Loss: 0.6809


Training R-LightGCN-Diagonal-Global-Dot:  10%|█         | 10/100 [00:33<04:59,  3.33s/it]

Epoch  10 | Loss: 0.6725


Training R-LightGCN-Diagonal-Global-Dot:  15%|█▌        | 15/100 [00:49<04:48,  3.39s/it]

Epoch  15 | Loss: 0.6698


Training R-LightGCN-Diagonal-Global-Dot:  20%|██        | 20/100 [01:06<04:21,  3.27s/it]

Epoch  20 | Loss: 0.6680


Training R-LightGCN-Diagonal-Global-Dot:  25%|██▌       | 25/100 [01:22<04:02,  3.24s/it]

Epoch  25 | Loss: 0.6674


Training R-LightGCN-Diagonal-Global-Dot:  29%|██▉       | 29/100 [01:39<04:03,  3.43s/it]


Epoch  30 | Loss: 0.6664
Early stopping triggered at epoch 30
Validation metrics for R-LightGCN-Diagonal-Global-Dot: {'auc': 0.9391459226608276, 'hits@1': 0.6359591298615689, 'hits@5': 0.8182267633487146, 'hits@10': 0.8620632827949901}

=== Training R-LightGCN-Scalar-Global-DistMult ===


Training R-LightGCN-Scalar-Global-DistMult:   1%|          | 1/100 [00:02<04:49,  2.93s/it]

Epoch   1 | Loss: 1.3142


Training R-LightGCN-Scalar-Global-DistMult:   2%|▏         | 2/100 [00:05<04:33,  2.79s/it]

Epoch   2 | Loss: 1.1777


Training R-LightGCN-Scalar-Global-DistMult:   3%|▎         | 3/100 [00:08<04:36,  2.85s/it]

Epoch   3 | Loss: 1.0781


Training R-LightGCN-Scalar-Global-DistMult:   5%|▌         | 5/100 [00:14<04:49,  3.04s/it]

Epoch   5 | Loss: 0.9285


Training R-LightGCN-Scalar-Global-DistMult:  10%|█         | 10/100 [00:30<04:40,  3.11s/it]

Epoch  10 | Loss: 0.6747


Training R-LightGCN-Scalar-Global-DistMult:  15%|█▌        | 15/100 [00:45<04:16,  3.02s/it]

Epoch  15 | Loss: 0.5147


Training R-LightGCN-Scalar-Global-DistMult:  20%|██        | 20/100 [01:00<03:59,  2.99s/it]

Epoch  20 | Loss: 0.4057


Training R-LightGCN-Scalar-Global-DistMult:  25%|██▌       | 25/100 [01:15<03:48,  3.04s/it]

Epoch  25 | Loss: 0.3331


Training R-LightGCN-Scalar-Global-DistMult:  30%|███       | 30/100 [01:30<03:26,  2.95s/it]

Epoch  30 | Loss: 0.2793


Training R-LightGCN-Scalar-Global-DistMult:  35%|███▌      | 35/100 [01:45<03:11,  2.94s/it]

Epoch  35 | Loss: 0.2381


Training R-LightGCN-Scalar-Global-DistMult:  40%|████      | 40/100 [02:00<03:04,  3.07s/it]

Epoch  40 | Loss: 0.2045


Training R-LightGCN-Scalar-Global-DistMult:  45%|████▌     | 45/100 [02:15<02:49,  3.07s/it]

Epoch  45 | Loss: 0.1766


Training R-LightGCN-Scalar-Global-DistMult:  50%|█████     | 50/100 [02:30<02:32,  3.05s/it]

Epoch  50 | Loss: 0.1544


Training R-LightGCN-Scalar-Global-DistMult:  55%|█████▌    | 55/100 [02:46<02:19,  3.10s/it]

Epoch  55 | Loss: 0.1361


Training R-LightGCN-Scalar-Global-DistMult:  60%|██████    | 60/100 [03:01<02:00,  3.00s/it]

Epoch  60 | Loss: 0.1195


Training R-LightGCN-Scalar-Global-DistMult:  65%|██████▌   | 65/100 [03:16<01:44,  2.97s/it]

Epoch  65 | Loss: 0.1064


Training R-LightGCN-Scalar-Global-DistMult:  67%|██████▋   | 67/100 [03:25<01:41,  3.06s/it]


Early stopping triggered at epoch 68
Validation metrics for R-LightGCN-Scalar-Global-DistMult: {'auc': 0.9130885601043701, 'hits@1': 0.5416941331575478, 'hits@5': 0.7656558998022412, 'hits@10': 0.819215557020435}

=== Training R-LightGCN-Diagonal-Global-DistMult ===


Training R-LightGCN-Diagonal-Global-DistMult:   1%|          | 1/100 [00:03<05:11,  3.14s/it]

Epoch   1 | Loss: 1.2911


Training R-LightGCN-Diagonal-Global-DistMult:   2%|▏         | 2/100 [00:06<05:07,  3.14s/it]

Epoch   2 | Loss: 1.1501


Training R-LightGCN-Diagonal-Global-DistMult:   3%|▎         | 3/100 [00:09<05:12,  3.22s/it]

Epoch   3 | Loss: 1.0530


Training R-LightGCN-Diagonal-Global-DistMult:   5%|▌         | 5/100 [00:15<05:04,  3.20s/it]

Epoch   5 | Loss: 0.9083


Training R-LightGCN-Diagonal-Global-DistMult:  10%|█         | 10/100 [00:32<04:47,  3.20s/it]

Epoch  10 | Loss: 0.6795


Training R-LightGCN-Diagonal-Global-DistMult:  15%|█▌        | 15/100 [00:48<04:36,  3.26s/it]

Epoch  15 | Loss: 0.5341


Training R-LightGCN-Diagonal-Global-DistMult:  20%|██        | 20/100 [01:05<04:23,  3.30s/it]

Epoch  20 | Loss: 0.4364


Training R-LightGCN-Diagonal-Global-DistMult:  25%|██▌       | 25/100 [01:21<04:05,  3.27s/it]

Epoch  25 | Loss: 0.3643


Training R-LightGCN-Diagonal-Global-DistMult:  30%|███       | 30/100 [01:37<03:51,  3.30s/it]

Epoch  30 | Loss: 0.3107


Training R-LightGCN-Diagonal-Global-DistMult:  35%|███▌      | 35/100 [01:53<03:28,  3.21s/it]

Epoch  35 | Loss: 0.2681


Training R-LightGCN-Diagonal-Global-DistMult:  40%|████      | 40/100 [02:10<03:11,  3.19s/it]

Epoch  40 | Loss: 0.2326


Training R-LightGCN-Diagonal-Global-DistMult:  45%|████▌     | 45/100 [02:26<03:06,  3.38s/it]

Epoch  45 | Loss: 0.2023


Training R-LightGCN-Diagonal-Global-DistMult:  45%|████▌     | 45/100 [02:29<03:03,  3.33s/it]


Early stopping triggered at epoch 46
Validation metrics for R-LightGCN-Diagonal-Global-DistMult: {'auc': 0.9006655216217041, 'hits@1': 0.4382003955174687, 'hits@5': 0.716381015161503, 'hits@10': 0.7928477257745551}

=== Training R-LightGCN-Scalar-PerRel-Dot ===


Training R-LightGCN-Scalar-PerRel-Dot:   1%|          | 1/100 [00:02<04:49,  2.92s/it]

Epoch   1 | Loss: 0.7258


Training R-LightGCN-Scalar-PerRel-Dot:   2%|▏         | 2/100 [00:05<04:46,  2.92s/it]

Epoch   2 | Loss: 0.6856


Training R-LightGCN-Scalar-PerRel-Dot:   3%|▎         | 3/100 [00:09<05:00,  3.09s/it]

Epoch   3 | Loss: 0.6788


Training R-LightGCN-Scalar-PerRel-Dot:   5%|▌         | 5/100 [00:14<04:43,  2.99s/it]

Epoch   5 | Loss: 0.6736


Training R-LightGCN-Scalar-PerRel-Dot:  10%|█         | 10/100 [00:30<04:28,  2.99s/it]

Epoch  10 | Loss: 0.6694


Training R-LightGCN-Scalar-PerRel-Dot:  15%|█▌        | 15/100 [00:45<04:19,  3.05s/it]

Epoch  15 | Loss: 0.6680


Training R-LightGCN-Scalar-PerRel-Dot:  20%|██        | 20/100 [01:00<03:58,  2.98s/it]

Epoch  20 | Loss: 0.6668


Training R-LightGCN-Scalar-PerRel-Dot:  25%|██▌       | 25/100 [01:16<03:47,  3.04s/it]

Epoch  25 | Loss: 0.6659


Training R-LightGCN-Scalar-PerRel-Dot:  26%|██▌       | 26/100 [01:22<03:54,  3.17s/it]


Early stopping triggered at epoch 27
Validation metrics for R-LightGCN-Scalar-PerRel-Dot: {'auc': 0.9368771910667419, 'hits@1': 0.6272247857613711, 'hits@5': 0.8195451549110085, 'hits@10': 0.8627224785761372}

=== Training R-LightGCN-Diagonal-PerRel-Dot ===


Training R-LightGCN-Diagonal-PerRel-Dot:   1%|          | 1/100 [00:02<04:54,  2.97s/it]

Epoch   1 | Loss: 0.7264


Training R-LightGCN-Diagonal-PerRel-Dot:   2%|▏         | 2/100 [00:06<05:02,  3.08s/it]

Epoch   2 | Loss: 0.6857


Training R-LightGCN-Diagonal-PerRel-Dot:   3%|▎         | 3/100 [00:09<05:20,  3.30s/it]

Epoch   3 | Loss: 0.6795


Training R-LightGCN-Diagonal-PerRel-Dot:   5%|▌         | 5/100 [00:16<05:04,  3.20s/it]

Epoch   5 | Loss: 0.6740


Training R-LightGCN-Diagonal-PerRel-Dot:  10%|█         | 10/100 [00:32<04:58,  3.31s/it]

Epoch  10 | Loss: 0.6697


Training R-LightGCN-Diagonal-PerRel-Dot:  15%|█▌        | 15/100 [00:49<04:44,  3.34s/it]

Epoch  15 | Loss: 0.6683


Training R-LightGCN-Diagonal-PerRel-Dot:  20%|██        | 20/100 [01:05<04:22,  3.28s/it]

Epoch  20 | Loss: 0.6671


Training R-LightGCN-Diagonal-PerRel-Dot:  25%|██▌       | 25/100 [01:22<04:08,  3.32s/it]

Epoch  25 | Loss: 0.6666


Training R-LightGCN-Diagonal-PerRel-Dot:  30%|███       | 30/100 [01:38<03:45,  3.22s/it]

Epoch  30 | Loss: 0.6662


Training R-LightGCN-Diagonal-PerRel-Dot:  35%|███▌      | 35/100 [01:55<03:33,  3.28s/it]

Epoch  35 | Loss: 0.6658


Training R-LightGCN-Diagonal-PerRel-Dot:  40%|████      | 40/100 [02:11<03:24,  3.40s/it]

Epoch  40 | Loss: 0.6655


Training R-LightGCN-Diagonal-PerRel-Dot:  43%|████▎     | 43/100 [02:24<03:12,  3.37s/it]


Early stopping triggered at epoch 44
Validation metrics for R-LightGCN-Diagonal-PerRel-Dot: {'auc': 0.9380567073822021, 'hits@1': 0.6371127224785761, 'hits@5': 0.8149307844429796, 'hits@10': 0.8647000659195782}

=== Training R-LightGCN-Scalar-PerRel-DistMult ===


Training R-LightGCN-Scalar-PerRel-DistMult:   1%|          | 1/100 [00:02<04:46,  2.89s/it]

Epoch   1 | Loss: 1.2792


Training R-LightGCN-Scalar-PerRel-DistMult:   2%|▏         | 2/100 [00:05<04:43,  2.89s/it]

Epoch   2 | Loss: 1.1477


Training R-LightGCN-Scalar-PerRel-DistMult:   3%|▎         | 3/100 [00:08<04:41,  2.90s/it]

Epoch   3 | Loss: 1.0602


Training R-LightGCN-Scalar-PerRel-DistMult:   5%|▌         | 5/100 [00:14<04:43,  2.98s/it]

Epoch   5 | Loss: 0.9222


Training R-LightGCN-Scalar-PerRel-DistMult:  10%|█         | 10/100 [00:29<04:28,  2.99s/it]

Epoch  10 | Loss: 0.6918


Training R-LightGCN-Scalar-PerRel-DistMult:  15%|█▌        | 15/100 [00:44<04:17,  3.03s/it]

Epoch  15 | Loss: 0.5246


Training R-LightGCN-Scalar-PerRel-DistMult:  20%|██        | 20/100 [00:59<03:57,  2.97s/it]

Epoch  20 | Loss: 0.4301


Training R-LightGCN-Scalar-PerRel-DistMult:  25%|██▌       | 25/100 [01:14<03:40,  2.94s/it]

Epoch  25 | Loss: 0.3666


Training R-LightGCN-Scalar-PerRel-DistMult:  30%|███       | 30/100 [01:30<03:33,  3.04s/it]

Epoch  30 | Loss: 0.3157


Training R-LightGCN-Scalar-PerRel-DistMult:  35%|███▌      | 35/100 [01:45<03:16,  3.03s/it]

Epoch  35 | Loss: 0.2703


Training R-LightGCN-Scalar-PerRel-DistMult:  40%|████      | 40/100 [02:00<03:03,  3.06s/it]

Epoch  40 | Loss: 0.2373


Training R-LightGCN-Scalar-PerRel-DistMult:  45%|████▌     | 45/100 [02:15<02:51,  3.11s/it]

Epoch  45 | Loss: 0.2131


Training R-LightGCN-Scalar-PerRel-DistMult:  50%|█████     | 50/100 [02:30<02:30,  3.01s/it]

Epoch  50 | Loss: 0.1913


Training R-LightGCN-Scalar-PerRel-DistMult:  53%|█████▎    | 53/100 [02:42<02:24,  3.07s/it]


Early stopping triggered at epoch 54
Validation metrics for R-LightGCN-Scalar-PerRel-DistMult: {'auc': 0.9073912501335144, 'hits@1': 0.4159525379037574, 'hits@5': 0.7127554383651945, 'hits@10': 0.8066908371786421}

=== Training R-LightGCN-Diagonal-PerRel-DistMult ===


Training R-LightGCN-Diagonal-PerRel-DistMult:   1%|          | 1/100 [00:02<04:51,  2.94s/it]

Epoch   1 | Loss: 1.2508


Training R-LightGCN-Diagonal-PerRel-DistMult:   2%|▏         | 2/100 [00:05<04:47,  2.94s/it]

Epoch   2 | Loss: 1.1156


Training R-LightGCN-Diagonal-PerRel-DistMult:   3%|▎         | 3/100 [00:09<05:06,  3.16s/it]

Epoch   3 | Loss: 1.0359


Training R-LightGCN-Diagonal-PerRel-DistMult:   5%|▌         | 5/100 [00:15<05:04,  3.21s/it]

Epoch   5 | Loss: 0.9130


Training R-LightGCN-Diagonal-PerRel-DistMult:  10%|█         | 10/100 [00:32<04:53,  3.26s/it]

Epoch  10 | Loss: 0.6913


Training R-LightGCN-Diagonal-PerRel-DistMult:  15%|█▌        | 15/100 [00:48<04:32,  3.21s/it]

Epoch  15 | Loss: 0.5462


Training R-LightGCN-Diagonal-PerRel-DistMult:  20%|██        | 20/100 [01:04<04:24,  3.30s/it]

Epoch  20 | Loss: 0.4524


Training R-LightGCN-Diagonal-PerRel-DistMult:  25%|██▌       | 25/100 [01:21<04:09,  3.33s/it]

Epoch  25 | Loss: 0.3691


Training R-LightGCN-Diagonal-PerRel-DistMult:  30%|███       | 30/100 [01:37<03:45,  3.23s/it]

Epoch  30 | Loss: 0.2954


Training R-LightGCN-Diagonal-PerRel-DistMult:  35%|███▌      | 35/100 [01:53<03:27,  3.19s/it]

Epoch  35 | Loss: 0.2478


Training R-LightGCN-Diagonal-PerRel-DistMult:  40%|████      | 40/100 [02:10<03:18,  3.30s/it]

Epoch  40 | Loss: 0.2109


Training R-LightGCN-Diagonal-PerRel-DistMult:  45%|████▌     | 45/100 [02:26<03:01,  3.31s/it]

Epoch  45 | Loss: 0.1797


Training R-LightGCN-Diagonal-PerRel-DistMult:  50%|█████     | 50/100 [02:42<02:42,  3.24s/it]

Epoch  50 | Loss: 0.1573


Training R-LightGCN-Diagonal-PerRel-DistMult:  55%|█████▌    | 55/100 [02:59<02:27,  3.28s/it]

Epoch  55 | Loss: 0.1394


Training R-LightGCN-Diagonal-PerRel-DistMult:  60%|██████    | 60/100 [03:15<02:07,  3.19s/it]

Epoch  60 | Loss: 0.1247


Training R-LightGCN-Diagonal-PerRel-DistMult:  61%|██████    | 61/100 [03:21<02:09,  3.31s/it]

Early stopping triggered at epoch 62
Validation metrics for R-LightGCN-Diagonal-PerRel-DistMult: {'auc': 0.9018722772598267, 'hits@1': 0.46209624258404747, 'hits@5': 0.728246539222149, 'hits@10': 0.7999340804218853}

All ablation experiments completed. Metrics saved in /content/drive/My Drive/NTU/Y3S1/SC4020 Data Analytics and Mining/checkpoints



