In [6]:
import random
from typing import Dict, List, Sequence, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Feature Engineering

Feature engineering underpins every two-tower recommender, since each tower must transform raw user/item signals into embeddings that maximize similarity for true matches and minimize it for non-matches. In practice that means curating heterogeneous feature types—categorical IDs (users, items, genres, locales), numerical stats (recency, dwell time, purchase counts), textual or image embeddings harvested from pretrained models, and contextual signals like device, time of day, or geo buckets. Two-tower losses rely entirely on these engineered vectors; weak or noisy features collapse the embedding space, while thoughtful encodings (e.g., hashing high-cardinality IDs, normalizing numerics, sharing semantic vocabularies) bring enough structure for the towers to learn. Because user and item towers are trained separately yet paired only through a dot product, the completeness, consistency, and scale of engineered features directly control retrieval quality—so investing in robust pipelines for sourcing, cleaning, scaling, and regularly updating these feature sets is critical for a performant recommendation system.

## ID Features
IDs including User and Item IDs can be converted to embeddings. Typical dimensions for IDs:
* User ID -> 16 - 128 dims
* Item ID -> 32 - 256 dims

## Categorical Features
Category features can be converted to embeddings. Common categorical features include category, brand, country, device_type, gender, seller, shop, etc. Example dimensions:
* category: 8 - 32
* brand: 8 - 32
* country: 4 - 8

## Numerical Features
Numerical or ordinal features can be normalised. Common normalization include: min-max, Z-score, Log-transform for heavy-tailed features. After preprocessing, the numerical features can be concatenated with other feature embeddings and sent to the model.

```
# After preprocessing
x = torch.cat([user_id_emb, category_emb, torch.tensor([norm_price])], dim=-1)
```

## Multi-hot Features
Some features are better coded as Multi-hot, such as long-term user preference. More example features of this type includes:
* Categories the user clicked in the past 30 days
* Brands the user purchased frequently
* Tags associated with an item
* Interests inferred from long-term behaviour

These features can be converted to Embedding Bag.

```
# Define embedding bag
self.user_hist_cat_emb = nn.EmbeddingBag(n_categories, d, mode="mean")
# Usage of embedding bag to create a behavioral embedding for the user
hist_emb = self.user_hist_cat_emb(hist_category_ids, offsets)
```

## Sequencial Features

Many features require a sequence model to encode, such as test description or short-term user behaviours. User behaviours is dynamic. The last few clicks strongly influence short-term interests. Last few impressions help track current intent. Order matters (“running shoes → socks → sports watch”). While long-term user preferences can be modelled using multi-hot embedding, short-term bahaviours such as last 10 minutes clicked, last 20 queries, or last 50 video watches, require a sequence model. Common sequence model are 
* Transformers (most common today)
    - capture order
	- capture dependencies between items
	- powerful for session-based signals

* GRU/LSTM (cheaper, still used)
    - excellent trade-off between cost and quality
    - widely used in Amazon/Alibaba retrieval stacks

* 1D CNN (fast approximation)
    - surprising effective for local sequences

Sequential features are often used in production retrieval systems to capture user intent drift and sudden preference shifts.

In [None]:
# Example: User History Encoder using GRU
class UserHistoryEncoder(nn.Module):
    def __init__(self, n_items, emb_dim=32, hidden_dim=64):
        super().__init__()
        self.item_emb = nn.Embedding(n_items, emb_dim)
        self.gru = nn.GRU(emb_dim, hidden_dim, batch_first=True)

    def forward(self, history_item_ids):
        # history_item_ids: [B, T]
        x = self.item_emb(history_item_ids)  # [B, T, emb_dim]
        _, h = self.gru(x)                   # h: [1, B, hidden_dim]
        return h.squeeze(0)                  # [B, hidden_dim]

The output of the `UserHistoryEncoder` can be concatenated with other features.

```
user_vec = concat(
    user_id_emb,
    country_emb,
    device_emb,
    long_term_pooled_emb,
    short_term_sequence_emb
)
```

## Text Tower / Text Embedding

