# LightGCN


In [None]:
import os, json, random, math
import numpy as np
import pandas as pd
import torch
from torch import optim

from model import LightGCN
from data_utils import load_dataset, build_graph, sample_batch
from evaluator import rank_and_metrics

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED);
device = 'cuda' if torch.cuda.is_available() else 'cpu'

CFG = dict(
    dataset='yelp2018',
    data_dir='data/yelp2018',
    seed=42, device=device,
    epochs=10, batch_size=4096,
    embed_dim=64, lr=1e-3, weight_decay=1e-4,
    negatives_per_pos=1,
    node_dropout_p=0.0, edge_dropout_p=0.0,
    eval_ks=(10, 20),
)

def apply_method_cfg(cfg, method):
    cfg = cfg.copy()
    if method == 'mf_like':
        cfg['K'] = 0
    return cfg

def resolve_data_dir(cfg):
    dd = cfg['data_dir']
    if os.path.exists(dd):
        return dd
    alt = os.path.join('LightGCN-PyTorch', 'data', cfg['dataset'])
    if os.path.exists(alt):
        return alt
    raise FileNotFoundError(f'Dataset folder not found: {dd} or {alt}')

CFGm = apply_method_cfg(CFG, METHOD)


## 2) Data loading & graph build

In [None]:
data_dir = resolve_data_dir(CFGm)
train, valid, test, n_users, n_items = load_dataset(CFGm['dataset'], data_dir)

graph = build_graph(train, n_users, n_items).to(device)
num_inter = sum(len(v) for v in train.values())
density = num_inter / (n_users * n_items) if n_users*n_items>0 else 0


## 3) Model (single class)

In [None]:
model = LightGCN(
    n_users=n_users, n_items=n_items, graph=graph,
    embed_dim=CFGm['embed_dim'], K=CFGm['K'],
    node_dropout_p=CFGm['node_dropout_p'], edge_dropout_p=CFGm['edge_dropout_p']
).to(device)
optimizer = optim.Adam(model.parameters(), lr=CFGm['lr'])


## 4) Training loop (BPR)

In [None]:
def to_device(*tensors):
    return [t.to(device) for t in tensors]

history = []
epochs = CFGm['epochs']
for epoch in range(1, epochs+1):
    model.train()
    users, pos, neg = sample_batch(train, n_items, CFGm['batch_size'], CFGm['negatives_per_pos'])
    users, pos, neg = to_device(users, pos, neg)
    optimizer.zero_grad()
    loss, reg = model.bpr_loss(users, pos, neg, weight_decay=CFGm['weight_decay'])
    loss.backward()
    optimizer.step()
    history.append({'epoch': epoch, 'loss': float(loss.detach().cpu()), 'reg': float(reg.cpu())})
    if epoch % max(1, epochs//5) == 0 or epoch == 1:

hist_df = pd.DataFrame(history)
hist_df.tail()


## 5) Evaluation

In [None]:
model.eval()
metrics = rank_and_metrics(model, test, train, n_items, ks=CFGm['eval_ks'], batch_size=2048, device=device)
metrics


## 6) Ablations (K, dropout, weight decay)

In [None]:
from copy import deepcopy

def run_one(cfg, method):
    cfgm = apply_method_cfg(cfg, method)
    dd = resolve_data_dir(cfgm)
    tr, va, te, U, I = load_dataset(cfgm['dataset'], dd)
    g = build_graph(tr, U, I).to(device)
    m = LightGCN(U, I, g, cfgm['embed_dim'], cfgm['K'], cfgm['node_dropout_p'], cfgm['edge_dropout_p']).to(device)
    opt = optim.Adam(m.parameters(), lr=cfgm['lr'])
    for ep in range(cfgm['epochs']):
        m.train()
        u, p, n = sample_batch(tr, I, cfgm['batch_size'], cfgm['negatives_per_pos'])
        u, p, n = u.to(device), p.to(device), n.to(device)
        opt.zero_grad()
        loss, _ = m.bpr_loss(u, p, n, weight_decay=cfgm['weight_decay'])
        loss.backward(); opt.step()
    m.eval()
    return rank_and_metrics(m, te, tr, I, ks=cfgm['eval_ks'], batch_size=2048, device=device)

grid = {
    'K': [0, 1, 2, 3],
    'edge_dropout_p': [0.0, 0.2],
    'node_dropout_p': [0.0, 0.2],
    'weight_decay': [0.0, 1e-4, 1e-3],
}



## 7) Two datasets run

## 8) Plots & qualitative cases

In [None]:
import matplotlib.pyplot as plt
plt.figure(); plt.plot(hist_df['epoch'], hist_df['loss']); plt.title('Train Loss'); plt.xlabel('epoch'); plt.ylabel('loss'); plt.show()


## 9) Results tables (for report)

## 10) Export for report

In [None]:
os.makedirs('outputs', exist_ok=True)
hist_df.to_csv('outputs/train_history.csv', index=False)
with open('outputs/summary.txt', 'w') as f:
    f.write(json.dumps(dict(config=CFGm, metrics=metrics), indent=2))
