## GraphSAGE (Neighbor Sampling) for Ranking on Graph3 (Goodbooks-10k)

In this notebook we train a scalable GNN recommender on the Graph3 augmented graph for Goodbooks-10k.

Graph3 contains multiple node/edge types:

Nodes: users, books, tags, authors, language, year_bin

### Edges:

user–book (train only)

book–tag (weighted: log1p(count))

book–book similarity (TF-IDF cosine)

book–author

book–language

book–year_bin

### Why GraphSAGE here?

Our previous R-GCN run showed a strong objective ↔ metric mismatch:

trained with BCE + negative sampling (binary link prediction),

evaluated with global ranking (Leave-One-Out Hit@K / NDCG@K),

resulting metrics were near zero despite decreasing loss.

### GraphSAGE is a better fit because:

it supports mini-batch neighbor sampling (practical training on large graphs),

we can optimize a ranking loss (BPR), aligned with the LOO ranking metrics.

### Evaluation

We keep the same evaluation protocol as the LightGCN baseline:

Hit@10/20/50

NDCG@10/20/50
using Leave-One-Out ground truth pairs.

Artifacts / Inputs

We load the frozen Graph3 bundle from:
D:/ML/GNN/graph_recsys/artifacts/v2_proper/graph3_bundle/

Bundle files:

graph3_state.pt — graph tensors (edges / edge types / A_norm / offsets / vocab)

splits_ui.npz — train / val / test LOO splits

In [1]:
# ============================
# Cell 1: Imports + device + paths
# Purpose:
# - Define project paths
# - Setup torch device
# - Basic imports used throughout the notebook
# ============================

import copy
import json
import math
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass
from pathlib import Path
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
from tqdm.auto import tqdm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

PROJECT_ROOT = Path(r"D:/ML/GNN/graph_recsys")
ARTIFACTS = PROJECT_ROOT / "artifacts" / "v2_proper"
BUNDLE_DIR = ARTIFACTS / "graph3_bundle"

print("BUNDLE_DIR:", BUNDLE_DIR)
assert BUNDLE_DIR.exists(), f"Missing bundle: {BUNDLE_DIR}"

DEVICE: cuda
BUNDLE_DIR: D:\ML\GNN\graph_recsys\artifacts\v2_proper\graph3_bundle


In [2]:
# ============================
# Cell 2: Load Graph3 bundle (graph tensors + LOO splits)
# Purpose:
# - Load saved graph state (sparse adjacency, edge types, offsets, etc.)
# - Load train/val/test user-item splits (LOO protocol)
# Notes:
# - We keep everything on CPU for now; sampling will move mini-batches to GPU later.
# ============================

g = torch.load(BUNDLE_DIR / "graph3_state.pt", map_location="cpu")
z = np.load(BUNDLE_DIR / "splits_ui.npz", allow_pickle=True)

train_ui = z["train_ui"].astype(np.int64)
val_ui   = z["val_ui"].astype(np.int64)
test_ui  = z["test_ui"].astype(np.int64)

# naming: U = n_users, B = n_books/items
U = int(z["U"])
B = int(z["B"])

# Main sparse graph representation (COO)
A_norm = g["A_norm"].coalesce()   # sparse COO on CPU
num_nodes = int(g["num_nodes"])
offsets = g["offsets"]            # dict with node-type offsets in the global ID space

print(f"Loaded: num_nodes={num_nodes}, U={U}, B={B}")
print("train/val/test:", train_ui.shape, val_ui.shape, test_ui.shape)
print("A_norm nnz:", int(A_norm._nnz()))
print("offsets:", offsets)

  g = torch.load(BUNDLE_DIR / "graph3_state.pt", map_location="cpu")


Loaded: num_nodes=74285, U=53398, B=9999
train/val/test: (4926384, 2) (53398, 2) (53398, 2)
A_norm nnz: 11260518
offsets: {'user_offset': 0, 'book_offset': 53398, 'tag_offset': 63397, 'author_offset': 68411, 'lang_offset': 74252, 'year_offset': 74278}


In [3]:
# ============================
# Cell 3: Build train positives + LOO ground truth (val/test)
# Purpose:
# - train_pos[u] = set(items) used to:
#   1) filter already-seen items during ranking
#   2) avoid sampling positives as negatives
# - val_gt/test_gt: dict user -> held-out item (LOO)
# ============================

from collections import defaultdict

train_pos = defaultdict(set)
for u, i in train_ui:
    train_pos[int(u)].add(int(i))

val_gt  = {int(u): int(i) for u, i in val_ui}
test_gt = {int(u): int(i) for u, i in test_ui}

