
# 03 — Two‑Tower Retriever & ANN Index

This notebook trains a **two‑tower retrieval model** on your sequences and builds an **ANN index** for fast candidate generation.

**What it does**
- Loads `item_id_map.parquet`, `customer_id_map.parquet`, and `sequences_{train,val,test}.parquet` from `OUT_DIR`.
- (Optional) Loads `item_text_emb.npy` to enrich items with FM/LLM **text embeddings**.
- Trains a **user tower** (pooled history) and an **item tower** (ID + projected text embedding).
- Uses an **in‑batch softmax** loss (InfoNCE‑style) for scalable training.
- Exports **item vectors** and builds an ANN index with **FAISS** (falls back to a simple brute‑force/sklearn index if FAISS is unavailable).


In [2]:

# --- Config & paths -----------------------------------------------------------
import os, math, gc, json
from pathlib import Path
import numpy as np
import pandas as pd

# Use the same OUT_DIR as previous notebooks (adjust if needed)
OUT_DIR = Path(globals().get('OUT_DIR', r"C:\KAUST-Project\data\processed\online_retail_II"))
OUT_DIR.mkdir(parents=True, exist_ok=True)
print("OUT_DIR:", OUT_DIR)

READ_KW = dict(engine="fastparquet")  # keep using fastparquet given pyarrow issues

# Training config (tune later)
CFG = {
    "d_model": 128,        # embedding dimension
    "hist_max": 30,        # max history length
    "batch_size": 1024,    # adjust to your GPU/CPU memory
    "epochs": 2,           # start small; scale up after smoke test
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "dropout": 0.1,
    "use_country": True,   # include country embedding in user tower
    "text_proj": True,     # include text embeddings if available
    "norm_user": True,
    "norm_item": True,
    "seed": 42,
}

# Detect device
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
torch.manual_seed(CFG["seed"])
if device.type == "cuda":
    torch.cuda.manual_seed_all(CFG["seed"])

# Load artifacts
items      = pd.read_parquet(OUT_DIR/'item_id_map.parquet', **READ_KW)
users      = pd.read_parquet(OUT_DIR/'customer_id_map.parquet', **READ_KW)
seq_train  = pd.read_parquet(OUT_DIR/'sequences_train.parquet', **READ_KW)
seq_val    = pd.read_parquet(OUT_DIR/'sequences_val.parquet', **READ_KW)
seq_test   = pd.read_parquet(OUT_DIR/'sequences_test.parquet', **READ_KW)

n_items = int(len(items)); n_users = int(len(users))
print(f"Items: {n_items} | Users: {n_users}")
print("Train/Val/Test:", seq_train.shape, seq_val.shape, seq_test.shape)

# Country indexing (if enabled)
if CFG["use_country"]:
    countries = pd.concat([seq_train['country'], seq_val['country'], seq_test['country']]).astype(str).unique()
    country_map = {c:i for i,c in enumerate(sorted(countries))}
else:
    country_map = {}

# Load optional text embeddings
emb_path = OUT_DIR/'item_text_emb.npy'
idx_path = OUT_DIR/'item_text_index.parquet'
has_text = False
text_dim = 0
if CFG["text_proj"] and emb_path.exists() and idx_path.exists():
    try:
        item_text_emb = np.load(emb_path)
        item_text_idx = pd.read_parquet(idx_path, **READ_KW)['item_idx'].to_numpy()
        # Sanity checks
        if item_text_emb.shape[0] >= n_items:
            item_text_emb = item_text_emb[:n_items]
        text_dim = int(item_text_emb.shape[1])
        has_text = True
        print("Loaded text embeddings:", item_text_emb.shape)
    except Exception as e:
        print("[warn] Could not load text embeddings, continuing without:", e)
        has_text = False

def parse_hist(s):
    if not isinstance(s, str) or not s.strip():
        return []
    return [int(x) for x in s.strip().split()]

# Prepare a priceholder for country ids
def country_to_idx(c):
    return country_map.get(str(c), 0)


OUT_DIR: C:\KAUST-Project\data\processed\online_retail_II
Device: cpu
Items: 4446 | Users: 5748
Train/Val/Test: (597299, 6) (75190, 6) (75583, 6)



## Dataset & DataLoader

