
# 03 — Two‑Tower Retriever & ANN Index

This notebook trains a **two‑tower retrieval model** on your sequences and builds an **ANN index** for fast candidate generation.

**What it does**
- Loads `item_id_map.parquet`, `customer_id_map.parquet`, and `sequences_{train,val,test}.parquet` from `OUT_DIR`.
- (Optional) Loads `item_text_emb.npy` to enrich items with FM/LLM **text embeddings**.
- Trains a **user tower** (pooled history) and an **item tower** (ID + projected text embedding).
- Uses an **in‑batch softmax** loss (InfoNCE‑style) for scalable training.
- Exports **item vectors** and builds an ANN index with **FAISS** (falls back to a simple brute‑force/sklearn index if FAISS is unavailable).


In [1]:

# --- 0) Config & paths --------------------------------------------------------
import os, math, gc, json
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# Set your processed data path
PROJECT_ROOT = Path.home() / "KAUST-Project"   # /home/kamalyy/KAUST-Project
OUT_DIR = PROJECT_ROOT / "data" / "processed" / "online_retail_II"
OUT_DIR.mkdir(parents=True, exist_ok=True)
READ_KW = dict(engine="fastparquet")

CFG = {
    "d_model": 512,          # embedding dim (reduced for interactions)
    "hist_max": 10,           # not used for interactions
    "batch_size": 2048,      # larger batches for interactions
    "accum_steps": 1,        # no need for accumulation with larger batches
    "epochs": 30,            # interactions train faster
    "patience": 8,           # slightly higher patience
    "lr": 1e-3,              # higher LR for interactions
    "weight_decay": 0.0,    # lower weight decay
    "dropout": 0.2,          # lower dropout
    "use_country": True,    # not needed for interactions
    "text_proj": True,      # simplify for interactions
    "eval_topk": 10,         # Recall@10
    "eval_sample": None,     # sample for faster evaluation
    "seed": 42,
    "k_neg": 20,             # fewer negatives for interactions
    "fixed_logit_scale": 10.0 # lower temperature
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
torch.manual_seed(CFG["seed"])
if device.type == "cuda":
    torch.cuda.manual_seed_all(CFG["seed"])

# Load artifacts from 01/02
items      = pd.read_parquet(OUT_DIR/'item_id_map.parquet', **READ_KW)
users      = pd.read_parquet(OUT_DIR/'customer_id_map.parquet', **READ_KW)
seq_train  = pd.read_parquet(OUT_DIR/'sequences_train.parquet', **READ_KW)
seq_val    = pd.read_parquet(OUT_DIR/'sequences_val.parquet', **READ_KW)
seq_test   = pd.read_parquet(OUT_DIR/'sequences_test.parquet', **READ_KW)

n_items = int(len(items)); n_users = int(len(users))
print(f"Items: {n_items} | Users: {n_users}")
print("Train/Val/Test:", seq_train.shape, seq_val.shape, seq_test.shape)

# Country indexing
if CFG["use_country"]:
    countries = pd.concat([seq_train['country'], seq_val['country'], seq_test['country']]).astype(str).unique()
    country_map = {c:i for i,c in enumerate(sorted(countries))}
else:
    country_map = {}

# Optional text embeddings
emb_path = OUT_DIR/'item_text_emb.npy'
has_text = False
text_dim = 0
if CFG["text_proj"] and emb_path.exists():
    try:
        item_text_emb = np.load(emb_path)
        if item_text_emb.shape[0] >= n_items:
            item_text_emb = item_text_emb[:n_items]
        text_dim = int(item_text_emb.shape[1])
        has_text = True
        print("Loaded text embeddings:", item_text_emb.shape)
    except Exception as e:
        print("[warn] Text embeddings not loaded:", e)
        has_text = False

def parse_hist(s):
    if not isinstance(s, str) or not s.strip():
        return []
    return [int(x) for x in s.strip().split()]

def country_to_idx(c):
    return country_map.get(str(c), 0)

OUT_DIR.mkdir(parents=True, exist_ok=True)
print("OUT_DIR:", OUT_DIR)

READ_KW  = dict(engine="fastparquet")

Device: cuda
Items: 4446 | Users: 5748
Train/Val/Test: (592183, 6) (74819, 6) (75377, 6)
Loaded text embeddings: (4446, 384)
OUT_DIR: /home/kamalyy/KAUST-Project/data/processed/online_retail_II



## Dataset & DataLoader

We build batches of `(user_history → positive item)` and rely on **in‑batch negatives**, i.e., each item in the batch serves as a negative for the other users.


In [2]:

class SeqDataset(Dataset):
    def __init__(self, df, hist_max=30, n_items=0, k_neg=50):
        self.hist_max = hist_max
        self.pos = df['pos_item_idx'].astype(int).to_numpy()
        self.hist = df['history_idx'].astype(str).tolist()
        self.country = df['country'].astype(str).tolist()
        self.n_items = n_items
        self.k_neg = k_neg

    def __len__(self): return len(self.pos)

    def __getitem__(self, idx):
        h = parse_hist(self.hist[idx])
        if len(h) > self.hist_max:
            h = h[-self.hist_max:]
        # +1 shift so 0=PAD
        h = [x+1 for x in h]
        pos = int(self.pos[idx]) + 1
        cidx = country_to_idx(self.country[idx]) if country_map else 0

        # Sample K negatives uniformly (1..n_items), avoiding pos; allow seen in history to keep it simple
        negs = []
        if self.k_neg > 0 and self.n_items > 1:
            import random
            for _ in range(self.k_neg):
                r = random.randint(1, self.n_items)  # inclusive
                while r == pos:
                    r = random.randint(1, self.n_items)
                negs.append(r)
        return {"hist": np.array(h, dtype=np.int64),
                "pos": np.int64(pos),
                "negs": np.array(negs, dtype=np.int64) if negs else np.zeros((0,), dtype=np.int64),
                "country": np.int64(cidx)}

def collate_batch(batch):
    maxL = max((len(x["hist"]) for x in batch), default=1)
    B = len(batch)
    H = np.zeros((B, maxL), dtype=np.int64)  # 0 = PAD
    for i, x in enumerate(batch):
        h = x["hist"]
        if len(h):
            H[i, -len(h):] = h
    pos = np.array([x["pos"] for x in batch], dtype=np.int64)
    country = np.array([x["country"] for x in batch], dtype=np.int64)
    # Negatives as [B, K]
    if len(batch[0]["negs"]) > 0:
        K = len(batch[0]["negs"])
        NE = np.stack([x["negs"] for x in batch], axis=0)
    else:
        K = 0
        NE = np.zeros((B, 0), dtype=np.int64)
    return {"hist": torch.from_numpy(H),
            "pos": torch.from_numpy(pos),
            "negs": torch.from_numpy(NE),
            "country": torch.from_numpy(country)}

train_ds = SeqDataset(seq_train, hist_max=CFG["hist_max"], n_items=n_items, k_neg=CFG["k_neg"])
val_src = seq_val if CFG["eval_sample"] is None else seq_val.sample(min(CFG["eval_sample"], len(seq_val)), random_state=CFG["seed"])
val_ds  = SeqDataset(val_src, hist_max=CFG["hist_max"], n_items=n_items, k_neg=0)  # no negatives needed for eval

train_loader = DataLoader(train_ds, batch_size=CFG["batch_size"], shuffle=True,  num_workers=0, collate_fn=collate_batch)
val_loader   = DataLoader(val_ds,  batch_size=CFG["batch_size"], shuffle=False, num_workers=0, collate_fn=collate_batch)



## Two‑Tower model

- **Item tower:** `Embedding(n_items, d)` + optional linear projection of **text embedding** concatenated and projected to `d`.
- **User tower:** mean pool of recent item embeddings (share item ID embedding weights) + optional **country embedding**, then MLP → `d`.
- **Loss:** in‑batch softmax over item dot‑products (InfoNCE‑style).


In [3]:

class MLP(nn.Module):
    def __init__(self, d_in, d_out, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_out),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_out, d_out),
        )
    def forward(self, x): return self.net(x)

