# 03: Two-Tower Retriever

This notebook implements and trains a Two-Tower retrieval model:

1. **Configuration & Imports**: set hyperparameters and load libraries
2. **Data Loading**: read sequences and maps from preprocessed data
3. **Model Architecture**: define user and item towers
4. **Dataset & DataLoader**: prepare PyTorch datasets and collators
5. **Initialize & Compile**: instantiate model, optimizer, scheduler
6. **Training Loop**: train with in-batch negatives and early stopping
7. **Save Embeddings**: export trained user/item embeddings
8. **Evaluation & Comparison**: compute Recall@K and compare to baselines
9. **Results Persistence**: save training logs and final metrics

## 1. Configuration & Imports

- Define paths for input/output
- Set model, training, and optimization parameters

In [1]:
import os, json
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import warnings
warnings.filterwarnings('ignore')

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Paths
OUT_DIR = Path('../data/processed/jarir/')
SEQ_TRAIN = OUT_DIR / 'sequences_train.parquet'
SEQ_VAL   = OUT_DIR / 'sequences_val.parquet'
ITEM_MAP   = OUT_DIR / 'item_id_map.parquet'
CUST_MAP   = OUT_DIR / 'customer_id_map.parquet'
BASELINE_RESULTS = OUT_DIR / 'baseline_results.json'

# Model & Training Config
CFG = {
    'd_model':256,
    'batch_size':512,
    'accum_steps':2,
    'epochs':50,
    'patience':6,
    'lr':5e-4,
    'weight_decay':1e-4,
    'dropout':0.2,
    'eval_topk':10,
    'seed':42,
    'k_neg':50,
    'fixed_logit_scale':10.0
}
# Reproducibility
torch.manual_seed(CFG['seed'])
np.random.seed(CFG['seed'])

## 2. Data Loading

- Load train/val sequences and id maps
- Print dataset sizes

In [2]:
print("Loading data...")
seq_train = pd.read_parquet(SEQ_TRAIN, engine='fastparquet')
seq_val   = pd.read_parquet(SEQ_VAL,   engine='fastparquet')
item_map  = pd.read_parquet(ITEM_MAP, engine='fastparquet')
cust_map  = pd.read_parquet(CUST_MAP, engine='fastparquet')
print(f"Train sequences: {len(seq_train)}")
print(f"Val   sequences: {len(seq_val)}")
print(f"# items: {len(item_map)}, # users: {len(cust_map)}")

# Load baseline results for comparison later
with open(BASELINE_RESULTS) as f:
    baseline_results = json.load(f)['held_out_interactions']

Loading data...
Train sequences: 1108
Val   sequences: 169
# items: 1735, # users: 929


## 3. Two-Tower Model Architecture

- Two MLP towers with embedding + LayerNorm + ReLU + Dropout
- Normalize final vectors for cosine similarity scoring

In [3]:
class TwoTowerModel(nn.Module):
    def __init__(self, n_users, n_items, d_model, dropout):
        super().__init__()
        # Embeddings
        self.user_emb = nn.Embedding(n_users, d_model)
        self.item_emb = nn.Embedding(n_items, d_model)
        # MLP towers
        def make_tower():
            return nn.Sequential(
                nn.Linear(d_model, d_model), nn.LayerNorm(d_model), nn.ReLU(), nn.Dropout(dropout),
                nn.Linear(d_model, d_model), nn.LayerNorm(d_model), nn.ReLU(), nn.Dropout(dropout)
            )
        self.user_mlp = make_tower()
        self.item_mlp = make_tower()
        # init
        nn.init.normal_(self.user_emb.weight, 0, 0.1)
        nn.init.normal_(self.item_emb.weight, 0, 0.1)
    def user_vec(self, uids):
        u = self.user_emb(uids); return F.normalize(self.user_mlp(u), p=2, dim=1)
    def item_vec(self, iids):
        v = self.item_emb(iids); return F.normalize(self.item_mlp(v), p=2, dim=1)
    def forward(self, uids, pos_i, neg_i=None):
        u_vec = self.user_vec(uids)
        pos_vec = self.item_vec(pos_i)
        pos_scores = (u_vec * pos_vec).sum(1) * CFG['fixed_logit_scale']
        if neg_i is None:
            return pos_scores
        neg_vecs = self.item_vec(neg_i.view(-1)).view(uids.size(0), -1, u_vec.size(-1))
        neg_scores = torch.sum(u_vec.unsqueeze(1) * neg_vecs, dim=2) * CFG['fixed_logit_scale']
        return pos_scores, neg_scores

