# Import Libraries

In [None]:
import random
from pathlib import Path
from collections import defaultdict

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import degree, add_self_loops
from torch_geometric.nn import MessagePassing
from sklearn.metrics import roc_auc_score

import sys
current_dir = Path(__file__).parent if '__file__' in globals() else Path.cwd()
parent_dir = current_dir.parent
sys.path.append(str(parent_dir))
from utility.dataset_loader import KGDataModuleCollapsed, KGDataModuleTyped

# Dataset Loading


In [47]:
# Dataset loading using dataset_loader.py
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dataset paths
train_path = Path("../WN18RR/train.txt")
valid_path = Path("../WN18RR/valid.txt")  
test_path = Path("../WN18RR/test.txt")

# Initialize data modules
# Collapsed mode for LightGCN (untyped pairs)
dm_collapsed = KGDataModuleCollapsed(
    train_path=train_path,
    valid_path=valid_path,
    test_path=test_path,
    batch_size=4096,
    shuffle=True,
    add_reverse=True
)

# Typed mode for R-LightGCN (typed triples)  
dm_typed = KGDataModuleTyped(
    train_path=train_path,
    valid_path=valid_path,
    test_path=test_path,
    batch_size=4096,
    shuffle=True,
    add_reverse=True,
    reverse_relation_strategy="duplicate_rel"
)

num_entities = len(dm_collapsed.ent2id)
num_relations = len(dm_collapsed.rel2id)  # Original relations
num_relations_with_inv = len(dm_typed.rel2id)  # With inverse relations

print(f"Dataset: WN18RR")
print(f"Entities: {num_entities:,}")
print(f"Original Relations: {num_relations:,}")
print(f"Relations (with inverse): {num_relations_with_inv:,}")
print(f"Training pairs: {len(dm_collapsed._train_pairs):,}")
print(f"Training triples (typed): {len(dm_typed._train_triples):,}")

# Build edge_index for LightGCN (collapsed pairs)
edge_index = dm_collapsed._train_pairs.t().contiguous().to(device)
edge_index, _ = add_self_loops(edge_index, num_nodes=num_entities)

print(f"Graph edges (with self-loops): {edge_index.shape[1]:,}")
print(f"Graph edge_index shape: {tuple(edge_index.shape)}")

train_loader_collapsed = dm_collapsed.train_loader()
val_loader_collapsed = dm_collapsed.val_loader()
test_loader_collapsed = dm_collapsed.test_loader()

train_loader_typed = dm_typed.train_loader()
val_loader_typed = dm_typed.val_loader()
test_loader_typed = dm_typed.test_loader()

print(f"Data loaders created:")
print(f"Train batches (collapsed): {len(train_loader_collapsed)}")
print(f"Val batches (collapsed): {len(val_loader_collapsed) if val_loader_collapsed else 0}")
print(f"Test batches (collapsed): {len(test_loader_collapsed) if test_loader_collapsed else 0}")

Using device: cpu
Dataset: WN18RR
Entities: 40,943
Original Relations: 11
Relations (with inverse): 22
Training pairs: 173,670
Training triples (typed): 173,670
Graph edges (with self-loops): 214,613
Graph edge_index shape: (2, 214613)
Data loaders created:
Train batches (collapsed): 43
Val batches (collapsed): 2
Test batches (collapsed): 2
Dataset: WN18RR
Entities: 40,943
Original Relations: 11
Relations (with inverse): 22
Training pairs: 173,670
Training triples (typed): 173,670
Graph edges (with self-loops): 214,613
Graph edge_index shape: (2, 214613)
Data loaders created:
Train batches (collapsed): 43
Val batches (collapsed): 2
Test batches (collapsed): 2


# Helper Functions