Text-based signals are dominant in many systems:
* product title
* product description
* movie synopsis
* document content
* app category description
* video captions

The industry practice are:
* DistilBERT is common
* Tiny-BERT or MiniLM for speed
* Sometimes use precomputed embeddings shared across towers

There are two main approaches:

| Approach | Description | Pros | Cons |
| --- | --- | --- | --- |
| Precompute text embeddings | Generate all item text embeddings offline, store them as item features, and have the ItemTower simply load them. | Fast inference; text encoder cost removed from serving graph. | No end-to-end training, so text features cannot adapt during tower training. |
| Fine-tune text encoder | Embed a trainable text encoder directly inside the ItemTower to produce item embeddings on the fly. | Better retrieval quality; embeddings adapt to domain; text semantics align with user behavior. | Training becomes expensive; embedding updates force offline re-indexing of the item catalog. |



In [5]:
# Example: text Encoder using simple averaging
from transformers import DistilBertModel

class ItemTextEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.text_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased")

    def forward(self, input_ids, attention_mask):
        outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0, :]   # [CLS] embedding

  from .autonotebook import tqdm as notebook_tqdm


ImportError: dlopen(/Users/wtai/anaconda3/envs/two_tower/lib/python3.12/site-packages/torch/_C.cpython-312-darwin.so, 0x0002): Symbol not found: __ZN4absl12lts_2025012712log_internal10LogMessagelsIiTnNSt3__19enable_ifIXntsr4absl16HasAbslStringifyIT_EE5valueEiE4typeELi0EEERS2_RKS6_
  Referenced from: <F0CE594F-5059-3403-BEDE-CC2EF3170AD7> /Users/wtai/anaconda3/envs/two_tower/lib/libprotobuf.29.3.0.dylib
  Expected in:     <621B4947-F73F-3962-8DDB-2484D6B77411> /Users/wtai/anaconda3/envs/two_tower/lib/libabsl_log_internal_message.2501.0.0.dylib

## Feature Interactions

Deep & Wide tower. While MLP are the deep part, feature interactions are the wide part. Feature interactions can capture highly predictive interaction that MLP alone struggle to discovery. Feature interactions are the manually created crosses. Examples of wide crosses include:
* country x device_type
* category x price_bucket
* brand x gender

# Two-Towers Model

In [None]:
n_users = 10000
n_items = 50000
n_countries = 50
n_categories = 200
emb_dim = 32 # embedding dimension per feature
tower_dim = 64 # output dimension of each tower

In [None]:
# ---------------------------
# 2. Define the two towers
# ---------------------------
class UserTower(nn.Module):
    def __init__(self):
        super().__init__()
        self.user_emb = nn.Embedding(n_users, emb_dim)
        self.country_emb = nn.Embedding(n_countries, emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim * 2, tower_dim),
            nn.ReLU(),
            nn.Linear(tower_dim, tower_dim)
        )
    
    def forward(self, user_id, country_id):
        """
        Perform forward pass of user tower.

        Args:
            user_id: Tensor of shape (batch_size,)
            country_id: Tensor of shape (batch_size,)
        """
        u_id = self.user_emb(user_id) # (batch_size, emb_dim)
        u_country = self.country_emb(country_id) # (batch_size, emb_dim)
        x = torch.cat([u_id, u_country], dim=-1) # (batch_size, emb_dim * 2)
        out = self.mlp(x) # (batch_size, tower_dim)
        # optional: L2-normalise
        out = F.normalize(out, p=2, dim=-1)
        return out

