
# 04 — Ranker & End‑to‑End (**FAST GPU**) — Recall@10 / NDCG@10

This version removes per‑sample I/O, uses **FAISS‑GPU batched searches**, **GPU feature building with sharded Parquet**, and **AMP** for ranker training.

**Highlights**
- Candidate generation on GPU (batched).
- Feature engineering on GPU with **optional negative subsampling** (`neg_per_query`).
- Ranker training per‑**shard** (load once → many batches), not per‑sample file reads.
- Mixed precision (**torch.cuda.amp**), gradient clipping.


In [1]:

# --- 0) Config & paths --------------------------------------------------------
import os, json, math, time, gc, glob, bisect
from pathlib import Path
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt

torch.set_float32_matmul_precision('high')

PROJECT_ROOT = Path.home() / "KAUST-Project"   # /home/kamalyy/KAUST-Project
OUT_DIR = PROJECT_ROOT / "data" / "processed" / "online_retail_II"
OUT_DIR.mkdir(parents=True, exist_ok=True)
print("OUT_DIR:", OUT_DIR)

READ_KW  = dict(engine="fastparquet")
WRITE_KW = dict(engine="fastparquet", index=False)

CFG = {
    # Retrieval
    "cand_topk": 100,
    "cand_batch": 8192,          # FAISS search batch size
    # Histories & features
    "hist_max": 50,
    # Feature building
    "feat_batch_q": 2048,        # queries per GPU batch when building features
    "shard_rows": 2_000_000,     # approx rows per Parquet shard (features)
    "neg_per_query": 20,         # keep 1 pos + N hard negatives per query (set None to keep all K)
    "hard_negatives": True,      # choose hardest by dot_uv; False=random
    # Ranker
    "batch_size": 4096,          # larger thanks to AMP
    "epochs": 15,
    "patience": 3,
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "dropout": 0.2,
    "hidden": 512,
    "eval_topk": 10,
    "seed": 42,
    "use_text": True,
    # FAISS GPU
    "faiss_use_all_gpus": False, # set True to use all GPUs
    "faiss_device": 0,
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
torch.manual_seed(CFG["seed"])
if device.type == "cuda":
    torch.cuda.manual_seed_all(CFG["seed"])

# Sequences
seq_train = pd.read_parquet(OUT_DIR/'sequences_train.parquet', **READ_KW)
seq_val   = pd.read_parquet(OUT_DIR/'sequences_val.parquet', **READ_KW)
seq_test  = pd.read_parquet(OUT_DIR/'sequences_test.parquet', **READ_KW)
print("Train/Val/Test:", seq_train.shape, seq_val.shape, seq_test.shape)

# Popularity for features/fallback
pop_counts = seq_train['pos_item_idx'].value_counts()
pop_norm = (pop_counts - pop_counts.min()) / (pop_counts.max() - pop_counts.min() + 1e-9)

# Item vectors from retriever (already normalized)
V_PATH = OUT_DIR/'retriever_item_vectors.npy'
assert V_PATH.exists(), "Missing retriever_item_vectors.npy from notebook 03"
ITEM_VECS = np.load(V_PATH).astype('float32')
n_items, d_vec = ITEM_VECS.shape
print("Item vectors:", ITEM_VECS.shape)

# Optional text embeddings (normalize)
TXT_PATH = OUT_DIR/'item_text_emb.npy'
HAS_TEXT = (CFG["use_text"] and TXT_PATH.exists())
ITEM_TXT = np.load(TXT_PATH).astype('float32') if HAS_TEXT else None
if HAS_TEXT:
    ITEM_TXT /= np.maximum(np.linalg.norm(ITEM_TXT, axis=1, keepdims=True), 1e-6)
    print("Item text emb:", ITEM_TXT.shape)

# GPU tensors for items
ITEM_VECS_T = torch.from_numpy(ITEM_VECS).to(device, non_blocking=True)
ITEM_TXT_T  = torch.from_numpy(ITEM_TXT).to(device, non_blocking=True) if HAS_TEXT else None

# Popularity & optional price_z tensors
pop_vec = torch.zeros(n_items, dtype=torch.float32, device=device)
pop_idx = torch.tensor(pop_counts.index.values, dtype=torch.long, device=device)
pop_val = torch.tensor(pop_norm.loc[pop_counts.index].values, dtype=torch.float32, device=device)
pop_vec[pop_idx] = pop_val

price_z = None
items_clean_path = OUT_DIR/'items_clean.parquet'
if items_clean_path.exists():
    items_clean = pd.read_parquet(items_clean_path, **READ_KW)
    if 'price_median' in items_clean.columns and 'item_idx' in items_clean.columns:
        m = items_clean[['item_idx','price_median']].dropna()
        mu, sigma = m['price_median'].mean(), m['price_median'].std() + 1e-6
        z = ((m['price_median'] - mu) / sigma).astype(float)
        price_z = torch.zeros(n_items, dtype=torch.float32, device=device)
        ii = torch.tensor(m['item_idx'].astype(int).values, dtype=torch.long, device=device)
        price_z[ii] = torch.tensor(z.values, dtype=torch.float32, device=device)

def parse_hist(s):
    if not isinstance(s, str) or not s.strip():
        return []
    return [int(x) for x in s.strip().split()]


OUT_DIR: /home/kamalyy/KAUST-Project/data/processed/online_retail_II
Device: cuda
Train/Val/Test: (597298, 6) (75190, 6) (75583, 6)
Item vectors: (4446, 512)


## 1) FAISS‑GPU index & batched retrieval

In [2]:

def load_faiss_gpu_index():
    import faiss
    idx_meta_path = OUT_DIR/'index_meta.json'
    cpu_index = None
    if idx_meta_path.exists():
        meta = json.load(open(idx_meta_path))
        if meta.get("type","").startswith("faiss"):
            cpu_index = faiss.read_index(str(OUT_DIR/'items.faiss'))
    if cpu_index is None:
        d = ITEM_VECS.shape[1]
        cpu_index = faiss.IndexFlatIP(d)
        cpu_index.add(ITEM_VECS)
    if CFG["faiss_use_all_gpus"]:
        return faiss.index_cpu_to_all_gpus(cpu_index)
    else:
        res = faiss.StandardGpuResources()
        return faiss.index_cpu_to_gpu(res, CFG["faiss_device"], cpu_index)

try:
    import faiss
    ANN = load_faiss_gpu_index()
    print("FAISS GPU index ready.")
except Exception as e:
    print("[warn] FAISS GPU init failed; CPU fallback:", e)
    ANN = None

@torch.no_grad()
def user_vecs_from_hist_batch(hist_tensor):
    # hist_tensor: LongTensor [B, L], with -1 as PAD → mean of valid item vectors
    B, L = hist_tensor.shape
    safe_idx = hist_tensor.clamp(min=0)  # replace -1 with 0 for gather
    H = ITEM_VECS_T.index_select(0, safe_idx.view(-1)).view(B, L, -1)  # [B,L,d]
    mask = (hist_tensor >= 0).float().unsqueeze(-1)                    # [B,L,1]
    U = (H * mask).sum(1) / mask.sum(1).clamp_min(1e-6)                # [B,d]
    U = F.normalize(U, dim=-1)
    return U

def build_hist_tensor(series, L):
    B = len(series)
    H = torch.full((B, L), -1, dtype=torch.long)
    for i, s in enumerate(series):
        h = parse_hist(s)
        if len(h) > L: h = h[-L:]
        if h:
            H[i, -len(h):] = torch.tensor(h, dtype=torch.long)
    return H

def batched_faiss_search(U_np, topk, batch=8192):
    if ANN is None:
        with torch.no_grad():
            U = torch.from_numpy(U_np).to(device)
            sims = U @ ITEM_VECS_T.t()
            return torch.topk(sims, k=topk, dim=1).indices.detach().cpu().numpy().astype('int32')
    I_all = []
    for i in range(0, U_np.shape[0], batch):
        I = ANN.search(U_np[i:i+batch], topk)[1]
        I_all.append(I)
    return np.vstack(I_all)


FAISS GPU index ready.


## 2) Candidate generation on GPU (batched)

In [3]:

def gen_candidates_gpu(df, topk=100, batch_q=8192):
    hist_series = df['history_idx'].astype(str).tolist()
    H = build_hist_tensor(hist_series, CFG["hist_max"])  # CPU
    U_chunks = []
    for i in range(0, H.size(0), batch_q):
        Ub = user_vecs_from_hist_batch(H[i:i+batch_q].to(device))
        U_chunks.append(Ub.detach().cpu())
    U = torch.cat(U_chunks, 0).numpy().astype('float32')

    I_full = batched_faiss_search(U, topk=topk, batch=CFG["cand_batch"])  # [N,topk]

    pos_list = df['pos_item_idx'].astype(int).tolist()
    ts_list  = df['ts'].astype(str).tolist() if 'ts' in df.columns else ['']*len(pos_list)
    rows = []
    for pos, cand_idx, ts, h_s in zip(pos_list, I_full, ts_list, hist_series):
        cand = cand_idx.tolist()
        if pos not in cand:
            cand[-1] = int(pos)
        rows.append((h_s, int(pos), " ".join(map(str,cand)), ts))
    return pd.DataFrame(rows, columns=['history_idx','pos_item_idx','cands','ts'])

# Build & save if missing
CAND_TRAIN_PATH = OUT_DIR/'candidates_train.parquet'
CAND_VAL_PATH   = OUT_DIR/'candidates_val.parquet'
CAND_TEST_PATH  = OUT_DIR/'candidates_test.parquet'

if not (CAND_TRAIN_PATH.exists() and CAND_VAL_PATH.exists() and CAND_TEST_PATH.exists()):
    print("Generating training candidates (GPU)...")
    gen_candidates_gpu(seq_train, topk=CFG["cand_topk"], batch_q=CFG["cand_batch"]).to_parquet(CAND_TRAIN_PATH, **WRITE_KW)
    print("Generating validation candidates (GPU)...")
    gen_candidates_gpu(seq_val,   topk=CFG["cand_topk"], batch_q=CFG["cand_batch"]).to_parquet(CAND_VAL_PATH, **WRITE_KW)
    print("Generating test candidates (GPU)...")
    gen_candidates_gpu(seq_test,  topk=CFG["cand_topk"], batch_q=CFG["cand_batch"]).to_parquet(CAND_TEST_PATH, **WRITE_KW)

cand_train = pd.read_parquet(CAND_TRAIN_PATH, **READ_KW)
cand_val   = pd.read_parquet(CAND_VAL_PATH, **READ_KW)
cand_test  = pd.read_parquet(CAND_TEST_PATH, **READ_KW)
print("Candidates:", cand_train.shape, cand_val.shape, cand_test.shape)


Candidates: (597298, 4) (75190, 4) (75583, 4)


## 3) Feature engineering on GPU (sharded) + optional negative subsampling

In [4]:

def _pack_batch_features(Ub, Hb, Cb, Pb):
    B, d = Ub.shape
    K = Cb.size(1)
    L = Hb.size(1)
    # Candidate item vectors
    Vc = ITEM_VECS_T.index_select(0, Cb.view(-1)).view(B, K, d)
    # dot(u,v)
    dot_uv = (Ub.unsqueeze(1) * Vc).sum(-1)                        # [B,K]
    # Max sim to recent
    safe_hist = Hb.clamp(min=0)
    Hvec = ITEM_VECS_T.index_select(0, safe_hist.view(-1)).view(B, L, d)
    Hvec = F.normalize(Hvec, dim=-1); Vc_n = F.normalize(Vc, dim=-1)
    sims = torch.matmul(Hvec, Vc_n.transpose(1,2))                  # [B,L,K]
    maskL = (Hb >= 0).unsqueeze(-1).float()
    sims = sims + (maskL - 1.0) * 1e9
    max_sim_recent = sims.max(dim=1).values                         # [B,K]
    # pop & hist_len
    pop = pop_vec.index_select(0, Cb.view(-1)).view(B, K)
    hlen = (Hb >= 0).float().sum(1) / float(CFG["hist_max"])
    hlen = hlen.unsqueeze(1).expand(B, K)
    # price_z
    if isinstance(globals().get('price_z', None), torch.Tensor):
        price = price_z.index_select(0, Cb.view(-1)).view(B, K)
    else:
        price = torch.zeros((B, K), device=Ub.device)
    # text_sim
    if ITEM_TXT_T is not None:
        Th = ITEM_TXT_T.index_select(0, safe_hist.view(-1)).view(B, L, -1)
        mask = (Hb >= 0).float().unsqueeze(-1)
        Th_mean = (Th * mask).sum(1) / mask.sum(1).clamp_min(1e-6)
        Tc = ITEM_TXT_T.index_select(0, Cb.view(-1)).view(B, K, -1)
        text_sim = torch.matmul(Th_mean.unsqueeze(1), Tc.transpose(1,2)).squeeze(1)
    else:
        text_sim = torch.zeros((B, K), device=Ub.device)
    # labels
    labels = (Cb == Pb.view(-1,1)).float()
    return {"dot_uv": dot_uv, "max_sim_recent": max_sim_recent, "pop": pop,
            "hist_len": hlen, "price_z": price, "text_sim": text_sim,
            "label": labels, "item_idx": Cb.float()}

def _select_negatives(feats):
    # Keep 1 positive + N negatives per query if configured
    if CFG["neg_per_query"] is None:
        return feats
    B, K = feats["label"].shape
    pos_col = torch.argmax(feats["label"], dim=1, keepdim=True)     # [B,1]
    if CFG["hard_negatives"]:
        neg_scores = feats["dot_uv"].clone()
        neg_scores.scatter_(1, pos_col, -1e9)
        _, neg_idx = torch.topk(neg_scores, k=min(CFG["neg_per_query"], K-1), dim=1)
    else:
        rnd = torch.rand_like(feats["dot_uv"])
        rnd.scatter_(1, pos_col, 1e9)
        _, neg_idx = torch.topk(-rnd, k=min(CFG["neg_per_query"], K-1), dim=1)
    keep_cols = torch.cat([pos_col, neg_idx], dim=1)                # [B,1+N]
    for k in ["dot_uv","max_sim_recent","pop","hist_len","price_z","text_sim","label","item_idx"]:
        feats[k] = torch.gather(feats[k], 1, keep_cols)
    return feats

def build_feats_gpu_sharded(cand_df, split_name):
    N = len(cand_df)
    L = CFG["hist_max"]
    bq = CFG["feat_batch_q"]
    out_dir = OUT_DIR / f"ranker_feats_{split_name}_shards"
    out_dir.mkdir(parents=True, exist_ok=True)

    def write_shard(idx, feats_dict):
        cpu = {k: v.detach().float().view(-1).to('cpu').numpy() for k,v in feats_dict.items()}
        df = pd.DataFrame(cpu); df['item_idx'] = df['item_idx'].astype(np.int32)
        df.to_parquet(out_dir / f"part_{idx:03d}.parquet", **WRITE_KW)

    hist_series = cand_df['history_idx'].astype(str).tolist()
    pos_list = cand_df['pos_item_idx'].astype(int).tolist()
    cands_series = cand_df['cands'].astype(str).tolist()

    rows_written = 0; shard_idx = 0; buf = None
    for i in range(0, N, bq):
        H = build_hist_tensor(hist_series[i:i+bq], L).to(device, non_blocking=True)
        P = torch.tensor(pos_list[i:i+bq], dtype=torch.long, device=device)
        C = torch.tensor([[int(x) for x in s.split()] for s in cands_series[i:i+bq]],
                         dtype=torch.long, device=device)
        U = user_vecs_from_hist_batch(H)
        feats = _pack_batch_features(U, H, C, P)
        feats = _select_negatives(feats)

        if buf is None:
            buf = {k: v.detach().clone() for k,v in feats.items()}
        else:
            for k in buf.keys():
                buf[k] = torch.cat([buf[k], feats[k]], dim=0)

        rows_in_buf = int(buf["label"].numel())
        if rows_in_buf >= CFG["shard_rows"]:
            write_shard(shard_idx, buf); shard_idx += 1
            for k in list(buf.keys()): del buf[k]
            buf = None; torch.cuda.empty_cache()

        rows_written += int(feats["label"].numel())
        if (i//bq) % 20 == 0:
            print(f"[{split_name}] Built ~{rows_written/1e6:.2f}M rows...")

    if buf is not None:
        write_shard(shard_idx, buf); shard_idx += 1
        for k in list(buf.keys()): del buf[k]
        buf = None; torch.cuda.empty_cache()

    print(f"[{split_name}] Done. Rows ~{rows_written:,}. Shards -> {out_dir}")
    return out_dir

# Build shards if missing
train_shards_dir = OUT_DIR / "ranker_feats_train_shards"
val_shards_dir   = OUT_DIR / "ranker_feats_val_shards"
test_shards_dir  = OUT_DIR / "ranker_feats_test_shards"

if not train_shards_dir.exists():
    train_shards_dir = build_feats_gpu_sharded(cand_train, "train")
if not val_shards_dir.exists():
    val_shards_dir = build_feats_gpu_sharded(cand_val, "val")
if not test_shards_dir.exists():
    test_shards_dir = build_feats_gpu_sharded(cand_test, "test")

print("Shard dirs:", train_shards_dir, val_shards_dir, test_shards_dir)


Shard dirs: /home/kamalyy/KAUST-Project/data/processed/online_retail_II/ranker_feats_train_shards /home/kamalyy/KAUST-Project/data/processed/online_retail_II/ranker_feats_val_shards /home/kamalyy/KAUST-Project/data/processed/online_retail_II/ranker_feats_test_shards


## 4) Ranker (per‑shard training, AMP)

In [5]:

RANKER_COLS = ["dot_uv","max_sim_recent","pop","hist_len","price_z","text_sim"]

def shard_batches(files, batch_size):
    for f in sorted(glob.glob(str(Path(files)/"part_*.parquet")) if isinstance(files, (str, Path)) else files):
        df = pd.read_parquet(f, engine="fastparquet", columns=RANKER_COLS+["label"])

        # ✅ build tensors without non_blocking, then .to(device, non_blocking=True)
        X_np = df[RANKER_COLS].to_numpy(dtype='float32', copy=False)
        y_np = df["label"].to_numpy(dtype='float32', copy=False)
        X = torch.from_numpy(X_np).to(device, non_blocking=True)
        y = torch.from_numpy(y_np).to(device, non_blocking=True)

        perm = torch.randperm(X.size(0), device=device)
        for i in range(0, X.size(0), batch_size):
            idx = perm[i:i+batch_size]
            yield X[idx], y[idx]

        del X, y, df
        torch.cuda.empty_cache()


class RankerMLP(nn.Module):
    def __init__(self, d_in, hidden=512, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, hidden), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden, hidden//2), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden//2, 1)
        )
    def forward(self, x): return self.net(x).squeeze(-1)

model = RankerMLP(len(RANKER_COLS), hidden=CFG["hidden"], dropout=CFG["dropout"]).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=CFG["lr"], weight_decay=CFG["weight_decay"])
loss_fn = nn.BCEWithLogitsLoss()
use_amp = (device.type == "cuda")
amp_device = "cuda" if use_amp else "cpu"