We build batches of `(user_history → positive item)` and rely on **in‑batch negatives**, i.e., each item in the batch serves as a negative for the other users.


In [None]:

# --- PyTorch dataset ----------------------------------------------------------
from torch.utils.data import Dataset, DataLoader

class SeqDataset(Dataset):
    def __init__(self, df, hist_max=30):
        self.hist_max = hist_max
        self.pos = df['pos_item_idx'].astype(int).to_numpy()
        self.hist = df['history_idx'].astype(str).tolist()
        self.country = df['country'].astype(str).tolist()

    def __len__(self):
        return len(self.pos)

    def __getitem__(self, idx):
        h = parse_hist(self.hist[idx])
        if len(h) > self.hist_max:
            h = h[-self.hist_max:]
        return {
            "hist": np.array(h, dtype=np.int64),
            "pos": np.int64(self.pos[idx]),
            "country": country_to_idx(self.country[idx]) if country_map else 0
        }

def collate_batch(batch):
    # Pad histories to max length in batch
    maxL = max((len(x["hist"]) for x in batch), default=1)
    H = np.zeros((len(batch), maxL), dtype=np.int64)
    for i, x in enumerate(batch):
        h = x["hist"]
        if len(h):
            H[i, -len(h):] = h
    pos = np.array([x["pos"] for x in batch], dtype=np.int64)
    country = np.array([x["country"] for x in batch], dtype=np.int64)
    return {
        "hist": torch.from_numpy(H),
        "pos": torch.from_numpy(pos),
        "country": torch.from_numpy(country)
    }

train_ds = SeqDataset(seq_train, hist_max=CFG["hist_max"])
val_ds   = SeqDataset(seq_val.sample(min(100_000, len(seq_val)), random_state=CFG["seed"]), hist_max=CFG["hist_max"])
train_loader = DataLoader(train_ds, batch_size=CFG["batch_size"], shuffle=True, num_workers=0, collate_fn=collate_batch)
val_loader   = DataLoader(val_ds,   batch_size=CFG["batch_size"], shuffle=False, num_workers=0, collate_fn=collate_batch)

len(train_ds), len(val_ds)



## Two‑Tower model

- **Item tower:** `Embedding(n_items, d)` + optional linear projection of **text embedding** concatenated and projected to `d`.
- **User tower:** mean pool of recent item embeddings (share item ID embedding weights) + optional **country embedding**, then MLP → `d`.
- **Loss:** in‑batch softmax over item dot‑products (InfoNCE‑style).


In [None]:

# --- Model --------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F