class ItemTower(nn.Module):
    def __init__(self):
        super().__init__()
        self.item_emb = nn.Embedding(n_items, emb_dim)
        self.cat_emb = nn.Embedding(n_categories, emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim * 2, tower_dim),
            nn.ReLU(),
            nn.Linear(tower_dim, tower_dim)
        )
    
    def forward(self, item_id, category_id):
        """
        Perform forward pass of item tower for the batch of items. The items contains both
        positive and negative samples.

        Args:
            item_id: Tensor of shape (batch_size, N)
            category_id: Tensor of shape (batch_size, N)
        """
        i_id = self.item_emb(item_id) # (batch_size, N, emb_dim)
        i_cat = self.cat_emb(category_id) # (batch_size, N, emb_dim)
        x = torch.cat([i_id, i_cat], dim=-1) # (batch_size, N, emb_dim * 2)
        out = self.mlp(x) # (batch_size, tower_dim)
        # optional: L2-normalise
        out = F.normalize(out, p=2, dim=-1)
        return out

In [None]:
class TwoTowerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.user_tower = UserTower()
        self.item_tower = ItemTower()
    
    def forward(self, user_id, country_id, item_id, category_id):
        user_vec = self.user_tower(user_id, country_id) # (batch_size, tower_dim)
        item_vec = self.item_tower(item_id, category_id) # (batch_size, N, tower_dim)
        # Expand user_vec to match item_vec shape
        user_exp_vec = user_vec.unsqueeze(1).expand(-1, item_vec.size(1), -1) # (batch_size, 1, tower_dim)
        # Compute dot product similarity
        logits = torch.sum(user_exp_vec * item_vec, dim=-1) # (batch_size,)
        return logits

# Negative Sampling Dataset

In [None]:
# Negative sampling dataset
class TwoTowerNegSamplingDataset(Dataset):
    def __init__(self, 
                 interactions: Sequence[Tuple], # Sequence of (user_id, pos_item_id)
                 user_to_hard_negs: Dict[int, List[int]],
                 all_item_ids: Sequence[int], # Catalog of all item IDs
                 num_hard_negs: int = 2, # number of hard negatives per positive
                 num_random_negs: int = 2, # number of random negatives per positive
                 user_to_pos_items: Dict[int, List[int]] = None # optional mapping of user to their positive items
    ):
        """
        Dataset for two-tower model with negative sampling (1 positive + hard negatives + random negatives).
        Args:
            interactions: List of (user_id, pos_item_id) tuples.
            user_to_hard_negs: Dict mapping user_id to list of hard negative item_ids.
            all_item_ids: Sequence of all item IDs in the catalog.
            num_hard_negs: Number of hard negatives to sample per positive.
            num_random_negs: Number of random negatives to sample per positive.
            user_to_pos_items: Optional dict mapping user_id to their positive item_ids to avoid sampling positives as negatives.
        """
        self.interactions = list(interactions)
        self.user_to_hard_negs = user_to_hard_negs
        self.all_item_ids = all_item_ids
        self.num_hard_negs = num_hard_negs
        self.num_random_negs = num_random_negs
        self.user_to_pos_items = user_to_pos_items or self._build_user_pos_items()
        # Convert all_item_ids to a set for faster lookup
        self.all_item_id_set = set(all_item_ids)
    
    def _build_user_pos_items(self) -> Dict[int, List[int]]:
        user_pos = {}
        for u_id, i_id in self.interactions:
            user_pos.setdefault(u_id, []).append(i_id)
        return user_pos

    def __len__(self):
        return len(self.interactions)

    def _sample_hard_negs(self, user_id: int, pos_item_id: int) -> List[int]:
        hard_negs = self.user_to_hard_negs.get(user_id, [])
        # Remove the positive items from hard negatives
        filtered = [i for i in hard_negs if i != pos_item_id]
        if not filtered:
            return []
        if len(filtered) <= self.num_hard_negs:
            # If not enough hard negatives, return all available
            return filtered
        else:
            # otherwise, randomly sample the required number
            sampled = random.sample(filtered, min(self.num_hard_negs, len(filtered)))
            return sampled
    
    def _sample_random_negs(self, user_id: int, pos_item_id: int) -> List[int]:
        """
        Sample random negatives ensuring they are not in user's positive items.
        - Randomly sample from all_item_ids until we have enough valid negatives.
        - If a sampled item is already a positive for the user, drop it and sample again.
        """
        user_pos_items = set(self.user_to_pos_items.get(user_id, [])) # Filter in-batch negatives
        negatives = []
        attempts = 0
        max_attempts = self.num_random_negs * 10  # to avoid infinite loop
        while len(negatives) < self.num_random_negs and attempts < max_attempts:
            sampled_item = random.choice(self.all_item_ids)
            if sampled_item != pos_item_id and sampled_item not in user_pos_items:
                negatives.append(sampled_item)
            attempts += 1
        return negatives

    def __getitem__(self, idx):
        """
        For each interaction, return:
        - one positive interaction (user_id, pos_item_id)
        - num_hard_negs hard negative item_ids
        - num_random_negs random negative item_ids

        Returns:
            user_id: int
            pos_item_id: int
        """
        user_id, pos_item_id = self.interactions[idx]

        # 1) Positive
        item_ids = [pos_item_id]
        labels = [1.0]

        # 2) Hard negatives
        hard_negs = self._sample_hard_negs(user_id, pos_item_id)
        item_ids.extend(hard_negs)
        labels.extend([0.0] * len(hard_negs))

        # 3) Random negatives
        random_negs = self._sample_random_negs(user_id, pos_item_id)
        item_ids.extend(random_negs)
        labels.extend([0.0] * len(random_negs))

        # 4) convert to tensors
        user_id_tensor = torch.tensor(user_id, dtype=torch.long)
        item_ids_tensor = torch.tensor(item_ids, dtype=torch.long)
        labels_tensor = torch.tensor(labels, dtype=torch.float32)

        return {
            'user_id': user_id_tensor, # Scalar tensor
            'item_ids': item_ids_tensor, # Shape: (1 + num_hard_negs + num_random_negs,)
            'labels': labels_tensor # Shape: (1 + num_hard_negs + num_random_negs,)
        }