class TwoTower(nn.Module):
    def __init__(self, n_items, n_countries, d=256, text_dim=0, dropout=0.2, norm_user=True, norm_item=True):
        super().__init__()
        self.norm_user = norm_user
        self.norm_item = norm_item

        # ID embeddings (+1 because 0 is PAD)
        self.item_id_emb = nn.Embedding(n_items + 1, d, padding_idx=0)

        # Optional projected text embedding concatenated to ID emb
        self.has_text = text_dim > 0
        if self.has_text:
            self.text_proj = nn.Linear(text_dim, d, bias=False)

        # Country embedding (small) as user context
        self.has_country = n_countries > 0
        if self.has_country:
            self.country_emb = nn.Embedding(n_countries, d // 4)

        # Towers as shallow MLPs (to mirror your Keras Dense towers)
        self.user_tower = MLP(d + (d//4 if self.has_country else 0), d, dropout=dropout)
        self.item_tower = MLP(d + (d if self.has_text else 0), d, dropout=dropout)

        # Fixed/bounded temperature scale for stability
        self.register_buffer("logit_scale", torch.tensor(float(math.log(CFG["fixed_logit_scale"]))))

        nn.init.normal_(self.item_id_emb.weight, std=0.02)
        if self.has_text:
            nn.init.xavier_uniform_(self.text_proj.weight)

    def user_vec(self, hist_item_ids, country_ids=None):
        # hist_item_ids: [B, L], 0 = PAD
        h = self.item_id_emb(hist_item_ids)            # [B,L,d]
        mask = (hist_item_ids != 0).float().unsqueeze(-1)
        h_sum = (h * mask).sum(dim=1)
        lengths = mask.sum(dim=1).clamp_min(1.0)
        h_mean = h_sum / lengths                       # [B,d]
    
        feats = [h_mean]
        if self.has_country:
            if country_ids is None:
                # use zeros when country is not provided
                cemb = torch.zeros(hist_item_ids.size(0),
                                   self.country_emb.embedding_dim,
                                   device=hist_item_ids.device, dtype=h_mean.dtype)
            else:
                cemb = self.country_emb(country_ids)
            feats.append(cemb)
    
        u = self.user_tower(torch.cat(feats, dim=-1))
        if self.norm_user:
            u = F.normalize(u, dim=-1)
        return u


    def item_vec(self, item_ids, item_text=None):
        v = self.item_id_emb(item_ids)                               # [*,d]
        feats = [v]
        if self.has_text and item_text is not None:
            feats.append(self.text_proj(item_text))
        v = torch.cat(feats, dim=-1) if len(feats) > 1 else feats[0]
        v = self.item_tower(v)
        if self.norm_item: v = F.normalize(v, dim=-1)
        return v

    def forward(self, batch, text_bank=None, device=None):
        hist = batch["hist"].to(torch.long).to(device)
        pos  = batch["pos"].to(torch.long).to(device)
        country = batch["country"].to(torch.long).to(device) if self.has_country else None
        negs = batch["negs"].to(torch.long).to(device)               # [B,K] or [B,0]

        u = self.user_vec(hist, country_ids=country)                  # [B,d]

        # Gather positive and negative item vectors (with optional text)
        pos_txt = text_bank[(pos-1).clamp_min(0)] if (self.has_text and text_bank is not None) else None
        v_pos = self.item_vec(pos, item_text=pos_txt)                 # [B,d]

        if negs.numel() > 0:
            # flatten negatives and embed
            negs_flat = negs.reshape(-1)                              # [B*K]
            neg_txt = text_bank[(negs_flat-1).clamp_min(0)] if (self.has_text and text_bank is not None) else None
            v_neg = self.item_vec(negs_flat, item_text=neg_txt).view(negs.size(0), negs.size(1), -1)  # [B,K,d]
        else:
            v_neg = None

        return u, v_pos, v_neg



## Train

We use in‑batch negatives. Start with a couple of epochs and check validation loss.


In [4]:

n_countries = len(country_map) if country_map else 0
model = TwoTower(n_items=n_items, n_countries=n_countries, d=CFG["d_model"],
                 text_dim=(text_dim if has_text else 0),
                 dropout=CFG["dropout"]).to(device)

text_bank = torch.from_numpy(item_text_emb).to(torch.float32).to(device) if has_text else None

opt = AdamW(model.parameters(), lr=CFG["lr"], weight_decay=CFG["weight_decay"])
sched = CosineAnnealingLR(opt, T_max=CFG["epochs"])
use_amp = (device.type == "cuda")
amp_device = "cuda" if use_amp else "cpu"

# New GradScaler (old: torch.cuda.amp.GradScaler)
scaler = torch.amp.GradScaler(amp_device, enabled=use_amp)

@torch.no_grad()
def compute_item_matrix(m):
    ids = torch.arange(1, n_items+1, device=device, dtype=torch.long)
    if has_text and text_bank is not None:
        vecs = m.item_vec(ids, item_text=text_bank)
    else:
        vecs = m.item_vec(ids)
    return F.normalize(vecs, dim=-1).detach()

@torch.no_grad()
def recall_at_k_val(m, K=10, loader=val_loader):
    m.eval()
    item_mat = compute_item_matrix(m)                 # [n_items,d]
    hits = 0; n = 0
    for batch in loader:
        hist = batch["hist"].to(torch.long).to(device)
        pos  = batch["pos"].to(torch.long).to(device)
        country = batch["country"].to(torch.long).to(device) if m.has_country else None
        u = m.user_vec(hist, country_ids=country)     # [B,d]
        scores = u @ item_mat.t()                     # [B,n_items]
        topk = scores.topk(k=min(K, scores.size(1)), dim=1).indices
        pos0 = (pos - 1).clamp_min(0).unsqueeze(1)
        hit = (topk == pos0).any(dim=1).float()
        hits += hit.sum().item(); n += hit.numel()
    return hits / max(1, n)

def nce_loss(u, v_pos, v_neg, scale):
    # u: [B,d], v_pos: [B,d], v_neg: [B,K,d] or None; scale: float
    # logits per example: [1 + K]
    pos_logit = (u * v_pos).sum(dim=-1, keepdim=True)         # [B,1]
    if v_neg is not None and v_neg.numel() > 0:
        # [B,K] = u · v_neg_k
        neg_logit = torch.bmm(v_neg, u.unsqueeze(-1)).squeeze(-1)  # [B,K]
        logits = torch.cat([pos_logit, neg_logit], dim=1)      # [B,1+K]
    else:
        logits = pos_logit                                     # [B,1]
    logits = logits * scale
    labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)  # pos at index 0
    return F.cross_entropy(logits, labels)

def run_epoch(loader, training=True):
    model.train(training)
    total = 0.0; nobs = 0; step = 0
    if training: opt.zero_grad(set_to_none=True)
    for batch in loader:
        with torch.amp.autocast(amp_device, enabled=use_amp):
            u, v_pos, v_neg = model(batch, text_bank=text_bank, device=device)
            loss = nce_loss(u, v_pos, v_neg, scale=CFG["fixed_logit_scale"]) / CFG["accum_steps"]

        if training:
            scaler.scale(loss).backward()
            step += 1
            if step % CFG["accum_steps"] == 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True)

        total += loss.item() * u.size(0) * CFG["accum_steps"]
        nobs  += u.size(0)
    return total / max(1, nobs)

best_recall = -1.0
bad = 0
for ep in range(1, CFG["epochs"]+1):
    tr_ce = run_epoch(train_loader, training=True)
    sched.step()
    val_ce = run_epoch(val_loader,   training=False)  # CE on val (proxy)
    val_rec = recall_at_k_val(model, K=CFG["eval_topk"], loader=val_loader)
    print(f"Epoch {ep:02d} | train CE {tr_ce:.4f} | val CE {val_ce:.4f} | val Recall@{CFG['eval_topk']}: {val_rec:.4f}")
    if val_rec > best_recall + 1e-4:
        best_recall = val_rec; bad = 0
        torch.save(model.state_dict(), OUT_DIR / "two_tower_best.pt")
    else:
        bad += 1
        if bad >= CFG["patience"]:
            print("Early stopping on Recall@K."); break
print("Best val Recall@{}: {:.4f}".format(CFG["eval_topk"], best_recall))


Epoch 01 | train CE 1.9301 | val CE 0.0000 | val Recall@10: 0.0818
Epoch 02 | train CE 1.2600 | val CE 0.0000 | val Recall@10: 0.1234
Epoch 03 | train CE 1.0809 | val CE 0.0000 | val Recall@10: 0.1493
Epoch 04 | train CE 0.9884 | val CE 0.0000 | val Recall@10: 0.1713
Epoch 05 | train CE 0.9293 | val CE 0.0000 | val Recall@10: 0.1807
Epoch 06 | train CE 0.8886 | val CE 0.0000 | val Recall@10: 0.1940
Epoch 07 | train CE 0.8568 | val CE 0.0000 | val Recall@10: 0.1963
Epoch 08 | train CE 0.8307 | val CE 0.0000 | val Recall@10: 0.2030
Epoch 09 | train CE 0.8078 | val CE 0.0000 | val Recall@10: 0.2118
Epoch 10 | train CE 0.7899 | val CE 0.0000 | val Recall@10: 0.2159
Epoch 11 | train CE 0.7721 | val CE 0.0000 | val Recall@10: 0.2178
Epoch 12 | train CE 0.7581 | val CE 0.0000 | val Recall@10: 0.2233
Epoch 13 | train CE 0.7427 | val CE 0.0000 | val Recall@10: 0.2262
Epoch 14 | train CE 0.7290 | val CE 0.0000 | val Recall@10: 0.2292
Epoch 15 | train CE 0.7163 | val CE 0.0000 | val Recall@10: 0.


## Export item vectors

We compute item vectors for all items and save them to disk for indexing.


In [5]:
# --- Export item vectors from best checkpoint ---------------------------------
ckpt_path = OUT_DIR / "two_tower_best.pt"
if ckpt_path.exists():
    model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()
with torch.no_grad():
    ids = torch.arange(1, n_items+1, device=device, dtype=torch.long)
    if has_text:
        vecs = model.item_vec(ids, item_text=text_bank).detach().cpu().numpy()
    else:
        vecs = model.item_vec(ids).detach().cpu().numpy()

# Normalize and save
from numpy.linalg import norm
vecs = vecs / np.maximum(norm(vecs, axis=1, keepdims=True), 1e-6)
np.save(OUT_DIR / "retriever_item_vectors.npy", vecs)
print("Saved item vectors:", vecs.shape)


Saved item vectors: (4446, 512)



## Build ANN index (FAISS if available)

We build a cosine‑similarity (inner‑product on normalized vectors) index. If **FAISS** is not installed, we fall back to a simple brute‑force index using sklearn `NearestNeighbors` and save it with joblib.


In [6]:

# --- Build ANN index ----------------------------------------------------------
item_vecs = np.load(OUT_DIR / "retriever_item_vectors.npy")
# Ensure unit length for inner product ≈ cosine
from numpy.linalg import norm
norms = np.maximum(norm(item_vecs, axis=1, keepdims=True), 1e-6)
item_vecs = item_vecs / norms

index_artifacts = {}

try:
    import faiss
    d = item_vecs.shape[1]
    index = faiss.IndexFlatIP(d)
    index.add(item_vecs.astype(np.float32))
    faiss.write_index(index, str(OUT_DIR / "items.faiss"))
    index_artifacts["type"] = "faiss.IndexFlatIP"
    index_artifacts["path"] = str(OUT_DIR / "items.faiss")
    print("FAISS index saved ->", OUT_DIR / "items.faiss", "| ntotal:", index.ntotal)
except Exception as e:
    print("[warn] FAISS not available or failed:", e)
    print("Falling back to sklearn NearestNeighbors (brute force).")
    from sklearn.neighbors import NearestNeighbors
    import joblib
    nn = NearestNeighbors(metric="cosine", algorithm="brute")
    nn.fit(item_vecs)
    joblib.dump(nn, OUT_DIR / "items_sklearn_nn.joblib")
    np.save(OUT_DIR / "retriever_item_vectors_normed.npy", item_vecs)
    index_artifacts["type"] = "sklearn.NearestNeighbors"
    index_artifacts["path"] = str(OUT_DIR / "items_sklearn_nn.joblib")
    print("Sklearn NN index saved ->", OUT_DIR / "items_sklearn_nn.joblib")

with open(OUT_DIR / "index_meta.json", "w", encoding="utf-8") as f:
    json.dump(index_artifacts, f, indent=2)


FAISS index saved -> /home/kamalyy/KAUST-Project/data/processed/online_retail_II/items.faiss | ntotal: 4446



## Retrieval sanity check

We embed a few validation histories and retrieve top‑10 items to ensure the pipeline runs end‑to‑end.


In [7]:

# --- Retrieval sanity check ---------------------------------------------------
def recommend_from_hist(hist_ids, topk=10, default_country="United Kingdom"):
    model.eval()
    with torch.no_grad():
        # +1 shift so 0 = PAD
        h = torch.tensor([[x+1 for x in hist_ids[-CFG["hist_max"]:]]],
                         dtype=torch.long, device=device)

        # provide a country id if the model uses country
        if getattr(model, "has_country", False):
            cidx = torch.tensor([country_map.get(default_country, 0)],
                                dtype=torch.long, device=device)
            u = model.user_vec(h, country_ids=cidx).detach().cpu().numpy()
        else:
            u = model.user_vec(h).detach().cpu().numpy()

    u = u / max(np.linalg.norm(u), 1e-6)
    try:
        import faiss
        index = faiss.read_index(str(OUT_DIR / "items.faiss"))
        D, I = index.search(u.astype(np.float32), topk)
        return I[0].tolist()
    except Exception:
        import joblib
        from sklearn.neighbors import NearestNeighbors
        nn = joblib.load(OUT_DIR / "items_sklearn_nn.joblib")
        dist, nbrs = nn.kneighbors(u, n_neighbors=topk)
        return [int(j) for j in nbrs[0].tolist()]


# Try on a small sample
sample = seq_val.sample(5, random_state=0)
for _, row in sample.iterrows():
    hist = [int(x) for x in row['history_idx'].split()] if isinstance(row['history_idx'], str) else []
    recs = recommend_from_hist(hist, topk=10)
    print("GT:", int(row['pos_item_idx']), "| Recs:", recs[:10])


GT: 1673 | Recs: [390, 1136, 1312, 389, 1038, 1135, 1358, 1039, 236, 1555]
GT: 1387 | Recs: [1444, 1394, 1377, 1423, 1446, 1421, 1445, 1461, 1359, 1393]
GT: 722 | Recs: [1038, 1242, 1039, 1012, 1248, 1135, 1139, 1255, 1137, 814]
GT: 3848 | Recs: [3931, 3884, 3879, 3858, 3880, 3870, 3868, 3869, 3783, 3797]
GT: 2115 | Recs: [3884, 3797, 3866, 3783, 3870, 3931, 3796, 3808, 3807, 3910]



## Save model weights


In [8]:

# --- Save model ---------------------------------------------------------------
torch.save(model.state_dict(), OUT_DIR / "two_tower.pt")
with open(OUT_DIR / "two_tower_config.json", "w", encoding="utf-8") as f:
    json.dump(CFG, f, indent=2)
print("Saved model ->", OUT_DIR / "two_tower.pt")


Saved model -> /home/kamalyy/KAUST-Project/data/processed/online_retail_II/two_tower.pt


In [10]:
# --- Save Embeddings for Notebook 04 ---
ITEM_MAP_PATH = OUT_DIR / 'item_id_map.parquet'
CUSTOMER_MAP_PATH = OUT_DIR / 'customer_id_map.parquet'
customer_map = pd.read_parquet(OUT_DIR / 'customer_id_map.parquet', engine="fastparquet")
item_map = pd.read_parquet(OUT_DIR / 'item_id_map.parquet', engine="fastparquet")

print("\nSaving embeddings for notebook 04...")

model.eval()
with torch.no_grad():
    # Save item embeddings with text features
    all_item_ids = torch.arange(1, len(item_map)+1, device=device, dtype=torch.long)
    
    if has_text and item_text_emb is not None:
        # Use actual text features
        text_bank = torch.from_numpy(item_text_emb).to(device)
        item_embeddings = model.item_vec(all_item_ids, item_text=text_bank).cpu().numpy()
    else:
        # Create dummy text features
        dummy_text_dim = 512  # or whatever dimension the model expects
        text_bank = torch.zeros(len(item_map), dummy_text_dim, device=device)
        item_embeddings = model.item_vec(all_item_ids, item_text=text_bank).cpu().numpy()
    
    np.save(OUT_DIR / 'item_embeddings.npy', item_embeddings)
    
    # Save user embeddings by computing them from sequences
    user_embeddings = []
    user_ids = []
    
    # Get all unique users from sequences
    all_users = pd.concat([seq_train, seq_val, seq_test])['user_idx'].unique()
    
    for user_idx in all_users:
        # Get user's history from sequences
        user_seqs = seq_train[seq_train['user_idx'] == user_idx]
        if len(user_seqs) == 0:
            user_seqs = seq_val[seq_val['user_idx'] == user_idx]
        if len(user_seqs) == 0:
            user_seqs = seq_test[seq_test['user_idx'] == user_idx]
        
        if len(user_seqs) > 0:
            # Use the most recent history
            latest_seq = user_seqs.iloc[-1]
            hist_str = latest_seq['history_idx']
            hist_items = parse_hist(hist_str)
            
            if len(hist_items) > 0:
                # Convert to tensor format (+1 shift, 0=PAD)
                hist_tensor = torch.tensor([[x+1 for x in hist_items[-CFG["hist_max"]:]]], 
                                         dtype=torch.long, device=device)
                
                # Compute user embedding
                user_emb = model.user_vec(hist_tensor).cpu().numpy()
                user_embeddings.append(user_emb[0])  # Remove batch dimension
                user_ids.append(user_idx)
    
    # Convert to numpy array
    user_embeddings = np.array(user_embeddings)
    
    # Create full user embedding matrix (fill missing users with zeros)
    full_user_embeddings = np.zeros((len(customer_map), user_embeddings.shape[1]))
    for i, user_idx in enumerate(user_ids):
        full_user_embeddings[user_idx] = user_embeddings[i]
    
    np.save(OUT_DIR / 'user_embeddings.npy', full_user_embeddings)

print(f"Saved user embeddings: {full_user_embeddings.shape}")
print(f"Saved item embeddings: {item_embeddings.shape}")
print("✅ Embeddings ready for notebook 04!")


Saving embeddings for notebook 04...
Saved user embeddings: (5748, 512)
Saved item embeddings: (4446, 512)
✅ Embeddings ready for notebook 04!



### Next
Proceed to **04_ranker_and_eval.ipynb** to train a DIN/MLP ranker that scores retrieved candidates using richer features (country, recency, popularity, price buckets, and optionally text vectors).
