In [15]:
# 05 - Training and Evaluation of LightGCN

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.utils import degree
import numpy as np
from tqdm.notebook import tqdm
import psutil
import gc
import os



In [16]:
## 1. Load Graph

from torch_geometric.data.data import DataTensorAttr, DataEdgeAttr
from torch_geometric.data.storage import GlobalStorage
import torch.serialization

torch.serialization.add_safe_globals([DataTensorAttr, DataEdgeAttr, GlobalStorage])

data = torch.load('data/processed/graph_data.pt')
num_users = data.num_users
num_books = data.num_books
num_nodes = data.num_nodes

print(data)

Data(x=[5582, 64], edge_index=[2, 95220], num_nodes=5582, num_users=3404, num_books=2178)


In [17]:
# ...existing code...
os.makedirs('data/processed', exist_ok=True)

required_files = [
    'data/processed/train_edge_index.pt',
    'data/processed/train_pos.pt',
    'data/processed/val_pos.pt',
    'data/processed/test_pos.pt'
]

def create_splits():
    print("Creating reproducible split...")
    # Ensure user → book direction for positive edges
    user_to_book = data.edge_index[:, data.edge_index[0] < num_users]  # only user → book

    num_pos = user_to_book.size(1)
    perm = torch.randperm(num_pos, generator=torch.Generator().manual_seed(42))

    train_size = int(0.8 * num_pos)
    val_size = int(0.1 * num_pos)

    train_pos = user_to_book[:, perm[:train_size]]
    val_pos = user_to_book[:, perm[train_size:train_size+val_size]]
    test_pos = user_to_book[:, perm[train_size+val_size:]]

    # Build full undirected graphs for propagation
    def make_undirected(pos_edges):
        rev = torch.stack([pos_edges[1], pos_edges[0]])
        return torch.cat([pos_edges, rev], dim=1)

    train_edge_index = make_undirected(train_pos)
    val_edge_index = make_undirected(val_pos)   # not used for prop
    test_edge_index = make_undirected(test_pos)

    # Save only the positive directions for evaluation
    torch.save(train_pos, 'data/processed/train_pos.pt')
    torch.save(val_pos, 'data/processed/val_pos.pt')
    torch.save(test_pos, 'data/processed/test_pos.pt')
    torch.save(train_edge_index, 'data/processed/train_edge_index.pt')

    print(f"Saved splits: Train {train_pos.size(1)}, Val {val_pos.size(1)}, Test {test_pos.size(1)}")
    return train_pos, val_pos, test_pos, train_edge_index

# Create splits if any required file is missing, else try to load (fall back to recreate on error)
if not all(os.path.exists(f) for f in required_files):
    train_pos, val_pos, test_pos, train_edge_index = create_splits()
else:
    try:
        print("Loading splits...")
        train_pos = torch.load('data/processed/train_pos.pt')
        val_pos = torch.load('data/processed/val_pos.pt')
        test_pos = torch.load('data/processed/test_pos.pt')
        train_edge_index = torch.load('data/processed/train_edge_index.pt')
    except Exception as e:
        print(f"Error loading splits ({e}), recreating...")
        train_pos, val_pos, test_pos, train_edge_index = create_splits()

print(f"Positive edges - Train: {train_pos.size(1)}, Val: {val_pos.size(1)}, Test: {test_pos.size(1)}")
# ...existing code...

Loading splits...
Positive edges - Train: 38088, Val: 4761, Test: 4761


In [18]:
## 3. LightGCN Model

class LightGCN(nn.Module):
    def __init__(self, num_users, num_books, embedding_dim=64, num_layers=3):
        super().__init__()
        self.num_users = num_users
        self.num_books = num_books
        self.embedding_dim = embedding_dim
        
        self.num_layers = num_layers
        
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_books, embedding_dim)
        
        nn.init.normal_(self.user_embedding.weight, std=0.01)
        nn.init.normal_(self.item_embedding.weight, std=0.01)
    
    def forward(self, edge_index):
        x = torch.cat([self.user_embedding.weight, self.item_embedding.weight], dim=0)
        
        out_list = [x]
        for _ in range(self.num_layers):
            x = self.propagate(x, edge_index)
            out_list.append(x)
        
        final = sum(out_list) / (self.num_layers + 1)
        user_emb, item_emb = torch.split(final, [self.num_users, self.num_books])
        return user_emb, item_emb
    
    def propagate(self, x, edge_index):
        row, col = edge_index
        deg = degree(row, num_nodes=x.size(0))
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        
        # Sparse multiplication
        edge_index_sparse = torch.sparse_coo_tensor(edge_index, norm, (x.size(0), x.size(0)))
        return edge_index_sparse @ x

device = torch.device('cpu')
model = LightGCN(num_users, num_books, embedding_dim=64, num_layers=3).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-6)