# Model Train

In [None]:
model = TwoTowerModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.BCEWithLogitsLoss()

In [None]:
dataset = TwoTowerNegSamplingDataset(
    interactions=[(0, 10), (1, 20), (2, 30)], # example interactions
    user_to_hard_negs={
        0: [11, 12, 13],
        1: [21, 22, 23],
        2: [31, 32, 33]
    },
    all_item_ids=list(range(n_items)),
    num_hard_negs=2,
    num_random_negs=2
)

## Loss Function

The training steps consist of a forward pass, a backward pass and the adjustment.

The Binary Cross Entropy loss function is used here:
$$
\mathcal{L}_{\text{BCE}}(y, \hat{y})
= - \left[ y \log(\hat{y}) + (1 - y) \log(1 - \hat{y}) \right]
$$

Where:
* y = true label
* $\hat{y}$ = predicted probability
* C = number of classes
* $y_i \in \{0,1\}$ is 1 for the correct class
* $\hat{y}_i$ = predicted probability for class i

For a multiclass case it is

$$
\mathcal{L}_{\text{CE}}(y, \hat{y})
= - \sum_{i=1}^{C} y_i \log(\hat{y}_i)
$$

Another option for loss function is contrastive loss. In contrastive learning, every other user's positive item in the batch becomes a negative automatically. This eliminate the need for explicit negative sampling for the loss. The contrastive loss drive the embeddings to be closer to the positive items and to be further from the negative items.

For a batch of size B:
* User embeddings: $u_1, u_2, ..., u_B$
* Positive item embeddings: $v_1, v_2, ..., v_B$

For user i, we want:
* $u_i$ close to $v_i$ (positive)
* $u_i$ far from $v_j$ where $j ≠ i$ (negatives)

The loss for a single user is:

$$
L_i = - log ( \frac{\exp(u_i \cdot v_i / \tau)}{\sum_j \exp(u_i \cdot v_j / \tau)} )
$$