# New GradScaler (old: torch.cuda.amp.GradScaler)
scaler = torch.amp.GradScaler(amp_device, enabled=use_amp)
torch.backends.cudnn.benchmark = True


## 5) Train & Eval (end‑to‑end Recall@10 / NDCG@10)

In [6]:

def recall_at_k(preds, truth): return float(truth in preds)
def ndcg_at_k(preds, truth):
    try:
        r = preds.index(truth) + 1; return 1.0 / math.log2(r + 1.0)
    except ValueError: return 0.0

@torch.no_grad()
def rerank_one_batch(cand_df_slice):
    H = build_hist_tensor(cand_df_slice['history_idx'].astype(str).tolist(), CFG["hist_max"]).to(device)
    P = torch.tensor(cand_df_slice['pos_item_idx'].astype(int).tolist(), dtype=torch.long, device=device)
    C = torch.tensor([[int(x) for x in s.split()] for s in cand_df_slice['cands'].astype(str).tolist()],
                     dtype=torch.long, device=device)
    U = user_vecs_from_hist_batch(H)
    feats = _pack_batch_features(U, H, C, P)
    # No negative subsampling at inference: rank all provided candidates
    X = torch.stack([feats[c] for c in RANKER_COLS], dim=-1) # [B,K,6]
    scores = model(X.view(-1, len(RANKER_COLS))).view(X.size(0), X.size(1))
    topk = min(CFG["eval_topk"], C.size(1))
    vals, idx = torch.topk(scores, k=topk, dim=1)
    reranked = [C[i][idx[i]].tolist() for i in range(C.size(0))]
    return reranked