print("train_pos users:", len(train_pos))
print("val_gt:", len(val_gt), "test_gt:", len(test_gt))

train_pos users: 53398
val_gt: 53398 test_gt: 53398


In [4]:
# ============================
# Cell 4: Build bipartite user-book edge_index for training
# Purpose:
# - GraphSAGE with BPR is typically trained on the user-item interaction graph
# - We build an undirected edge_index (u<->(i+U)) in global node ID space
# Notes:
# - items are shifted by +U so that users: [0..U-1], items: [U..U+B-1]
# - This matches the convention used in LightGCN baseline
# ============================

u = torch.from_numpy(train_ui[:, 0]).long()
i = torch.from_numpy(train_ui[:, 1]).long() + U

row = torch.cat([u, i], dim=0)
col = torch.cat([i, u], dim=0)

edge_index_ui = torch.stack([row, col], dim=0)  # [2, 2*E]
num_nodes_ui = U + B

print("edge_index_ui:", tuple(edge_index_ui.shape), "num_nodes_ui:", num_nodes_ui)

edge_index_ui: (2, 9852768) num_nodes_ui: 63397


In [5]:
# ============================
# Cell 5: Sanity checks
# Purpose:
# - Verify that val/test items are not present in train for the same user (LOO assumption)
# - Basic stats for debugging
# ============================

def check_loo_splits(train_pos, gt_dict, name="val"):
    bad = 0
    for u, gt_i in gt_dict.items():
        if gt_i in train_pos.get(u, set()):
            bad += 1
    print(f"[{name}] leaks (gt in train_pos): {bad} / {len(gt_dict)}")

check_loo_splits(train_pos, val_gt, "val")
check_loo_splits(train_pos, test_gt, "test")

# interaction density quick look
num_interactions = train_ui.shape[0]
print("train interactions:", num_interactions)
print("avg train interactions per user:", num_interactions / max(1, U))

[val] leaks (gt in train_pos): 0 / 53398
[test] leaks (gt in train_pos): 0 / 53398
train interactions: 4926384
avg train interactions per user: 92.25783737218623


In [6]:
# ============================
# Cell 6: Config + reproducibility
# Purpose:
# - Central place for hyperparameters for GraphSAGE+BPR
# - Fix seeds for stable comparisons
# ============================

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

CFG = {
    "embedding_dim": 64,
    "num_layers": 2,
    "dropout": 0.1,
    "batch_size_users": 2048,      
    "neighbors": [15, 10],         # neighbor sampling fanout per layer
    "lr": 1e-3,
    "weight_decay": 1e-6,
    "epochs": 10,
    "bpr_reg": 1e-6,
    "eval_every": 1,
}
CFG

{'embedding_dim': 64,
 'num_layers': 2,
 'dropout': 0.1,
 'batch_size_users': 2048,
 'neighbors': [15, 10],
 'lr': 0.001,
 'weight_decay': 1e-06,
 'epochs': 10,
 'bpr_reg': 1e-06,
 'eval_every': 1}

In [7]:
# ============================
# Cell 7: Build PyG Data + NeighborLoader
# Purpose:
# - Wrap edge_index in torch_geometric.data.Data
# - Prepare a NeighborLoader that samples neighbors for a mini-batch of *seed nodes*
# Notes:
# - We'll train only on the bipartite user-item graph first.
# - Seed nodes will be USERS; positives/negatives are sampled as ITEMS.
# ============================

data_ui = Data(edge_index=edge_index_ui, num_nodes=num_nodes_ui)

# Train loader samples neighborhoods for a set of seed nodes.
# We'll pass user node ids (0..U-1) as input_nodes.
train_user_nodes = torch.arange(U, dtype=torch.long)

train_loader = NeighborLoader(
    data_ui,
    input_nodes=train_user_nodes,
    num_neighbors=CFG["neighbors"],   # e.g. [15, 10]
    batch_size=CFG["batch_size_users"],
    shuffle=True,
    num_workers=0,                    # set >0 later if you want
    persistent_workers=False
)

# quick smoke: one batch
batch = next(iter(train_loader))
print(batch)
print("batch.num_nodes:", batch.num_nodes, "batch.edge_index:", batch.edge_index.shape)

Data(edge_index=[2, 104485], num_nodes=46092, n_id=[46092], e_id=[104485], num_sampled_nodes=[3], num_sampled_edges=[2], input_id=[2048], batch_size=2048)
batch.num_nodes: 46092 batch.edge_index: torch.Size([2, 104485])


In [8]:
# ============================
# Cell 8: GraphSAGE model
# Purpose:
# - Node embeddings table for all nodes (users + items)
# - GraphSAGE layers produce context-aware embeddings
# ============================