In [None]:
def contrastive_loss(u, v, temperature=0.1):
    """
    Compute contrastive loss between user embeddings u and item embeddings v.

    Args:
        u: Tensor of shape (batch_size, tower_dim) - user embeddings
        v: Tensor of shape (batch_size, tower_dim) - item embeddings (positive pairs)
        temperature: scaling factor for logits
    """
    # Normalise to unit sphere
    u = F.normalize(u, p=2, dim=-1)
    v = F.normalize(v, p=2, dim=-1)

    # Similarity matrix: [batch_size, batch_size]
    logits = u @ v.T / temperature

    # labels: each row i has positive at position i
    labels = torch.arange(len(u)).to(u.device)

    # classic InfoNCE loss
    return F.cross_entropy(logits, labels)

## Training Cycles

In [None]:
loader = DataLoader(dataset, batch_size=32, shuffle=True) # DataLoader for batching

for batch in loader:
    user_ids = batch['user_id']          # (B,)
    item_ids = batch['item_ids']         # (B, N)
    labels = batch['labels']             # (B, N)

    logits = model(
        user_ids,
        torch.zeros_like(user_ids), # dummy country_id
        item_ids,
        torch.zeros_like(item_ids)  # dummy category_id
    )  # (B, N)
    loss = loss_fn(logits.view(-1), labels.view(-1))

    optimizer.zero_grad()
    loss.backward() # calculate gradients
    optimizer.step() # update parameters


Many production models combine Contrastive loss (in-batch negatives, no sampler) + BCE loss (explicit negatives including impressions, category-level negatives). While contrastive loss gives global structure, BCE injects real negative signals based on impression. The total loss is written as 

$$
loss = λ * contrastive_loss + (1 - λ) * BCE_loss
$$

# Precompute Item Embeddings

In [None]:
model.eval() # set model to eval mode
user_tower = model.user_tower
item_tower = model.item_tower

In [None]:
item_ids_all = torch.arange(n_items) # (n_items,)

with torch.no_grad():
    def compute_item_embeddings(item_ids, batch_size = 1024) -> torch.Tensor:
        """
        Precompute item embeddings for all items in the catalog.
        Returns:
            item_embeddings: Tensor of shape (n_items, batch_size, tower_dim)
        """
        embs = []
        for i in range(0, len(item_ids), batch_size):
            batch = item_ids[i:i+batch_size]
            # reshape to (batch_size, 1)
            batch = batch.unsqueeze(1)
            v = item_tower(batch)
            v = v.squeeze(1) # (batch_size, tower_dim)
            embs.append(v)
        return torch.cat(embs, dim=0) # (n_items, tower_dim)
    
    item_embeddings = compute_item_embeddings(item_ids_all) # (n_items, tower_dim)

# Retrieval

Two-tower at inference: retrieval with precomputed item embedding

In [None]:
import faiss # FAISS library for efficient similarity search

In [None]:
# Build FAISS index
d = item_embeddings.shape[1]
index = faiss.IndexFlatIP(d) # inner product index
index.add(item_embeddings.cpu().numpy()) # add item embeddings to index

In [None]:
# Recommend for a given user
def recommend_for_user(user_id: int, country_id: int, top_k: int = 10) -> List[Tuple[int, float]]:
    """
    Recommend top-k items for the given user using FAISS index.
    Args:
        user_id: int
        country_id: int
        top_k: int
    Returns:
        List of (item_id, score) tuples
    """
    user_id_tensor = torch.tensor([user_id], dtype=torch.long)
    country_id_tensor = torch.tensor([country_id], dtype=torch.long)
    with torch.no_grad():
        user_vec = user_tower(user_id_tensor, country_id_tensor) # (1, tower_dim)
        user_vec_np = user_vec.cpu().numpy().astype('float32') # FAISS requires float32
        # Search in FAISS index
        D, I = index.search(user_vec_np, top_k) # D: distances, I: indices
        recommendations = [(int(item_id), float(score)) for item_id, score in zip(I[0], D[0])]
    return recommendations

# Offline Evaluation

Common evaluation metrics for recommendation problems are:

__Hit@K__: did the true item appear in the top-K?

__Recall@K__: what fraction of the user's "relevant" items are in the top-K? Recall@K is defined as

```Recall@K = (# true positives in top-K) / (# total true positives)```

Recall@K is the primary metric for two-tower retrieval.

