# Training the simple GNN recommender

Companion to `training_lightgcn.ipynb`, but using the simpler SAGE-based encoder from `modeling/models/simple.py`. This notebook fixes the loss computation used in `training_tests.py` by explicitly separating positive/negative scores before calling `BPR_loss`.


In [16]:
import time

import matplotlib.pyplot as plt
import torch
from torch import optim
from torch_geometric.data import HeteroData

from modeling.losses import BPR_loss
from modeling.metrics import calculate_metrics
from modeling.models.simple import Model
from modeling.sampling import prepare_training_data, sample_minibatch

torch.manual_seed(1)


<torch._C.Generator at 0x76dc2b2b3eb0>

In [17]:
# Load data
data: HeteroData = torch.load("data/hetero_data_no_coauthor.pt", weights_only=False)

paper_ids = data["paper"].node_id
paper_embeddings = data["paper"].x
author_ids = data["author"].node_id
author_embeddings = torch.ones((data["author"].num_nodes, paper_embeddings.shape[1]))
edge_index = data["author", "writes", "paper"].edge_index

print(f"Number of authors: {len(author_ids)}")
print(f"Number of papers: {len(paper_ids)}")
print(f"Number of edges: {edge_index.shape[1]}")


Number of authors: 90941
Number of papers: 63854
Number of edges: 320187


In [5]:
# Train/val/test split and message-passing vs supervision edges
(
    message_passing_edge_index,
    supervision_edge_index,
    val_edge_index_raw,
    test_edge_index_raw,
) = prepare_training_data(edge_index)

# Keep non-offset copies for evaluation (user/item ids remain contiguous)
train_edge_index_raw = torch.cat([message_passing_edge_index, supervision_edge_index], dim=1)

# Build joint embedding table and offset paper ids so authors/papers share the same adjacency
node_embeddings = torch.cat([author_embeddings, paper_embeddings], dim=0)
edge_index_offset = torch.tensor([0, author_embeddings.shape[0]])
message_passing_edge_index = message_passing_edge_index + edge_index_offset.view(2, 1)
supervision_edge_index = supervision_edge_index + edge_index_offset.view(2, 1)
val_edge_index = val_edge_index_raw + edge_index_offset.view(2, 1)
test_edge_index = test_edge_index_raw + edge_index_offset.view(2, 1)

num_authors, num_papers = len(author_ids), len(paper_ids)


In [15]:
num_authors, num_papers

(90941, 63854)

In [None]:
# Hyperparameters
ITERATIONS = 10000
BATCH_SIZE = 512
LR = 5e-3
NEG_SAMPLE_RATIO = 5
ITERS_PER_EVAL = 1000
K = 20


In [None]:
# Setup
model = Model(embedding_dim=paper_embeddings.shape[1])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}.")

model = model.to(device)
model.train()
optimizer = optim.Adam(model.parameters(), lr=LR)

node_embeddings = node_embeddings.to(device)
message_passing_edge_index = message_passing_edge_index.to(device)
supervision_edge_index = supervision_edge_index.to(device)
val_edge_index = val_edge_index.to(device)
test_edge_index = test_edge_index.to(device)
train_edge_index_raw = train_edge_index_raw.to(device)
val_edge_index_raw = val_edge_index_raw.to(device)
test_edge_index_raw = test_edge_index_raw.to(device)


In [None]:
# Training loop
train_losses = []
timings = {"batching": [], "forward": [], "loss": [], "backward": []}

for iter in range(ITERATIONS):
    # Mini-batch sampling (returns positive + negative supervision edges)
    start_time = time.time()
    pos_edge_index, neg_edge_index = sample_minibatch(
        supervision_edge_index,
        BATCH_SIZE,
        neg_sample_ratio=NEG_SAMPLE_RATIO,
    )
    pos_edge_index = pos_edge_index.to(device)
    neg_edge_index = neg_edge_index.to(device)
    batch_edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=1)
    timings["batching"].append(time.time() - start_time)

    # Forward pass
    start_time = time.time()
    scores = model(
        node_embeddings,
        message_passing_edge_index,
        batch_edge_index,
    )
    pos_scores = scores[: pos_edge_index.shape[1]]
    neg_scores = scores[pos_edge_index.shape[1] :]
    timings["forward"].append(time.time() - start_time)

    # Correct BPR loss: compare positive vs negative scores
    start_time = time.time()
    train_loss = BPR_loss(pos_scores, neg_scores)
    timings["loss"].append(time.time() - start_time)

    # Backward
    start_time = time.time()
    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()
    timings["backward"].append(time.time() - start_time)

    if (iter + 1) % ITERS_PER_EVAL == 0:
        model.eval()
        with torch.no_grad():
            node_emb = model.get_node_embeddings(node_embeddings, message_passing_edge_index)
            user_embedding = node_emb[:num_authors]
            item_embedding = node_emb[num_authors:]

        val_recall, val_precision = calculate_metrics(
            user_embedding,
            item_embedding,
            val_edge_index_raw,
            [train_edge_index_raw],
            K,
            batch_size=512,
            device=device,
        )

        train_recall, train_precision = calculate_metrics(
            user_embedding,
            item_embedding,
            supervision_edge_index,
            [message_passing_edge_index],
            K,
            batch_size=1024,
        )

        print(
            f"[Iter {iter + 1}/{ITERATIONS}] loss: {train_loss.item():.5f}, val_recall@{K}: {val_recall:.5f}, val_precision@{K}: {val_precision:.5f}"
        )
        train_losses.append(train_loss.item())
        model.train()

print("Training done.")


In [None]:
# Loss and timing curves
iters = [i * ITERS_PER_EVAL for i in range(len(train_losses))]
plt.plot(iters, train_losses, label="train")
plt.xlabel("iteration")
plt.ylabel("loss")
plt.title("training loss")
plt.legend()
plt.grid()
plt.savefig("training_simple_loss.png")

plt.plot(timings["batching"][5:], label="batching")
plt.plot(timings["forward"][5:], label="forwarding")
plt.plot(timings["loss"][5:], label="loss computation")
plt.plot(timings["backward"][5:], label="backwarding")
plt.xlabel("iteration")
plt.ylabel("time (s)")
plt.title("time per operation")
plt.legend()
plt.grid()
plt.savefig("training_simple_timing.png")


In [None]:
# Final test evaluation
model.eval()
with torch.no_grad():
    node_emb = model.get_node_embeddings(node_embeddings, message_passing_edge_index)
    user_embedding = node_emb[:num_authors]
    item_embedding = node_emb[num_authors:]

test_recall, test_precision = calculate_metrics(
    user_embedding,
    item_embedding,
    test_edge_index_raw,
    [train_edge_index_raw, val_edge_index_raw],
    K,
    batch_size=512,
    device=device,
)

print(f"[test_recall@{K}: {round(test_recall, 5)}, test_precision@{K}: {round(test_precision, 5)}")