In [48]:
# Model Saving Utility Functions
def save_model_checkpoint(model, optimizer, hyperparameters, final_test_metrics, 
                         training_history, model_name, filename):
    """
    Save model checkpoint with comprehensive information.
    
    Args:
        model: The trained model
        optimizer: The optimizer used
        hyperparameters: Dict with training hyperparameters
        final_test_metrics: Final test evaluation results
        training_history: Training metrics history
        model_name: Name of the model for display
        filename: Output filename for the checkpoint
    """
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'hyperparameters': hyperparameters,
        'final_test_metrics': final_test_metrics,
        'training_history': training_history,
        'model_name': model_name
    }
    
    torch.save(checkpoint, filename)
    print(f"{model_name} model checkpoint saved to {filename}")

def load_model_checkpoint(filename, model_class, device, **model_kwargs):
    """
    Load model checkpoint and restore model state.
    
    Args:
        filename: Path to checkpoint file
        model_class: Model class to instantiate
        device: Device to load model on
        **model_kwargs: Arguments for model instantiation
    
    Returns:
        model: Loaded model
        checkpoint: Full checkpoint data
    """
    checkpoint = torch.load(filename, map_location=device)
    
    # Create model instance
    model = model_class(**model_kwargs).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    print(f"Model loaded from {filename}")
    print(f"Model: {checkpoint.get('model_name', 'Unknown')}")
    print(f"Hyperparameters: {checkpoint.get('hyperparameters', {})}")
    
    if 'final_test_metrics' in checkpoint:
        metrics = checkpoint['final_test_metrics']
        if 'head' in metrics and 'tail' in metrics:
            auc_avg = (metrics['head']['auc'] + metrics['tail']['auc']) / 2
            hits10_avg = (metrics['head']['hits@10'] + metrics['tail']['hits@10']) / 2
            print(f"Test AUC (avg): {auc_avg:.4f}")
            print(f"Test Hits@10 (avg): {hits10_avg:.4f}")
    
    return model, checkpoint

In [59]:
class MetricsTracker:
    """Simplified metrics tracker without plotting"""
    def __init__(self):
        self.metrics = defaultdict(list)
    
    def add(self, epoch, **kwargs):
        self.metrics['epoch'].append(epoch)
        for key, value in kwargs.items():
            self.metrics[key].append(value)
    
    def get_best_epoch(self, metric='val_auc_head'):
        """Get the epoch with the best performance for a given metric"""
        if metric not in self.metrics or not self.metrics[metric]:
            return 0
        
        # For AUC and Hits@K, higher is better
        best_idx = np.argmax(self.metrics[metric])
        return self.metrics['epoch'][best_idx]
    
    def save_to_file(self, filepath):
        with open(filepath, 'w') as f:
            f.write("Epoch\tLoss\tAUC\tHits@1\tHits@5\tHits@10\n")
            for i in range(len(self.metrics['epoch'])):
                epoch = self.metrics['epoch'][i]
                loss = self.metrics['loss'][i]
                auc = self.metrics['val_auc'][i]
                h1 = self.metrics['val_hits1'][i]
                h5 = self.metrics['val_hits5'][i]
                h10 = self.metrics['val_hits10'][i]
                
                f.write(f"{epoch}\t{loss:.6f}\t{auc:.6f}\t{h1:.6f}\t{h5:.6f}\t{h10:.6f}\n")


class EarlyStopping:
    def __init__(self, patience=10, metric='val_auc_head', mode='max', min_delta=0.001):
        self.patience = patience
        self.metric = metric
        self.mode = mode  # 'max' for AUC, Hits@K; 'min' for loss
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.should_stop = False
        
    def __call__(self, metrics_tracker):
        current_score = metrics_tracker.metrics[self.metric][-1] if self.metric in metrics_tracker.metrics else 0
        
        if self.best_score is None:
            self.best_score = current_score
        else:
            if self.mode == 'max':
                improved = current_score > self.best_score + self.min_delta
            else:  # mode == 'min'
                improved = current_score < self.best_score - self.min_delta
                
            if improved:
                self.best_score = current_score
                self.counter = 0
            else:
                self.counter += 1
                
        self.should_stop = self.counter >= self.patience
        return self.should_stop