For example,
```
Ground truth = {7, 10, 25}
Ranking = [3, 7, 10, 2, 25, 14, ...]
```

then 
* Hit@3 → True (because 7 and 10 are in the top 3)
* Recall@3 → 2/3 (the model retrieved 2 of the 3 relevant items)

NDCG@K (Normalised Discounted Cumulative Gain) is a more nuanced ranking metric. 
* Top positions get high weight
* Lower positions get discounted
* Supports multiple positives

Usually, it is defined as:

NDCG@K = DCG / IDCG


__MRR (Mean Reciprocal Rank)__ rewards placing the true item very high. 

```if the true item is at rank 2 --> MRR = 1/2```

Instead of evaluating using all items in the catalog (often contains millions of items), the industry pattern is to use positive + sampled negatives to build the evaluation set. Here is the steps:
1.	Pick positives for evaluation
Items the user interacted with at test time (not in training).
2.	Sample negatives
* Often ~100 random items
* or ~1000 items from similar category
* Optionally: the same negatives used in training are avoided
3.	Score them
Use:
* user_tower(user_id)
* item_tower(candidate_items)
* dot product similarity
4.	Rank
Sort by score descending.
5.	Compute metrics
* Hit@K
* Recall@K
* MRR (mean reciprocal rank)
* NDCG@K (discounts position)

In [18]:
def evaluate(model, test_users, user_to_test_pos, all_items, K=10, num_neg = 100):
    model.eval() # set model to eval mode
    user_tower = model.user_tower
    item_tower = model.item_tower

    hit_count = 0
    recall_count = 0
    total_positives = 0

    for user in test_users:
        pos_items = user_to_test_pos[user]
        total_positives += len(pos_items)

        # 1) Sample negatives
        neg_items = torch.tensor(
            random.sample(all_items, num_neg),
            dtype=torch.long
        )

        # 2) Form candidate set = positives + negatives
        candidate_items = torch.cat([
            torch.tensor(pos_items, dtype=torch.long),
            neg_items
        ]) # (num_pos + num_neg,)

        # 3) Score them
        user_tensor = torch.tensor([user], dtype=torch.long)
        u = user_tower(user_tensor) # (1, tower_dim)
        v = item_tower(candidate_items.unsqueeze(1)).squeeze(1)  # (num_candidates, tower_dim)
        scores = torch.matmul(v, u.squeeze(0)) # (num_candidates,)

        # 4) Rank
        topk_indices = torch.topk(scores, K).indices # (K,)
        topk_items = candidate_items[topk_indices] # (K,)

        # 5) Metrices
        # Hit@K
        if any(item in pos_items for item in topk_items):
            hit_count += 1
        
        # Recall@K
        recall_count += sum(1 for item in topk_items if item in pos_items)

        return {
            'Hit@K': hit_count / len(test_users),
            'Recall@K': recall_count / total_positives
        }

# Ranking

Two-tower similarity (dot-product / cosine) reflects:
* how much a user embedding aligns with
* an item embedding

But it does not consider:
* context
* freshness
* popularity bias
* business constraints
* short-term patterns
* price sensitivity
* real-time signals
* item quality
* clickability
* image attractiveness
* long descriptions
* dwell‐time behavior
* real conversion predictions

So retrieval towers capture stable, long-term user–item compatibility, while re-rankers capture contextual, short-term, and business-critical factors. If the catalog rarely changes and context doesn’t matter, the re-ranker becomes less important. But for real recommender systems, re-ranking solves three fundamental limitations of retrieval towers.

1. Retrieval Towers Are Optimized for Recall, Not Ranking, i.e. Pull relevant items into the top-K, not to perfectly order them inside top-K. Two-tower dot-product similarity is not expressive enough to order subtle differences between items.

2. Two-Tower Models Cannot Use Rich Features like text features (BERT), images (ResNET/ViT), etc, but ranking models can, because they don’t require:This is what's said by ChatGPT, but I doubt if it is hallucinating because ranking model is the next step of retrieval model. Using two-tower model is to speed up the retrieval using ANN, when we also need a fast re-ranking model.