print(model)
print(f"Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

LightGCN(
  (user_embedding): Embedding(3404, 64)
  (item_embedding): Embedding(2178, 64)
)
Parameters: 357,248


In [19]:
## 4. BPR Loss and Training Epoch

def bpr_loss(pos_scores, neg_scores):
    return -torch.log(torch.sigmoid(pos_scores - neg_scores) + 1e-8).mean()

@torch.no_grad()
def get_embeddings(model, edge_index):
    model.eval()
    return model(edge_index)

def train_epoch(model, edge_index, batch_size=2048):
    model.train()
    total_loss = 0.0
    
    # Use positive direction only for batching
    pos_edges = edge_index[:, ::2]  # one direction
    perm = torch.randperm(pos_edges.size(1))
    pos_edges = pos_edges[:, perm]
    
    for start in range(0, pos_edges.size(1), batch_size):
        end = start + batch_size
        batch_pos = pos_edges[:, start:end]
        
        # Random negative items (same number)
        neg_items = torch.randint(0, num_books, (batch_pos.size(1),), device=device)
        
        # Full forward (on whole graph - efficient for small data)
        user_emb, item_emb = get_embeddings(model, edge_index)
        
        # Scores
        pos_u_emb = user_emb[batch_pos[0]]
        pos_i_emb = item_emb[batch_pos[1] - num_users]
        neg_i_emb = item_emb[neg_items]
        
        pos_scores = (pos_u_emb * pos_i_emb).sum(dim=1)
        neg_scores = (pos_u_emb * neg_i_emb).sum(dim=1)
        
        loss = bpr_loss(pos_scores, neg_scores)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # stability
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / (pos_edges.size(1) // batch_size + 1)

In [20]:
## 5. Evaluation (Fixed)

@torch.no_grad()
def evaluate(model, pos_edges, K=10):
    model.eval()
    user_emb, item_emb = model(train_edge_index)
    scores = user_emb @ item_emb.t()  # [num_users, num_books]
    
    # Mask training positives
    mask = torch.zeros(num_users, num_books, dtype=torch.bool)
    mask[train_pos[0], train_pos[1] - num_users] = True
    scores[mask] = -1e8
    
    _, topk = torch.topk(scores, K, dim=1)
    
    users = pos_edges[0].cpu()
    true_items = pos_edges[1].cpu() - num_users
    
    recall = 0.0
    ndcg = 0.0
    num_users_with_test = users.unique().size(0)
    
    for u in users.unique():
        u_topk = topk[u]
        u_true = true_items[users == u]
        hits = torch.isin(u_topk, u_true)
        if hits.any():
            recall += 1
            rank = hits.nonzero(as_tuple=True)[0][0] + 1
            ndcg += 1 / np.log2(rank + 1)
    
    return recall / num_users_with_test, ndcg / len(pos_edges)


In [21]:
def train_epoch(model, edge_index, batch_size=2048, pos_edges=None):
    model.train()
    total_loss = 0.0
    num_batches = 0

    # Use provided pos_edges if available, otherwise fall back to slicing
    if pos_edges is None:
        pos_edges = edge_index[:, ::2]  # one direction if not provided

    pos_edges = pos_edges.to(edge_index.device)
    perm = torch.randperm(pos_edges.size(1))
    pos_edges = pos_edges[:, perm]

    for start in range(0, pos_edges.size(1), batch_size):
        end = start + batch_size
        batch_pos = pos_edges[:, start:end]
        if batch_pos.size(1) == 0:
            continue

        # Random negative items
        neg_items = torch.randint(0, num_books, (batch_pos.size(1),), device=edge_index.device)

        # Forward pass on full training graph
        user_emb, item_emb = model(edge_index)

        # Positive scores
        users = batch_pos[0]
        pos_items = batch_pos[1] - num_users
        pos_scores = (user_emb[users] * item_emb[pos_items]).sum(dim=1)

        # Negative scores
        neg_scores = (user_emb[users] * item_emb[neg_items]).sum(dim=1)

        # BPR loss
        loss = -torch.log(torch.sigmoid(pos_scores - neg_scores) + 1e-8).mean()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    return total_loss / num_batches if num_batches > 0 else 0.0

In [22]:
# 6. Training Loop - This Will Work

os.makedirs('models', exist_ok=True)

best_recall = 0.0
patience = 7
wait = 0
max_epochs = 50

print("Starting LightGCN training...\n")

for epoch in range(1, max_epochs + 1):
    loss = train_epoch(model, train_edge_index, pos_edges=train_pos, batch_size=2048)
    
    val_recall, val_ndcg = evaluate(model, val_pos, K=10)
    
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Val Recall@10: {val_recall:.4f} | Val NDCG@10: {val_ndcg:.4f}")
    
    if val_recall > best_recall:
        best_recall = val_recall
        torch.save(model.state_dict(), 'models/best_lightgcn.pt')
        wait = 0
        print("  >>> Best model saved!")
    else:
        wait += 1
        if wait >= patience:
            print("Early stopping.")
            break

print("\nTraining done!")
model.load_state_dict(torch.load('models/best_lightgcn.pt'))
test_recall, test_ndcg = evaluate(model, test_pos, K=10)
print(f"Test Recall@10: {test_recall:.4f} | Test NDCG@10: {test_ndcg:.4f}")

Starting LightGCN training...



  ndcg += 1 / np.log2(rank + 1)


Epoch 01 | Loss: 0.6930 | Val Recall@10: 0.1177 | Val NDCG@10: 63.0130
  >>> Best model saved!
Epoch 02 | Loss: 0.6920 | Val Recall@10: 0.1098 | Val NDCG@10: 59.6696
Epoch 03 | Loss: 0.6872 | Val Recall@10: 0.1039 | Val NDCG@10: 56.0294
Epoch 04 | Loss: 0.6762 | Val Recall@10: 0.0990 | Val NDCG@10: 53.5635
Epoch 05 | Loss: 0.6592 | Val Recall@10: 0.0975 | Val NDCG@10: 52.8414
Epoch 06 | Loss: 0.6384 | Val Recall@10: 0.0975 | Val NDCG@10: 53.3798
Epoch 07 | Loss: 0.6178 | Val Recall@10: 0.0931 | Val NDCG@10: 51.6983
Epoch 08 | Loss: 0.5975 | Val Recall@10: 0.0955 | Val NDCG@10: 52.5229
Early stopping.

Training done!
Test Recall@10: 0.1203 | Test NDCG@10: 65.6868