In [50]:
# Training functions with data loaders
def train_one_epoch_lightgcn(model, data_loader, optimizer, edge_index, device, num_entities, max_grad_norm=5.0):
    """
    Stable LightGCN training loop (BPR loss).
    Handles NaNs, exploding gradients, and large score differences gracefully.

    Args:
        model: LightGCN model
        data_loader: DataLoader yielding positive pairs (batch_size, 2)
        optimizer: Optimizer instance
        edge_index: Graph edges tensor on the correct device
        device: torch.device to run on
        num_entities: Total number of entities (int)
        max_grad_norm: Optional gradient clipping value
    """
    model.train()
    total_loss = 0.0
    num_batches = 0

    for batch_idx, batch in enumerate(data_loader):
        optimizer.zero_grad()

        # ------------------------------
        # Handle batch format
        # ------------------------------
        if isinstance(batch, (list, tuple)):
            pairs_tensor = batch[0]
        else:
            pairs_tensor = batch

        pairs_tensor = pairs_tensor.to(device)
        if pairs_tensor.dim() == 1:
            pairs_tensor = pairs_tensor.unsqueeze(0)

        pos_edges = pairs_tensor  # [batch_size, 2]
        num_pos = pos_edges.size(0)

        # ------------------------------
        # Create random negatives
        # ------------------------------
        neg_tail = torch.randint(0, num_entities, (num_pos,), device=device)
        neg_head = torch.randint(0, num_entities, (num_pos,), device=device)

        # ------------------------------
        # Encode embeddings (per-batch to avoid retain_graph issues)
        # ------------------------------
        embeddings = model.encode(edge_index)

        # Normalize to prevent exploding scores
        emb = embeddings / (embeddings.norm(dim=1, keepdim=True) + 1e-9)

        # ------------------------------
        # Compute scores
        # ------------------------------
        pos_scores = (emb[pos_edges[:, 0]] * emb[pos_edges[:, 1]]).sum(dim=1)
        neg_scores_tail = (emb[pos_edges[:, 0]] * emb[neg_tail]).sum(dim=1)
        neg_scores_head = (emb[neg_head] * emb[pos_edges[:, 1]]).sum(dim=1)

        # ------------------------------
        # Stable BPR loss
        # ------------------------------
        loss_tail = F.softplus(-(pos_scores - neg_scores_tail)).mean()
        loss_head = F.softplus(-(pos_scores - neg_scores_head)).mean()
        loss = loss_tail + loss_head

        # ------------------------------
        # NaN / inf guard
        # ------------------------------
        if not torch.isfinite(loss):
            print(f"[Batch {batch_idx}] NaN or inf detected — skipping this batch.")
            continue

        # ------------------------------
        # Backprop and update
        # ------------------------------
        loss.backward()

        # Optional gradient clipping (prevents explosions)
        if max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        # Track loss
        total_loss += loss.item()
        num_batches += 1

    avg_loss = total_loss / max(1, num_batches)
    return avg_loss

### Negative Sampling
- A training technique used in knowledge graph link prediction to create "negative examples", triples that are likely to be false. 
- Since knowledge graphs only contain positive facts (true triples), we need to artificially create negative examples for the model to learn what relationships are incorrect.
- Both head and tail corruptions are used to train the model to understand connections flowing in both directions. Model learns "What subjects fit this relation-object" and "What object fits this subject-relation".

In [51]:
@torch.no_grad()
def pairs_from_triples(triples: torch.LongTensor) -> torch.LongTensor:
    """
    Convert (h, r, t) -> pairs [2, N] = (h, t) for decoding on collapsed graph.
    """
    return triples[:, [0, 2]].t().contiguous()  # [2, N]

@torch.no_grad()
def negative_sample_heads(triples: torch.LongTensor, num_nodes: int) -> torch.LongTensor:
    """
    Corrupt heads: (h, r, t) -> (h', t)
    Returns pairs [2, N].
    """
    N = triples.size(0)
    neg_h = torch.randint(0, num_nodes, (N,), device=triples.device)
    t = triples[:, 2]
    return torch.stack([neg_h, t], dim=0)