3. Retrieval Embeddings Cannot Capture Business Rules. Examples:
* “We must show sponsored items preferentially.”
* “Diversity rules: show items from different categories.”
* “Hide items out of stock.”
* “Avoid repeating the same item 3 times.”
* “Prioritize local sellers in this region.”
* “Don’t show NSFW content.”
* “Boost new items.”

Two-tower retrieval cannot enforce these, but ranking can.

4. Retrieval Optimizes a Different Objective. Retrieval training is usually:
* contrastive loss
* or BCE
* or sampled softmax

Ranking training is usually:
* listwise loss
* pairwise ranking loss (RankNet / LambdaRank)
* click prediction loss
* conversion prediction loss
* CTR/CVR modeling

Retrieval chooses “good candidates.” while ranking chooses “best order.”

# Practical Project

https://www.kaggle.com/code/twtw5201/book-recommendation/edit

Use Amazon's book recommendation.

# Advanced Feature Engineering

These are the most powerful techniques used in large-scale retrieval systems.

We’ll cover:
1.	Short-term + long-term dual-tower design
2.	Hierarchical sequence encoders
3.	Multi-modal item towers (text + image + video)
4.	Mixture-of-Experts (MoE) towers
5.	Feature gating / routing
6.	Distributional embeddings (TikTok/Facebook)
7.	Dynamic user state modeling

__Short-Term + Long-Term User Representations__

Almost all modern retrieval systems use two separate user representations. Long-term embedding captures "who the user is": stable interests, average category distributions, embeddings from months of interactions, often an embedding bag, averaged over large history. Short-term embedding captures "what the user wants right now": last session, last X clicks, most recent query and are often RNN/Transformer based. The final embeddingare concatenated and often gated.

```
gate = sigmoid(W * concat(u_long, u_short))
u_final = gate * u_short + (1 - gate) * u_long
```

This lets the model adapt:
* when user intent is stable → long-term dominates
* when user intent shifts fast → short-term dominates

__Hierarchical Sequence Encoder__

Hierarchical sequence encoder is used when sequences are long and noisy. The general structure is:
1.	Encode short segments with a small GRU (local patterns)
2.	Pool segment outputs
3.	Feed pooled outputs into a second GRU/Transformer (global patterns)

This helps:
* reduce noise
* encode very long histories (hundreds of interactions)
* capture both short and medium-term patterns

__Multi-modal item towers__

Items often have:
* text
* images
* videos
* numeric metadata
* reviews
* seller info

A real item tower combines all of them.

```
v = concat(
    item_id_emb,
    category_emb,
    title_bert_emb,
    image_cnn_emb,
    price_norm,
    brand_emb
)
v = MLP(v)
```

__Mixture-of-Experts (MoE)__

MoE is used when userbase is diverse.

```
UserTower:
    MoE 1 → sports-oriented expert
    MoE 2 → fashion-oriented expert
    MoE 3 → tech-oriented expert
```

Expert weights predicted by a gating network based on user features. This gives huge gains in:
* personalization
* niche categories
* robustness to data sparsity

__Feature Gating / Routing__

Used when some features should dominate in certain contexts.

Examples:
	•	if user session is long → short-term tower has higher weight
	•	if user comes from search → query embedding dominates
	•	if country = US → price sensitivity dominates
	•	if device = mobile → video preference dominates

This makes embedding spaces context-aware.

__Distributional Embeddings__

Instead of representing a user as a single vector, represent as:
•	set of vectors
•	Gaussian distribution
•	mixture of clusters
•	attention-weighted bag of embeddings

Reasons:
•	user preferences are multi-modal
•	single vector collapses too much information
•	multi-vector retrieval improves coverage

__Dynamic User State Modeling__

Real-time ER-type signals:
•	user just clicked X
•	user hovered over Y
•	user scrolled to section Z
•	dwell time
•	search query edits

A real two-tower system ingests real-time signals to form user embedding in milliseconds.

This is the frontier of retrieval research.