class GraphSAGERec(nn.Module):
    def __init__(self, num_nodes: int, dim: int = 64, num_layers: int = 2, dropout: float = 0.1):
        super().__init__()
        self.num_nodes = num_nodes
        self.dim = dim
        self.dropout = dropout

        self.emb = nn.Embedding(num_nodes, dim)
        nn.init.normal_(self.emb.weight, std=0.1)

        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            self.convs.append(SAGEConv(dim, dim))

    def forward(self, x, edge_index):
        # x is node ids (LongTensor)
        h = self.emb(x)
        for conv in self.convs:
            h = conv(h, edge_index)
            h = F.relu(h)
            h = F.dropout(h, p=self.dropout, training=self.training)
        return h

model = GraphSAGERec(
    num_nodes=num_nodes_ui,
    dim=CFG["embedding_dim"],
    num_layers=CFG["num_layers"],
    dropout=CFG["dropout"],
).to(DEVICE)

print(model)

GraphSAGERec(
  (emb): Embedding(63397, 64)
  (convs): ModuleList(
    (0-1): 2 x SAGEConv(64, 64, aggr=mean)
  )
)


In [9]:
# ============================
# Cell 9: BPR helpers
# Purpose:
# - Sample negatives for (user, positive_item) pairs
# - Compute BPR loss from user/item embeddings
# Notes:
# - We must avoid sampling already-seen train positives as negatives.
# ============================

def sample_negatives(users_np: np.ndarray, num_items: int, train_pos, rng: np.random.Generator):
    """
    For each user, sample one negative item not in train_pos[user].
    Returns: neg_items (np.int64)
    """
    neg = np.empty_like(users_np, dtype=np.int64)
    for idx, u in enumerate(users_np):
        seen = train_pos[int(u)]
        while True:
            j = int(rng.integers(0, num_items))
            if j not in seen:
                neg[idx] = j
                break
    return neg

def bpr_loss(u_emb, p_emb, n_emb, reg=0.0):
    """
    BPR loss: -log sigma(u·p - u·n)
    """
    pos_scores = (u_emb * p_emb).sum(dim=-1)
    neg_scores = (u_emb * n_emb).sum(dim=-1)
    loss = -torch.log(torch.sigmoid(pos_scores - neg_scores) + 1e-8).mean()
    if reg > 0:
        loss = loss + reg * (u_emb.pow(2).mean() + p_emb.pow(2).mean() + n_emb.pow(2).mean())
    return loss

In [11]:
# ============================
# Cell 10: Training loop (smoke run)
# Purpose:
# - One epoch training with NeighborLoader + BPR
# Strategy:
# - For each batch of seed users, we sample ONE positive per user from train_pos
# - Then sample one negative item per user
# - Run GraphSAGE on the sampled subgraph and compute BPR on user/item embeddings
# ============================

optimizer = torch.optim.Adam(model.parameters(), lr=CFG["lr"], weight_decay=CFG["weight_decay"])
rng = np.random.default_rng(SEED)

def train_one_epoch():
    model.train()
    total_loss = 0.0
    steps = 0

    for batch in tqdm(train_loader, desc="train"):
        # batch contains a subgraph. batch.n_id maps local nodes -> global node ids.
        batch = batch.to(DEVICE)
        n_id = batch.n_id  # global node ids for nodes in this batch-subgraph

        # Seed users are the first batch.batch_size nodes in input_nodes order
        # In NeighborLoader, batch.input_id gives the seed node ids (global) of this batch.
        seed_users = batch.input_id  # global user node ids
        seed_users_np = seed_users.detach().cpu().numpy()

        # sample 1 positive per user
        pos_items = np.array([next(iter(train_pos[int(u)])) for u in seed_users_np], dtype=np.int64)

        # sample 1 negative per user
        neg_items = sample_negatives(seed_users_np, B, train_pos, rng)

        # convert to global item node ids
        pos_nodes = torch.from_numpy(pos_items).long().to(DEVICE) + U
        neg_nodes = torch.from_numpy(neg_items).long().to(DEVICE) + U

        # IMPORTANT: ensure sampled items exist in this batch subgraph
        # If an item is not included by sampling, its embedding won't be computed in forward.
        # Quick workaround (for now): compute embeddings directly from nn.Embedding for items not in batch.
        # We'll do a robust fix later (force include nodes or use embedding table fallback).
        # For now, we compute full node embeddings for the batch subgraph:
        h = model(n_id, batch.edge_index)  # embeddings for nodes in n_id order

        # map global ids -> local positions in n_id
        # build dict on CPU? no, do vectorized mapping using an index array:
        # create a tensor of size num_nodes_ui with -1; fill local indices for n_id
        idx_map = torch.full((num_nodes_ui,), -1, device=DEVICE, dtype=torch.long)
        idx_map[n_id] = torch.arange(n_id.size(0), device=DEVICE)

        u_loc = idx_map[seed_users]
        p_loc = idx_map[pos_nodes]
        n_loc = idx_map[neg_nodes]

        # fallback: if p_loc or n_loc == -1, take direct embedding (no message passing)
        u_emb = h[u_loc]

        def get_item_emb(loc_idx, global_nodes):
            mask = loc_idx >= 0
            out = torch.empty((global_nodes.size(0), CFG["embedding_dim"]), device=DEVICE)
            out[mask] = h[loc_idx[mask]]
            out[~mask] = model.emb(global_nodes[~mask])  # fallback
            return out

        p_emb = get_item_emb(p_loc, pos_nodes)
        n_emb = get_item_emb(n_loc, neg_nodes)

        loss = bpr_loss(u_emb, p_emb, n_emb, reg=CFG["bpr_reg"])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += float(loss.detach().cpu())
        steps += 1

    return total_loss / max(1, steps)