@torch.no_grad()
def negative_sample_tails(triples: torch.LongTensor, num_nodes: int) -> torch.LongTensor:
    """
    Corrupt tails: (h, r, t) -> (h, t')
    Returns pairs [2, N].
    """
    N = triples.size(0)
    h = triples[:, 0]
    neg_t = torch.randint(0, num_nodes, (N,), device=triples.device)
    return torch.stack([h, neg_t], dim=0)


In [52]:
# -------- LightGCN layer --------
class LightGCNConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr='add')

    # Compute symmetric normalization term D^-0.5*A*D^-0.5 to propagate messages through normalized adjacency
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.clamp(min=1).pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return self.propagate(edge_index, x=x, norm=norm)

    # Scales the neighbor embeddings
    def message(self, x_j: torch.Tensor, norm: torch.Tensor) -> torch.Tensor:
        return norm.view(-1, 1) * x_j

# -------- LightGCN encoder + dot-product decoder --------
class LightGCN(nn.Module):
    # Initialize trainable node embeddings
    def __init__(self, num_nodes: int, emb_dim: int = 64, num_layers: int = 3):
        super().__init__()
        self.embedding = nn.Embedding(num_nodes, emb_dim)
        nn.init.xavier_uniform_(self.embedding.weight)
        self.convs = nn.ModuleList([LightGCNConv() for _ in range(num_layers)])
        self.num_layers = num_layers

    def encode(self, edge_index: torch.Tensor) -> torch.Tensor:
        x0 = self.embedding.weight
        out = x0
        x = x0
        # Each layer's output is accumulated and averaged
        for conv in self.convs:
            x = conv(x, edge_index)
            out = out + x
        return out / (self.num_layers + 1)

    # Compute dot product between node embeddings for each edge (positive or negative pair)
    @staticmethod
    def decode(z: torch.Tensor, pairs: torch.LongTensor) -> torch.Tensor:
        # pairs: [2, B] with [src; dst]
        return (z[pairs[0]] * z[pairs[1]]).sum(dim=1)

In [None]:
def evaluate_auc_hits(model, triples, num_entities, edge_index=None, batch_size=4096, device=None):
    """
    Evaluate AUC and Hits@K (K=1,5,10) for LightGCN or R-LightGCN.
    Unfiltered version — each positive triple is compared to 99 random negatives.

    Args:
        model: LightGCN or R-LightGCN with .encode()
        triples: torch.Tensor [N, 3] or [N, 2] validation/test triples
        num_entities: total number of entities
        edge_index: graph structure for encoding
        batch_size: number of triples per batch
        device: torch.device to use
    """
    model.eval()
    if device is None:
        device = next(model.parameters()).device

    # -------------------------
    # Helper functions
    # -------------------------
    def batch_iter(tensor, size):
        for i in range(0, len(tensor), size):
            yield tensor[i:i + size]

    def sample_negatives(pos_batch, num_entities):
        """Corrupt tail entities randomly."""
        neg_batch = pos_batch.clone()
        neg_batch[:, -1] = torch.randint(0, num_entities, (pos_batch.size(0),), device=pos_batch.device)
        return neg_batch

    # -------------------------
    # AUC computation
    # -------------------------
    scores_all, labels_all = [], []

    for pos in batch_iter(triples, batch_size):
        pos = pos.to(device)
        neg = sample_negatives(pos, num_entities)

        with torch.no_grad():
            if hasattr(model, 'encode'):
                emb = model.encode(edge_index)
                s_pos = (emb[pos[:, 0]] * emb[pos[:, -1]]).sum(dim=1)
                s_neg = (emb[neg[:, 0]] * emb[neg[:, -1]]).sum(dim=1)
            else:
                s_pos = model(pos)
                s_neg = model(neg)

        scores_all.append(torch.cat([s_pos, s_neg], 0).cpu().numpy())
        labels_all.append(np.concatenate([np.ones(len(s_pos)), np.zeros(len(s_neg))], 0))

    scores_all = np.concatenate(scores_all, 0)
    labels_all = np.concatenate(labels_all, 0)
    auc = roc_auc_score(labels_all, scores_all)

    # -------------------------
    # Hits@K computation (1,5,10)
    # -------------------------
    hits_at = {1: 0, 5: 0, 10: 0}
    n_trials = 0

    with torch.no_grad():
        ent = model.encode(edge_index) if hasattr(model, 'encode') else model.encoder(edge_index)

        for pos in batch_iter(triples, batch_size):
            pos = pos.to(device)
            B = pos.size(0)
            true_t = pos[:, -1]
            rand_t = torch.randint(0, num_entities, (B, 99), device=device)
            tails = torch.cat([true_t.unsqueeze(1), rand_t], dim=1)  # [B,100]

            e_h = ent[pos[:, 0]]                                    # [B,d]
            e_candidates = ent[tails]                               # [B,100,d]
            s = (e_h.unsqueeze(1) * e_candidates).sum(dim=2)        # [B,100]

            # rank position of true tail
            ranks = (s.argsort(dim=1, descending=True) == 0).nonzero()[:, 1] + 1  # 1-based

            for k in hits_at.keys():
                hits_at[k] += (ranks <= k).sum().item()

            n_trials += B

    # Normalize
    hits_at = {f"hits@{k}": v / n_trials for k, v in hits_at.items()}

    return {"auc": float(auc), **hits_at}