## 4. Dataset & DataLoader

- Custom `Dataset` sampling `k_neg` negatives per example
- Collate with padding for variable-length history

In [4]:
class SequenceDataset(Dataset):
    def __init__(self, df, n_items, k_neg):
        self.df = df; self.n_items = n_items; self.k_neg = k_neg
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        u = torch.tensor(int(row['user_idx']), dtype=torch.long)
        pos = torch.tensor(int(row['pos_item_idx']), dtype=torch.long)
        hist = []
        if pd.notna(row['history_idx']): hist = [int(x) for x in row['history_idx'].split()]
        # sample negatives
        forbid = set(hist + [int(pos)])
        avail = list(set(range(self.n_items)) - forbid)
        neg = np.random.choice(avail, size=self.k_neg, replace=len(avail)<self.k_neg)
        return {'user':u, 'pos':pos, 'neg':torch.tensor(neg, dtype=torch.long), 'hist':torch.tensor(hist, dtype=torch.long)}

from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    users = torch.stack([b['user'] for b in batch])
    pos   = torch.stack([b['pos'] for b in batch])
    neg   = torch.stack([b['neg'] for b in batch])
    hists = pad_sequence([b['hist'] for b in batch], batch_first=True, padding_value=-1)
    mask  = (hists >= 0)
    return {'user':users, 'pos':pos, 'neg':neg, 'hist':hists, 'mask':mask}

# Create loaders
train_ds = SequenceDataset(seq_train, len(item_map), CFG['k_neg'])
val_ds   = SequenceDataset(seq_val,   len(item_map), CFG['k_neg'])
train_loader = DataLoader(train_ds, batch_size=CFG['batch_size'], shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=CFG['batch_size'], shuffle=False, collate_fn=collate_fn)
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

Train batches: 3, Val batches: 1


## 5. Initialize Model, Optimizer & Scheduler

- Use `AdamW` and `ReduceLROnPlateau` on Recall@K

In [5]:
model = TwoTowerModel(len(cust_map), len(item_map), CFG['d_model'], CFG['dropout']).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

# Loss: in-batch negative sampling

def in_batch_loss(pos_s, neg_s):
    logits = torch.cat([pos_s.unsqueeze(1), neg_s], dim=1)
    labels = torch.zeros(pos_s.size(0), dtype=torch.long, device=device)
    return F.cross_entropy(logits, labels)

## 6. Training Loop

- Train for up to `epochs`, early stop on validation Recall@K
- Track training loss and validation Recall