avg_loss = train_one_epoch()
print("avg train loss:", avg_loss)

train:   0%|          | 0/27 [00:00<?, ?it/s]

avg train loss: 0.5631215947645681


In [12]:
# ============================
# Cell 11: Prepare fast random positive sampling
# Purpose:
# - Convert train_pos sets -> arrays to sample positives uniformly
# ============================

train_pos_arr = {}
train_pos_size = np.zeros(U, dtype=np.int32)

for u in range(U):
    items = np.fromiter(train_pos[u], dtype=np.int64)
    train_pos_arr[u] = items
    train_pos_size[u] = items.size

print("min/mean/max train_pos size:", train_pos_size.min(), train_pos_size.mean(), train_pos_size.max())

def sample_positives(users_np: np.ndarray, rng: np.random.Generator):
    pos = np.empty_like(users_np, dtype=np.int64)
    for idx, u in enumerate(users_np):
        arr = train_pos_arr[int(u)]
        pos[idx] = int(arr[rng.integers(0, arr.size)])
    return pos

min/mean/max train_pos size: 3 92.25783737218623 197


In [13]:
# ============================
# Cell 12: Fast LOO evaluation on sampled candidates (subset users)
# Purpose:
# - Quick check that BPR training improves ranking-like metrics
# - Evaluate on candidate set: [1 gt + (C-1) negatives]
# ============================

@torch.no_grad()
def eval_loo_sampled(model, gt_dict, users_subset, C=200, Ks=(10, 20, 50), seed=42):
    """
    Candidate-based LOO eval:
    For each user:
      candidates = [gt_item] + sampled negatives (not in train_pos)
    Score by dot(user_emb, item_emb) using raw embedding table (fast).
    """
    model.eval()
    rng = np.random.default_rng(seed)

    hits = {k: 0 for k in Ks}
    ndcgs = {k: 0.0 for k in Ks}

    # We'll use base embeddings (no message passing) for speed in this quick eval.
    # Later we can do a more faithful eval if needed.
    emb = model.emb.weight.detach()  # [num_nodes_ui, dim], on DEVICE? model on DEVICE, so yes

    for u in tqdm(users_subset, desc="eval(sampled)"):
        gt = int(gt_dict[int(u)])

        # build candidates
        negs = []
        seen = train_pos[int(u)]
        while len(negs) < C - 1:
            j = int(rng.integers(0, B))
            if (j not in seen) and (j != gt):
                negs.append(j)

        cand_items = np.array([gt] + negs, dtype=np.int64)
        cand_nodes = torch.from_numpy(cand_items).long().to(DEVICE) + U

        u_node = torch.tensor([int(u)], device=DEVICE, dtype=torch.long)
        u_vec = emb[u_node]                         # [1, dim]
        i_vec = emb[cand_nodes]                     # [C, dim]
        scores = (u_vec * i_vec).sum(dim=-1)        # [C]

        # rank descending
        rank = torch.argsort(scores, descending=True)
        # position of gt is where rank == 0 (since gt is at index 0 in cand_items)
        gt_pos = (rank == 0).nonzero(as_tuple=False).item()  # 0-based

        for k in Ks:
            if gt_pos < k:
                hits[k] += 1
                ndcgs[k] += 1.0 / math.log2(gt_pos + 2)  # +2 because positions start at 1 in DCG formula

    n = len(users_subset)
    out = {f"Hit@{k}": hits[k] / n for k in Ks}
    out.update({f"NDCG@{k}": ndcgs[k] / n for k in Ks})
    return out