In [61]:
# ===============================
# Hyperparameters
# ===============================
lr = 1e-3
epochs = 100
emb_dim = 64
num_layers = 3
eval_every = 5
batch_size = 2048
patience = 10

# ===============================
# Setup
# ===============================
lightgcn_metrics_tracker = MetricsTracker()
lightgcn_early_stopping = EarlyStopping(patience=patience, metric='val_auc')

# Create model + optimizer
lightgcn_model = LightGCN(num_nodes=num_entities, emb_dim=emb_dim, num_layers=num_layers).to(device)

optimizer = torch.optim.Adam(lightgcn_model.parameters(), lr=lr)

print("Starting LightGCN training...")
print(f"Max epochs: {epochs}, Early stopping patience: {patience}")

# ===============================
# Training Loop
# ===============================
for epoch in tqdm(range(1, epochs + 1), desc="Training LightGCN"):
    avg_loss = train_one_epoch_lightgcn(
        lightgcn_model, train_loader_collapsed, optimizer, 
        edge_index, device, num_entities
    )

    if epoch <= 5 or epoch % 5 == 0:
        print(f"Epoch {epoch:3d}: Loss={avg_loss:.4f}")

    # ===============================
    # Evaluation
    # ===============================
    if epoch % eval_every == 0 or epoch == 1:
        print(f"\nEvaluating epoch {epoch}...")
        # Convert val_loader to flat tensor of pairs/triples
        all_val_data = []
        for batch in val_loader_collapsed:
            data_tensor = batch[0] if isinstance(batch, (list, tuple)) else batch
            if data_tensor.dim() == 1:
                data_tensor = data_tensor.unsqueeze(0)
            all_val_data.append(data_tensor)
        val_triples = torch.cat(all_val_data, dim=0)

        # Run simplified evaluation (AUC + Hits@K)
        val_metrics = evaluate_auc_hits(
            lightgcn_model,
            val_triples,
            num_entities=num_entities,
            edge_index=edge_index,
            batch_size=2048,
            device=device
        )

        # Print summary
        print(f"Epoch {epoch}: Loss={avg_loss:.4f}, "
              f"AUC={val_metrics['auc']:.4f}, "
              f"Hits@1={val_metrics['hits@1']:.4f}, "
              f"Hits@5={val_metrics['hits@5']:.4f}, "
              f"Hits@10={val_metrics['hits@10']:.4f}")

        # Log to tracker (simple flattening)
        lightgcn_metrics_tracker.add(
            epoch=epoch,
            loss=avg_loss,
            val_auc=val_metrics['auc'],
            val_hits1=val_metrics['hits@1'],
            val_hits5=val_metrics['hits@5'],
            val_hits10=val_metrics['hits@10']
        )

        # Early stopping
        if lightgcn_early_stopping(lightgcn_metrics_tracker):
            print(f"Early stopping at epoch {epoch} "
                  f"(Best AUC: {lightgcn_early_stopping.best_score:.4f})")
            break