@torch.no_grad()
def eval_reranked(cand_df, split="val"):
    model.eval()
    hits = 0; ndcgs = 0.0; tot = 0
    B = 2048
    for i in range(0, len(cand_df), B):
        batch_df = cand_df.iloc[i:i+B]
        reranked = rerank_one_batch(batch_df)
        for pos, rr in zip(batch_df['pos_item_idx'].tolist(), reranked):
            pos = int(pos)
            hits += float(pos in rr)
            if pos in rr:
                r = rr.index(pos) + 1
                ndcgs += 1.0 / math.log2(r + 1.0)
            tot += 1
    return hits/max(1,tot), ndcgs/max(1,tot)


def train_ranker():
    best_recall = -1.0; bad = 0
    for ep in range(1, CFG["epochs"]+1):
        model.train()
        total_loss = 0.0; nobs = 0
        for Xb, yb in shard_batches(train_shards_dir, CFG["batch_size"]):
            opt.zero_grad(set_to_none=True)
            with torch.amp.autocast(amp_device, enabled=use_amp, dtype=torch.float16):
                logits = model(Xb)
                loss = loss_fn(logits, yb)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(opt); scaler.update()
            total_loss += float(loss.item()) * yb.numel()
            nobs += yb.numel()

        # Evaluate every epoch (can change to every 2 for even faster)
        val_recall, val_ndcg = eval_reranked(cand_val, split="val")
        print(f"Epoch {ep:02d} | train BCE {total_loss/max(1,nobs):.4f} | "
              f"val Recall@{CFG['eval_topk']} {val_recall:.4f} | val NDCG@{CFG['eval_topk']} {val_ndcg:.4f}")
        if val_recall > best_recall + 1e-4:
            best_recall, bad = val_recall, 0
            torch.save(model.state_dict(), OUT_DIR/'ranker_best.pt')
        else:
            bad += 1
            if bad >= CFG["patience"]:
                print("Early stopping on Recall@10."); break
    print("Best val Recall@{} = {:.4f}".format(CFG["eval_topk"], best_recall))

train_ranker()
model.load_state_dict(torch.load(OUT_DIR/'ranker_best.pt', map_location=device, weights_only=True))
test_recall, test_ndcg = eval_reranked(cand_test, split="test")
print("TEST — Recall@{}: {:.4f}, NDCG@{}: {:.4f}".format(CFG["eval_topk"], test_recall, CFG["eval_topk"], test_ndcg))


Epoch 01 | train BCE 0.0258 | val Recall@10 0.7633 | val NDCG@10 0.7210
Epoch 02 | train BCE 0.0253 | val Recall@10 0.7628 | val NDCG@10 0.7213
Epoch 03 | train BCE 0.0252 | val Recall@10 0.7482 | val NDCG@10 0.7064
Epoch 04 | train BCE 0.0252 | val Recall@10 0.7508 | val NDCG@10 0.7099
Early stopping on Recall@10.
Best val Recall@10 = 0.7633
TEST — Recall@10: 0.7978, NDCG@10: 0.7563