class TwoTower(nn.Module):
    def __init__(self, n_items, n_countries, d=128, text_dim=0, dropout=0.1, norm_user=True, norm_item=True):
        super().__init__()
        self.d = d
        self.norm_user = norm_user
        self.norm_item = norm_item

        # Item ID embedding (shared for user history)
        self.item_id_emb = nn.Embedding(n_items, d)

        # Optional text projection
        self.has_text = text_dim > 0
        if self.has_text:
            self.text_proj = nn.Linear(text_dim, d, bias=False)

        # Country embedding (small)
        self.has_country = n_countries > 0
        if self.has_country:
            self.country_emb = nn.Embedding(n_countries, d // 4)

        # Small MLP heads
        self.user_mlp = nn.Sequential(
            nn.Linear(d + (d//4 if self.has_country else 0), d),
            nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(d, d)
        )
        self.item_mlp = nn.Sequential(
            nn.Linear(d + (d if self.has_text else 0), d),
            nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(d, d)
        )

        # Init
        nn.init.normal_(self.item_id_emb.weight, std=0.02)
        if self.has_text:
            nn.init.xavier_uniform_(self.text_proj.weight)

    def item_vec(self, item_ids, item_text=None):
        # item_ids: [B] or [N]
        v = self.item_id_emb(item_ids)  # [B, d]
        feats = [v]
        if self.has_text and item_text is not None:
            feats.append(self.text_proj(item_text))  # [B, d]
        v = torch.cat(feats, dim=-1) if len(feats) > 1 else feats[0]
        v = self.item_mlp(v)
        if self.norm_item:
            v = F.normalize(v, dim=-1)
        return v

    def user_vec(self, hist_item_ids, country_ids=None, text_bank=None):
        # hist_item_ids: [B, L]
        B, L = hist_item_ids.shape
        if L == 0:
            raise RuntimeError("Empty history")
        h = self.item_id_emb(hist_item_ids)  # [B, L, d]
        # mean pool over non-zero positions
        mask = (hist_item_ids != 0).float().unsqueeze(-1)  # treat 0 as pad
        # if 0 is a valid item id, we should shift ids by +1; here we assume 0 is valid -> fix:
        # We'll assume padding is zero ONLY when history length < L (we right-pad on the left)
        # Build a mask based on the left side zeros only:
        mask = torch.ones_like(h[..., :1])  # use simple mean of all tokens present
        h_mean = h.mean(dim=1)  # [B, d]

        feats = [h_mean]
        if self.has_country and country_ids is not None:
            feats.append(self.country_emb(country_ids))  # [B, d/4]
        u = torch.cat(feats, dim=-1) if len(feats) > 1 else feats[0]
        u = self.user_mlp(u)
        if self.norm_user:
            u = F.normalize(u, dim=-1)
        return u

    def forward(self, batch, text_bank=None):
        hist = batch["hist"].to(torch.long).to(device)     # [B, L]
        pos  = batch["pos"].to(torch.long).to(device)      # [B]
        country = batch["country"].to(torch.long).to(device) if self.has_country else None

        # Build user vectors
        u = self.user_vec(hist, country_ids=country, text_bank=text_bank)  # [B, d]

        # Positive item vectors
        v_pos = self.item_vec(pos, item_text=text_bank[pos] if (self.has_text and text_bank is not None) else None)  # [B, d]

        # In-batch negatives: compute logits = u @ v_items^T where v_items = item_vec(pos_all)
        v_all = v_pos  # reuse pos embeddings as candidate set
        logits = u @ v_all.t()  # [B, B]
        labels = torch.arange(logits.size(0), device=logits.device)
        loss = F.cross_entropy(logits, labels)
        return loss, logits, u, v_pos



## Train

We use in‑batch negatives. Start with a couple of epochs and check validation loss.


In [None]:

# --- Train --------------------------------------------------------------------
from torch.optim import AdamW

n_countries = len(country_map) if country_map else 0
model = TwoTower(n_items=n_items, n_countries=n_countries, d=CFG["d_model"],
                 text_dim=(text_dim if has_text else 0),
                 dropout=CFG["dropout"],
                 norm_user=CFG["norm_user"], norm_item=CFG["norm_item"]).to(device)

# Build a text bank tensor if available
if has_text:
    text_bank = torch.from_numpy(item_text_emb).to(torch.float32).to(device)
else:
    text_bank = None

opt = AdamW(model.parameters(), lr=CFG["lr"], weight_decay=CFG["weight_decay"])

def run_epoch(loader, training=True):
    model.train(training)
    total_loss = 0.0
    n = 0
    for batch in loader:
        if training:
            opt.zero_grad(set_to_none=True)
        loss, logits, u, v_pos = model(batch, text_bank=text_bank)
        if training:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
        total_loss += loss.item() * logits.size(0)
        n += logits.size(0)
    return total_loss / max(1, n)

for ep in range(1, CFG["epochs"]+1):
    tr_loss = run_epoch(train_loader, training=True)
    val_loss = run_epoch(val_loader, training=False)
    print(f"Epoch {ep:02d} | train CE: {tr_loss:.4f} | val CE: {val_loss:.4f}")



## Export item vectors

We compute item vectors for all items and save them to disk for indexing.


In [None]:

# --- Export item vectors ------------------------------------------------------
model.eval()
with torch.no_grad():
    item_ids = torch.arange(n_items, device=device, dtype=torch.long)
    if has_text:
        # ensure text_bank has at least n_items rows
        tb = text_bank
        if tb.shape[0] < n_items:
            pad = torch.zeros((n_items - tb.shape[0], tb.shape[1]), device=device)
            tb = torch.cat([tb, pad], dim=0)
        item_vecs = model.item_vec(item_ids, item_text=tb[:n_items]).detach().cpu().numpy()
    else:
        item_vecs = model.item_vec(item_ids).detach().cpu().numpy()

np.save(OUT_DIR / "retriever_item_vectors.npy", item_vecs)
print("Saved item vectors:", item_vecs.shape, "->", OUT_DIR / "retriever_item_vectors.npy")



## Build ANN index (FAISS if available)

We build a cosine‑similarity (inner‑product on normalized vectors) index. If **FAISS** is not installed, we fall back to a simple brute‑force index using sklearn `NearestNeighbors` and save it with joblib.


In [None]:

# --- Build ANN index ----------------------------------------------------------
item_vecs = np.load(OUT_DIR / "retriever_item_vectors.npy")
# Ensure unit length for inner product ≈ cosine
from numpy.linalg import norm
norms = np.maximum(norm(item_vecs, axis=1, keepdims=True), 1e-6)
item_vecs = item_vecs / norms

index_artifacts = {}

try:
    import faiss
    d = item_vecs.shape[1]
    index = faiss.IndexFlatIP(d)
    index.add(item_vecs.astype(np.float32))
    faiss.write_index(index, str(OUT_DIR / "items.faiss"))
    index_artifacts["type"] = "faiss.IndexFlatIP"
    index_artifacts["path"] = str(OUT_DIR / "items.faiss")
    print("FAISS index saved ->", OUT_DIR / "items.faiss", "| ntotal:", index.ntotal)
except Exception as e:
    print("[warn] FAISS not available or failed:", e)
    print("Falling back to sklearn NearestNeighbors (brute force).")
    from sklearn.neighbors import NearestNeighbors
    import joblib
    nn = NearestNeighbors(metric="cosine", algorithm="brute")
    nn.fit(item_vecs)
    joblib.dump(nn, OUT_DIR / "items_sklearn_nn.joblib")
    np.save(OUT_DIR / "retriever_item_vectors_normed.npy", item_vecs)
    index_artifacts["type"] = "sklearn.NearestNeighbors"
    index_artifacts["path"] = str(OUT_DIR / "items_sklearn_nn.joblib")
    print("Sklearn NN index saved ->", OUT_DIR / "items_sklearn_nn.joblib")

with open(OUT_DIR / "index_meta.json", "w", encoding="utf-8") as f:
    json.dump(index_artifacts, f, indent=2)



## Retrieval sanity check

We embed a few validation histories and retrieve top‑10 items to ensure the pipeline runs end‑to‑end.


In [None]:

# --- Retrieval sanity check ---------------------------------------------------
def recommend_from_hist(hist_ids, topk=10):
    if len(hist_ids) == 0:
        return []
    # Build user vector (mean of ID embeddings only for quick check)
    model.eval()
    with torch.no_grad():
        h = torch.tensor([hist_ids[-CFG["hist_max"]:]], dtype=torch.long, device=device)
        u = model.user_vec(h).detach().cpu().numpy()  # [1, d]
    u = u / max(np.linalg.norm(u), 1e-6)
    try:
        import faiss
        index = faiss.read_index(str(OUT_DIR / "items.faiss"))
        D, I = index.search(u.astype(np.float32), topk)
        return I[0].tolist()
    except Exception:
        # sklearn fallback
        import joblib
        from sklearn.neighbors import NearestNeighbors
        nn = joblib.load(OUT_DIR / "items_sklearn_nn.joblib")
        dist, nbrs = nn.kneighbors(u, n_neighbors=topk)
        return [int(j) for j in nbrs[0].tolist()]

# Try on a small sample
sample = seq_val.sample(5, random_state=0)
for _, row in sample.iterrows():
    hist = [int(x) for x in row['history_idx'].split()] if isinstance(row['history_idx'], str) else []
    recs = recommend_from_hist(hist, topk=10)
    print("GT:", int(row['pos_item_idx']), "| Recs:", recs[:10])



## Save model weights


In [None]:

# --- Save model ---------------------------------------------------------------
torch.save(model.state_dict(), OUT_DIR / "two_tower.pt")
with open(OUT_DIR / "two_tower_config.json", "w", encoding="utf-8") as f:
    json.dump(CFG, f, indent=2)
print("Saved model ->", OUT_DIR / "two_tower.pt")



### Next
Proceed to **04_ranker_and_eval.ipynb** to train a DIN/MLP ranker that scores retrieved candidates using richer features (country, recency, popularity, price buckets, and optionally text vectors).