print("\nLightGCN training completed!")

# ===============================
# Final Test Evaluation
# ===============================
print("\nRunning final test evaluation...")
all_test_data = []
for batch in test_loader_collapsed:
    data_tensor = batch[0] if isinstance(batch, (list, tuple)) else batch
    if data_tensor.dim() == 1:
        data_tensor = data_tensor.unsqueeze(0)
    all_test_data.append(data_tensor)
test_triples = torch.cat(all_test_data, dim=0)

lightgcn_final_test_metrics = evaluate_auc_hits(
    lightgcn_model,
    test_triples,
    num_entities=num_entities,
    edge_index=edge_index,
    batch_size=2048,
    device=device
)

print("\n[LightGCN WN18RR TEST RESULTS]")
print(f"AUC: {lightgcn_final_test_metrics['auc']:.4f}")
print(f"Hits@1: {lightgcn_final_test_metrics['hits@1']:.4f}")
print(f"Hits@5: {lightgcn_final_test_metrics['hits@5']:.4f}")
print(f"Hits@10: {lightgcn_final_test_metrics['hits@10']:.4f}")

# ===============================
# Save model + metrics
# ===============================
torch.save({
    'model_state_dict': lightgcn_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'hyperparameters': {
        'emb_dim': emb_dim,
        'num_layers': num_layers,
        'lr': lr,
        'num_entities': num_entities
    },
    'final_test_metrics': lightgcn_final_test_metrics,
    'training_history': lightgcn_metrics_tracker.metrics
}, "wn18rr_lightgcn_model.pt")

print("Model saved to wn18rr_lightgcn_model.pt")

lightgcn_metrics_tracker.save_to_file("wn18rr_lightgcn_metrics.txt")
print("Metrics saved to wn18rr_lightgcn_metrics.txt")


Starting LightGCN training...
Max epochs: 100, Early stopping patience: 10


Training LightGCN:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch   1: Loss=0.8249

Evaluating epoch 1...


Training LightGCN:   1%|          | 1/100 [00:07<12:30,  7.58s/it]

Epoch 1: Loss=0.8249, AUC=0.8821, Hits@1=0.6091, Hits@5=0.7426, Hits@10=0.7810


Training LightGCN:   2%|▏         | 2/100 [00:15<12:40,  7.76s/it]

Epoch   2: Loss=0.7204


Training LightGCN:   3%|▎         | 3/100 [00:22<12:12,  7.55s/it]

Epoch   3: Loss=0.7016


Training LightGCN:   4%|▍         | 4/100 [00:32<13:22,  8.36s/it]

Epoch   4: Loss=0.6920
Epoch   5: Loss=0.6867

Evaluating epoch 5...
Epoch   5: Loss=0.6867

Evaluating epoch 5...


Training LightGCN:   5%|▌         | 5/100 [00:40<13:01,  8.23s/it]

Epoch 5: Loss=0.6867, AUC=0.9153, Hits@1=0.6386, Hits@5=0.8090, Hits@10=0.8497


Training LightGCN:   9%|▉         | 9/100 [01:15<12:42,  8.38s/it]

Epoch  10: Loss=0.6751

Evaluating epoch 10...


Training LightGCN:  10%|█         | 10/100 [01:23<12:33,  8.38s/it]

Epoch 10: Loss=0.6751, AUC=0.9249, Hits@1=0.6177, Hits@5=0.8230, Hits@10=0.8645