# subset users for fast check
subset = np.random.default_rng(SEED).choice(np.arange(U), size=2000, replace=False)
metrics = eval_loo_sampled(model, val_gt, subset, C=200, Ks=(10,20,50), seed=SEED)
metrics

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

{'Hit@10': 0.062,
 'Hit@20': 0.109,
 'Hit@50': 0.2575,
 'NDCG@10': 0.029169205182808054,
 'NDCG@20': 0.04087964021528664,
 'NDCG@50': 0.0698140560566304}

In [15]:
# ============================
# Cell 13: Train for a few epochs + fast eval
# Purpose:
# - See if metrics move in the right direction
# ============================

def train_one_epoch_v2():
    model.train()
    total_loss = 0.0
    steps = 0
    for batch in tqdm(train_loader, desc="train"):
        batch = batch.to(DEVICE)
        n_id = batch.n_id
        seed_users = batch.input_id
        seed_users_np = seed_users.detach().cpu().numpy()

        pos_items = sample_positives(seed_users_np, rng)
        neg_items = sample_negatives(seed_users_np, B, train_pos, rng)

        pos_nodes = torch.from_numpy(pos_items).long().to(DEVICE) + U
        neg_nodes = torch.from_numpy(neg_items).long().to(DEVICE) + U

        h = model(n_id, batch.edge_index)

        idx_map = torch.full((num_nodes_ui,), -1, device=DEVICE, dtype=torch.long)
        idx_map[n_id] = torch.arange(n_id.size(0), device=DEVICE)

        u_loc = idx_map[seed_users]
        p_loc = idx_map[pos_nodes]
        n_loc = idx_map[neg_nodes]

        u_emb = h[u_loc]

        def get_item_emb(loc_idx, global_nodes):
            mask = loc_idx >= 0
            out = torch.empty((global_nodes.size(0), CFG["embedding_dim"]), device=DEVICE)
            out[mask] = h[loc_idx[mask]]
            out[~mask] = model.emb(global_nodes[~mask])
            return out

        p_emb = get_item_emb(p_loc, pos_nodes)
        n_emb = get_item_emb(n_loc, neg_nodes)

        loss = bpr_loss(u_emb, p_emb, n_emb, reg=CFG["bpr_reg"])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += float(loss.detach().cpu())
        steps += 1

    return total_loss / max(1, steps)

EPOCHS = 5
for ep in range(1, EPOCHS + 1):
    avg_loss = train_one_epoch_v2()
    m = eval_loo_sampled(model, val_gt, subset, C=200, Ks=(10,20,50), seed=SEED)
    print(f"epoch={ep} loss={avg_loss:.4f} metrics={m}")

train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

epoch=1 loss=0.5884 metrics={'Hit@10': 0.0625, 'Hit@20': 0.117, 'Hit@50': 0.2535, 'NDCG@10': 0.02991845318821254, 'NDCG@20': 0.04348749062164587, 'NDCG@50': 0.07003587265525034}


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

epoch=2 loss=0.5073 metrics={'Hit@10': 0.07, 'Hit@20': 0.1245, 'Hit@50': 0.268, 'NDCG@10': 0.03478295679550602, 'NDCG@20': 0.04849160576371658, 'NDCG@50': 0.07627352996153926}


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

epoch=3 loss=0.4674 metrics={'Hit@10': 0.0815, 'Hit@20': 0.1325, 'Hit@50': 0.2705, 'NDCG@10': 0.0392722566159134, 'NDCG@20': 0.05202523012743539, 'NDCG@50': 0.07878876052453557}


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

epoch=4 loss=0.4606 metrics={'Hit@10': 0.084, 'Hit@20': 0.1335, 'Hit@50': 0.285, 'NDCG@10': 0.042342207566231846, 'NDCG@20': 0.05470456798833096, 'NDCG@50': 0.08413747918220779}


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

epoch=5 loss=0.4531 metrics={'Hit@10': 0.0865, 'Hit@20': 0.1415, 'Hit@50': 0.295, 'NDCG@10': 0.04413208817029047, 'NDCG@20': 0.057961201592542574, 'NDCG@50': 0.08777030653335029}


In [16]:
# ============================
# Cell 14: Train for a few epochs + tougher sampled eval
# ============================

EPOCHS = 5
subset_10k = np.random.default_rng(SEED).choice(np.arange(U), size=10000, replace=False)