In [6]:
best_recall = 0.0; patience_counter = 0
train_losses = []; val_recalls = []
for epoch in range(1, CFG['epochs']+1):
    model.train(); total_loss = 0.0
    for i, batch in enumerate(train_loader, 1):
        u = batch['user'].to(device)
        p = batch['pos'].to(device)
        n = batch['neg'].to(device)
        pos_s, neg_s = model(u, p, n)
        loss = in_batch_loss(pos_s, neg_s) / CFG['accum_steps']
        loss.backward()
        if i % CFG['accum_steps'] == 0:
            optimizer.step(); optimizer.zero_grad()
        total_loss += loss.item() * CFG['accum_steps']
    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)
    # Validation
    model.eval(); hits = []
    with torch.no_grad():
        for batch in val_loader:
            u = batch['user'].to(device)
            p = batch['pos'].to(device)
            u_vec = model.user_vec(u)
            all_items = torch.arange(len(item_map), device=device)
            i_vecs = model.item_vec(all_items)
            scores = u_vec @ i_vecs.T
            topk = scores.topk(CFG['eval_topk'], dim=1).indices
            hits.append((topk == p.unsqueeze(1)).any(1).float().mean().item())
    val_recall = np.mean(hits); val_recalls.append(val_recall)
    scheduler.step(val_recall)
    print(f"Epoch {epoch}: Loss={avg_loss:.4f}, Recall@{CFG['eval_topk']}={val_recall:.4f}")
    if val_recall > best_recall:
        best_recall = val_recall; patience_counter = 0
        torch.save(model.state_dict(), OUT_DIR/'best_twotower_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= CFG['patience']:
            print("Early stopping."); break
print(f"Best validation Recall@{CFG['eval_topk']}: {best_recall:.4f}")

Epoch 1: Loss=4.0241, Recall@10=0.0059
Epoch 2: Loss=3.9228, Recall@10=0.0178
Epoch 3: Loss=3.8927, Recall@10=0.0296
Epoch 4: Loss=3.8334, Recall@10=0.0533
Epoch 5: Loss=3.8122, Recall@10=0.0828
Epoch 6: Loss=3.7632, Recall@10=0.0947
Epoch 7: Loss=3.7373, Recall@10=0.1065
Epoch 8: Loss=3.6646, Recall@10=0.1065
Epoch 9: Loss=3.6348, Recall@10=0.1065
Epoch 10: Loss=3.5969, Recall@10=0.1065
Epoch 11: Loss=3.5321, Recall@10=0.1065
Epoch 12: Loss=3.4783, Recall@10=0.1065
Epoch 13: Loss=3.5307, Recall@10=0.1124
Epoch 14: Loss=3.5071, Recall@10=0.1124
Epoch 15: Loss=3.4558, Recall@10=0.1124
Epoch 16: Loss=3.4055, Recall@10=0.1124
Epoch 17: Loss=3.4233, Recall@10=0.1124
Epoch 18: Loss=3.3313, Recall@10=0.1124
Epoch 19: Loss=3.3434, Recall@10=0.1124
Early stopping.
Best validation Recall@10: 0.1124


## 7. Save Embeddings

- Generate and save all user and item embeddings

In [7]:
model.load_state_dict(torch.load(OUT_DIR/'best_twotower_model.pth'))
model.eval()
with torch.no_grad():
    users = torch.arange(len(cust_map), device=device)
    items = torch.arange(len(item_map), device=device)
    u_emb = model.user_vec(users).cpu().numpy()
    i_emb = model.item_vec(items).cpu().numpy()
    np.save(OUT_DIR/'user_embeddings.npy', u_emb)
    np.save(OUT_DIR/'item_embeddings.npy', i_emb)
print(f"Saved embeddings shapes: {u_emb.shape}, {i_emb.shape}")

Saved embeddings shapes: (929, 256), (1735, 256)


## 8. Evaluation & Comparison

- Reload best model and compute final Recall@K
- Compare to baselines loaded earlier

In [8]:
# Final eval on validation (or test if available)
model.eval(); hits = []
with torch.no_grad():
    for batch in val_loader:
        u=batch['user'].to(device); p=batch['pos'].to(device)
        u_vec=model.user_vec(u); scores=u_vec @ i_vecs.T
        topk=scores.topk(CFG['eval_topk'],dim=1).indices
        hits.append((topk==p.unsqueeze(1)).any(1).float().mean().item())
final_recall = np.mean(hits)
recall_key = f"Recall@{CFG['eval_topk']}"
best_baseline = max(
    baseline_results.items(),
    key=lambda x: x[1][recall_key]
)[1][recall_key]
print(f"Two-Tower Recall@{CFG['eval_topk']}: {final_recall:.4f}")
print(f"Best baseline    : {best_baseline:.4f}")

Two-Tower Recall@10: 0.1124
Best baseline    : 0.0947


## 9. Save Results

- Persist model config, training curves, and comparison metrics to JSON

In [9]:
results = {
    'model_config': CFG,
    'training': {'train_losses': train_losses, 'val_recalls': val_recalls},
    'best_recall': best_recall,
    'final_recall': final_recall,
    'baseline_best_recall': best_baseline
}
with open(OUT_DIR/'twotower_results.json','w') as f:
    json.dump(results, f, indent=2)
print("Saved Two-Tower results to twotower_results.json")

Saved Two-Tower results to twotower_results.json