Training LightGCN:  14%|█▍        | 14/100 [01:59<12:31,  8.74s/it]

Epoch  15: Loss=0.6723

Evaluating epoch 15...


Training LightGCN:  15%|█▌        | 15/100 [02:08<12:37,  8.91s/it]

Epoch 15: Loss=0.6723, AUC=0.9280, Hits@1=0.6119, Hits@5=0.8240, Hits@10=0.8724


Training LightGCN:  19%|█▉        | 19/100 [02:43<11:39,  8.64s/it]

Epoch  20: Loss=0.6703

Evaluating epoch 20...


Training LightGCN:  20%|██        | 20/100 [02:52<11:54,  8.93s/it]

Epoch 20: Loss=0.6703, AUC=0.9309, Hits@1=0.6056, Hits@5=0.8284, Hits@10=0.8724


Training LightGCN:  24%|██▍       | 24/100 [03:27<11:02,  8.71s/it]

Epoch  25: Loss=0.6694

Evaluating epoch 25...


Training LightGCN:  25%|██▌       | 25/100 [03:36<11:04,  8.86s/it]

Epoch 25: Loss=0.6694, AUC=0.9299, Hits@1=0.6088, Hits@5=0.8279, Hits@10=0.8726


Training LightGCN:  29%|██▉       | 29/100 [04:12<10:37,  8.98s/it]

Epoch  30: Loss=0.6683

Evaluating epoch 30...


Training LightGCN:  30%|███       | 30/100 [04:20<10:17,  8.83s/it]

Epoch 30: Loss=0.6683, AUC=0.9291, Hits@1=0.5992, Hits@5=0.8293, Hits@10=0.8743


Training LightGCN:  34%|███▍      | 34/100 [04:51<08:38,  7.86s/it]

Epoch  35: Loss=0.6680

Evaluating epoch 35...


Training LightGCN:  35%|███▌      | 35/100 [04:59<08:38,  7.98s/it]

Epoch 35: Loss=0.6680, AUC=0.9308, Hits@1=0.6046, Hits@5=0.8268, Hits@10=0.8733


Training LightGCN:  39%|███▉      | 39/100 [05:29<07:49,  7.69s/it]

Epoch  40: Loss=0.6672

Evaluating epoch 40...


Training LightGCN:  40%|████      | 40/100 [05:37<07:48,  7.80s/it]

Epoch 40: Loss=0.6672, AUC=0.9309, Hits@1=0.5948, Hits@5=0.8296, Hits@10=0.8738


Training LightGCN:  44%|████▍     | 44/100 [06:09<07:28,  8.01s/it]

Epoch  45: Loss=0.6666

Evaluating epoch 45...


Training LightGCN:  45%|████▌     | 45/100 [06:17<07:26,  8.11s/it]

Epoch 45: Loss=0.6666, AUC=0.9294, Hits@1=0.6032, Hits@5=0.8304, Hits@10=0.8716


Training LightGCN:  49%|████▉     | 49/100 [06:47<06:33,  7.71s/it]

Epoch  50: Loss=0.6666

Evaluating epoch 50...


Training LightGCN:  50%|█████     | 50/100 [06:57<06:48,  8.17s/it]

Epoch 50: Loss=0.6666, AUC=0.9301, Hits@1=0.5999, Hits@5=0.8319, Hits@10=0.8736


Training LightGCN:  54%|█████▍    | 54/100 [07:28<06:02,  7.87s/it]

Epoch  55: Loss=0.6665

Evaluating epoch 55...


Training LightGCN:  55%|█████▌    | 55/100 [07:36<05:53,  7.84s/it]

Epoch 55: Loss=0.6665, AUC=0.9328, Hits@1=0.5951, Hits@5=0.8289, Hits@10=0.8752


Training LightGCN:  59%|█████▉    | 59/100 [08:07<05:22,  7.87s/it]

Epoch  60: Loss=0.6659

Evaluating epoch 60...