for ep in range(1, EPOCHS + 1):
    avg_loss = train_one_epoch_v2()

    m200  = eval_loo_sampled(model, val_gt, subset,      C=200,  Ks=(10,20,50), seed=SEED)
    m1000 = eval_loo_sampled(model, val_gt, subset_10k,  C=1000, Ks=(10,20,50), seed=SEED)

    print(f"epoch={ep} loss={avg_loss:.4f}")
    print("  eval C=200  (2k users): ", m200)
    print("  eval C=1000 (10k users):", m1000)

train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=1 loss=0.4459
  eval C=200  (2k users):  {'Hit@10': 0.09, 'Hit@20': 0.145, 'Hit@50': 0.308, 'NDCG@10': 0.04630391165631217, 'NDCG@20': 0.06012681450828042, 'NDCG@50': 0.09173653350296644}
  eval C=1000 (10k users): {'Hit@10': 0.0264, 'Hit@20': 0.0456, 'Hit@50': 0.0892, 'NDCG@10': 0.013610082739373195, 'NDCG@20': 0.01842730767125949, 'NDCG@50': 0.026915586022710103}


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=2 loss=0.4398
  eval C=200  (2k users):  {'Hit@10': 0.095, 'Hit@20': 0.1515, 'Hit@50': 0.3105, 'NDCG@10': 0.049146213333516243, 'NDCG@20': 0.06331539809195257, 'NDCG@50': 0.09437044772852086}
  eval C=1000 (10k users): {'Hit@10': 0.0276, 'Hit@20': 0.0477, 'Hit@50': 0.0915, 'NDCG@10': 0.014129746711627, 'NDCG@20': 0.019154261606535945, 'NDCG@50': 0.027713866980726243}


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=3 loss=0.4374
  eval C=200  (2k users):  {'Hit@10': 0.0995, 'Hit@20': 0.1605, 'Hit@50': 0.324, 'NDCG@10': 0.05029672998679614, 'NDCG@20': 0.06557543170751791, 'NDCG@50': 0.09749333114187213}
  eval C=1000 (10k users): {'Hit@10': 0.0286, 'Hit@20': 0.0494, 'Hit@50': 0.0947, 'NDCG@10': 0.014593363020934964, 'NDCG@20': 0.01978401294028719, 'NDCG@50': 0.028645318968704443}


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=4 loss=0.4398
  eval C=200  (2k users):  {'Hit@10': 0.108, 'Hit@20': 0.16, 'Hit@50': 0.326, 'NDCG@10': 0.05389124179687326, 'NDCG@20': 0.06690024800750159, 'NDCG@50': 0.09938436311400906}
  eval C=1000 (10k users): {'Hit@10': 0.0306, 'Hit@20': 0.0498, 'Hit@50': 0.1002, 'NDCG@10': 0.015388048796209912, 'NDCG@20': 0.020197617784262753, 'NDCG@50': 0.030085525428997954}


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=5 loss=0.4277
  eval C=200  (2k users):  {'Hit@10': 0.1075, 'Hit@20': 0.1665, 'Hit@50': 0.33, 'NDCG@10': 0.055178454182531644, 'NDCG@20': 0.07008618294428924, 'NDCG@50': 0.1019793030013096}
  eval C=1000 (10k users): {'Hit@10': 0.0311, 'Hit@20': 0.0525, 'Hit@50': 0.1041, 'NDCG@10': 0.015959581089605195, 'NDCG@20': 0.021329786600736856, 'NDCG@50': 0.03141121857691474}


In [17]:
# ============================
# Cell 15: Early stopping + checkpoint helpers
# ============================

@dataclass
class EarlyStopper:
    patience: int = 5
    min_delta: float = 1e-4
    best: float = -1e9
    best_epoch: int = -1
    bad_count: int = 0
    best_state: dict = None

    def step(self, metric_value: float, model: torch.nn.Module, epoch: int) -> bool:
        """
        Returns True if should stop.
        """
        improved = metric_value > (self.best + self.min_delta)
        if improved:
            self.best = metric_value
            self.best_epoch = epoch
            self.bad_count = 0
            # store best weights (GPU-safe)
            self.best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            self.bad_count += 1
        return self.bad_count >= self.patience

    def load_best(self, model: torch.nn.Module, device=DEVICE):
        assert self.best_state is not None, "No best_state saved"
        model.load_state_dict({k: v.to(device) for k, v in self.best_state.items()})

In [18]:
# ============================
# Cell 16: Train up to 30 epochs with early stopping on val NDCG@10 (C=1000)
# ============================

CFG["epochs"] = 30
EARLY = EarlyStopper(patience=5, min_delta=1e-4)

# фиксируем user subsets, чтобы сравнение по эпохам было честным
subset_2k = subset  
subset_10k = subset_10k  

history = []

for ep in range(1, CFG["epochs"] + 1):
    avg_loss = train_one_epoch_v2()

    # быстрый контроль (2k, C=200)
    m200 = eval_loo_sampled(model, val_gt, subset_2k, C=200, Ks=(10,20,50), seed=SEED)

    # основной сигнал для ES (10k, C=1000)
    m1000 = eval_loo_sampled(model, val_gt, subset_10k, C=1000, Ks=(10,20,50), seed=SEED)
    val_ndcg10 = float(m1000["NDCG@10"])

    row = {
        "epoch": ep,
        "loss": avg_loss,
        **{f"val200_{k}": v for k, v in m200.items()},
        **{f"val1000_{k}": v for k, v in m1000.items()},
    }
    history.append(row)

    print(f"epoch={ep:02d} loss={avg_loss:.4f} | val NDCG@10 (C=1000, 10k users) = {val_ndcg10:.6f}")

    if EARLY.step(val_ndcg10, model, ep):
        print(f"Early stopping at epoch {ep}. Best epoch={EARLY.best_epoch} best NDCG@10={EARLY.best:.6f}")
        break

# load best model weights
EARLY.load_best(model, device=DEVICE)
print("Loaded best checkpoint:", EARLY.best_epoch, "best val NDCG@10:", EARLY.best)

train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=01 loss=0.4339 | val NDCG@10 (C=1000, 10k users) = 0.016822


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=02 loss=0.4268 | val NDCG@10 (C=1000, 10k users) = 0.017561


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=03 loss=0.4281 | val NDCG@10 (C=1000, 10k users) = 0.018519


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=04 loss=0.4271 | val NDCG@10 (C=1000, 10k users) = 0.019001


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=05 loss=0.4226 | val NDCG@10 (C=1000, 10k users) = 0.019171


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=06 loss=0.4145 | val NDCG@10 (C=1000, 10k users) = 0.019954


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=07 loss=0.4132 | val NDCG@10 (C=1000, 10k users) = 0.020055


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=08 loss=0.4059 | val NDCG@10 (C=1000, 10k users) = 0.020539


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=09 loss=0.4030 | val NDCG@10 (C=1000, 10k users) = 0.021259


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=10 loss=0.4008 | val NDCG@10 (C=1000, 10k users) = 0.021671


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=11 loss=0.3926 | val NDCG@10 (C=1000, 10k users) = 0.021925


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=12 loss=0.3909 | val NDCG@10 (C=1000, 10k users) = 0.021864


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=13 loss=0.3920 | val NDCG@10 (C=1000, 10k users) = 0.022408


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=14 loss=0.3913 | val NDCG@10 (C=1000, 10k users) = 0.022953


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=15 loss=0.3847 | val NDCG@10 (C=1000, 10k users) = 0.023623


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=16 loss=0.3820 | val NDCG@10 (C=1000, 10k users) = 0.023593


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=17 loss=0.3804 | val NDCG@10 (C=1000, 10k users) = 0.024029


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=18 loss=0.3772 | val NDCG@10 (C=1000, 10k users) = 0.024243


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=19 loss=0.3701 | val NDCG@10 (C=1000, 10k users) = 0.024632


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=20 loss=0.3730 | val NDCG@10 (C=1000, 10k users) = 0.024948


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=21 loss=0.3746 | val NDCG@10 (C=1000, 10k users) = 0.025256


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=22 loss=0.3674 | val NDCG@10 (C=1000, 10k users) = 0.026053


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=23 loss=0.3602 | val NDCG@10 (C=1000, 10k users) = 0.026580


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=24 loss=0.3615 | val NDCG@10 (C=1000, 10k users) = 0.026863


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=25 loss=0.3594 | val NDCG@10 (C=1000, 10k users) = 0.026483


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=26 loss=0.3688 | val NDCG@10 (C=1000, 10k users) = 0.026936


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=27 loss=0.3576 | val NDCG@10 (C=1000, 10k users) = 0.027151


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=28 loss=0.3610 | val NDCG@10 (C=1000, 10k users) = 0.027448


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=29 loss=0.3589 | val NDCG@10 (C=1000, 10k users) = 0.027557


train:   0%|          | 0/27 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/2000 [00:00<?, ?it/s]

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

epoch=30 loss=0.3541 | val NDCG@10 (C=1000, 10k users) = 0.027834
Loaded best checkpoint: 30 best val NDCG@10: 0.027834138642069708


In [19]:
# ============================
# Cell 17: Final TEST evaluation (sampled candidates)
# ============================

# фиксируем test subset (10k users) для скорости и стабильности
test_subset_10k = np.random.default_rng(SEED + 123).choice(np.arange(U), size=10000, replace=False)