Training LightGCN:  60%|██████    | 60/100 [08:16<05:32,  8.31s/it]

Epoch 60: Loss=0.6659, AUC=0.9322, Hits@1=0.5974, Hits@5=0.8298, Hits@10=0.8771


Training LightGCN:  64%|██████▍   | 64/100 [08:47<04:42,  7.84s/it]

Epoch  65: Loss=0.6658

Evaluating epoch 65...


Training LightGCN:  65%|██████▌   | 65/100 [08:56<04:42,  8.08s/it]

Epoch 65: Loss=0.6658, AUC=0.9295, Hits@1=0.6018, Hits@5=0.8238, Hits@10=0.8744


Training LightGCN:  69%|██████▉   | 69/100 [09:27<04:08,  8.01s/it]

Epoch  70: Loss=0.6662

Evaluating epoch 70...


Training LightGCN:  70%|███████   | 70/100 [09:36<04:07,  8.26s/it]

Epoch 70: Loss=0.6662, AUC=0.9329, Hits@1=0.5944, Hits@5=0.8261, Hits@10=0.8785


Training LightGCN:  74%|███████▍  | 74/100 [10:08<03:28,  8.03s/it]

Epoch  75: Loss=0.6657

Evaluating epoch 75...


Training LightGCN:  75%|███████▌  | 75/100 [10:16<03:19,  7.98s/it]

Epoch 75: Loss=0.6657, AUC=0.9331, Hits@1=0.5892, Hits@5=0.8296, Hits@10=0.8780


Training LightGCN:  79%|███████▉  | 79/100 [10:45<02:38,  7.54s/it]

Epoch  80: Loss=0.6654

Evaluating epoch 80...


Training LightGCN:  80%|████████  | 80/100 [10:54<02:37,  7.86s/it]

Epoch 80: Loss=0.6654, AUC=0.9324, Hits@1=0.5948, Hits@5=0.8314, Hits@10=0.8779


Training LightGCN:  84%|████████▍ | 84/100 [11:28<02:16,  8.51s/it]

Epoch  85: Loss=0.6656

Evaluating epoch 85...


Training LightGCN:  85%|████████▌ | 85/100 [11:36<02:06,  8.45s/it]

Epoch 85: Loss=0.6656, AUC=0.9306, Hits@1=0.5984, Hits@5=0.8281, Hits@10=0.8800


Training LightGCN:  89%|████████▉ | 89/100 [12:09<01:28,  8.03s/it]

Epoch  90: Loss=0.6651

Evaluating epoch 90...


Training LightGCN:  90%|█████████ | 90/100 [12:17<01:20,  8.10s/it]

Epoch 90: Loss=0.6651, AUC=0.9365, Hits@1=0.5972, Hits@5=0.8339, Hits@10=0.8823


Training LightGCN:  94%|█████████▍| 94/100 [12:47<00:46,  7.67s/it]

Epoch  95: Loss=0.6651

Evaluating epoch 95...


Training LightGCN:  95%|█████████▌| 95/100 [12:56<00:39,  7.96s/it]

Epoch 95: Loss=0.6651, AUC=0.9367, Hits@1=0.5943, Hits@5=0.8316, Hits@10=0.8809


Training LightGCN:  99%|█████████▉| 99/100 [13:26<00:07,  7.72s/it]

Epoch 100: Loss=0.6649

Evaluating epoch 100...


Training LightGCN: 100%|██████████| 100/100 [13:34<00:00,  8.15s/it]

Epoch 100: Loss=0.6649, AUC=0.9349, Hits@1=0.5944, Hits@5=0.8317, Hits@10=0.8802

LightGCN training completed!

Running final test evaluation...






[LightGCN WN18RR TEST RESULTS]
AUC: 0.9319
Hits@1: 0.5994
Hits@5: 0.8336
Hits@10: 0.8805
Model saved to wn18rr_lightgcn_model.pt
Metrics saved to wn18rr_lightgcn_metrics.txt