test_m1000 = eval_loo_sampled(model, test_gt, test_subset_10k, C=1000, Ks=(10,20,50), seed=SEED + 123)
print("TEST metrics (C=1000, 10k users):", test_m1000)

# если терпимо по времени — можно чуть честнее:
test_m2000 = eval_loo_sampled(model, test_gt, test_subset_10k, C=2000, Ks=(10,20,50), seed=SEED + 123)
print("TEST metrics (C=2000, 10k users):", test_m2000)

eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

TEST metrics (C=1000, 10k users): {'Hit@10': 0.0577, 'Hit@20': 0.0922, 'Hit@50': 0.1715, 'NDCG@10': 0.02978591198187386, 'NDCG@20': 0.03842829609518776, 'NDCG@50': 0.0540368908993595}


eval(sampled):   0%|          | 0/10000 [00:00<?, ?it/s]

TEST metrics (C=2000, 10k users): {'Hit@10': 0.0356, 'Hit@20': 0.0572, 'Hit@50': 0.1067, 'NDCG@10': 0.01825064652190281, 'NDCG@20': 0.023635743952412133, 'NDCG@50': 0.033412209287885984}


In [20]:
# ============================
# Cell 18: Save history + best checkpoint info
# ============================

hist_df = pd.DataFrame(history)
out_dir = ARTIFACTS / "ablation_runs" / "graphsage_sampling"
out_dir.mkdir(parents=True, exist_ok=True)

hist_path = out_dir / "history_graphsage_bpr_sampled_eval.csv"
hist_df.to_csv(hist_path, index=False)

meta = {
    "best_epoch": EARLY.best_epoch,
    "best_val_ndcg10_C1000_10k": float(EARLY.best),
    "config": CFG,
    "bundle_dir": str(BUNDLE_DIR),
}
with open(out_dir / "run_meta.json", "w", encoding="utf-8") as f:
    json.dump(meta, f, ensure_ascii=False, indent=2)

# сохраняем веса (лучшие)
ckpt_path = out_dir / "graphsage_bpr_best.pt"
torch.save({"state_dict": EARLY.best_state, "meta": meta}, ckpt_path)

print("Saved:", hist_path)
print("Saved:", ckpt_path)
print("Saved:", out_dir / "run_meta.json")

Saved: D:\ML\GNN\graph_recsys\artifacts\v2_proper\ablation_runs\graphsage_sampling\history_graphsage_bpr_sampled_eval.csv
Saved: D:\ML\GNN\graph_recsys\artifacts\v2_proper\ablation_runs\graphsage_sampling\graphsage_bpr_best.pt
Saved: D:\ML\GNN\graph_recsys\artifacts\v2_proper\ablation_runs\graphsage_sampling\run_meta.json


## Results & Conclusions
### What we did

Loaded the frozen Graph3 bundle (Goodbooks-10k) and kept the original leave-one-out (LOO) splits.

Switched from binary link prediction (BCE) to a ranking-aligned objective:

GraphSAGE with neighbor sampling

BPR loss (pairwise ranking)

Trained the model in mini-batches of users with sampled neighborhoods.

Evaluated with candidate-based LOO ranking to approximate global ranking:

C=200, C=1000, C=2000 candidates per user (1 GT + negatives)

Metrics: Hit@K / NDCG@K

### Key findings

Training is stable and convergent (loss decreases smoothly).

Ranking quality improves steadily with more epochs.

Best checkpoint (epoch 30) achieved on validation:

Val NDCG@10 (C=1000, 10k users) ≈ 0.02783

Final test metrics (10k users):

TEST (C=1000): Hit@10 = 0.0577, NDCG@10 = 0.02979

TEST (C=2000): Hit@10 = 0.0356, NDCG@10 = 0.01825

### Interpretation

Unlike the R-GCN experiment (BCE + negative sampling), GraphSAGE trained with BPR demonstrates meaningful ranking performance.

This supports the diagnosis that R-GCN underperformed mainly due to an objective–metric mismatch (binary classification vs global ranking evaluation).

Candidate-based evaluation becomes stricter as C increases, reducing metrics as expected, while preserving consistent ranking signal.

### Limitations (important)

Current evaluation is candidate-based, not full-ranking over all items.

During training, some sampled positives/negatives may fall outside the sampled subgraph and fall back to raw embeddings (a practical workaround for a scalable smoke-to-full pipeline).

### Next steps

Move to the next model notebook to compare architectures under the same ranking objective.

In a later iteration, improve training/evaluation fidelity by:

forcing inclusion of positive/negative items in sampled subgraphs,

running a batched inference step to compute embeddings for all nodes and performing full-ranking LOO evaluation.