In [None]:
import random
from typing import Dict, List, Sequence, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Two-Towers Model

In [None]:
n_users = 10000
n_items = 50000
n_countries = 50
n_categories = 200
emb_dim = 32 # embedding dimension per feature
tower_dim = 64 # output dimension of each tower

In [None]:
# ---------------------------
# 2. Define the two towers
# ---------------------------
class UserTower(nn.Module):
    def __init__(self):
        super().__init__()
        self.user_emb = nn.Embedding(n_users, emb_dim)
        self.country_emb = nn.Embedding(n_countries, emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim * 2, tower_dim),
            nn.ReLU(),
            nn.Linear(tower_dim, tower_dim)
        )
    
    def forward(self, user_id, country_id):
        """
        Perform forward pass of user tower.

        Args:
            user_id: Tensor of shape (batch_size,)
            country_id: Tensor of shape (batch_size,)
        """
        u_id = self.user_emb(user_id) # (batch_size, emb_dim)
        u_country = self.country_emb(country_id) # (batch_size, emb_dim)
        x = torch.cat([u_id, u_country], dim=-1) # (batch_size, emb_dim * 2)
        out = self.mlp(x) # (batch_size, tower_dim)
        # optional: L2-normalise
        out = F.normalize(out, p=2, dim=-1)
        return out

class ItemTower(nn.Module):
    def __init__(self):
        super().__init__()
        self.item_emb = nn.Embedding(n_items, emb_dim)
        self.cat_emb = nn.Embedding(n_categories, emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim * 2, tower_dim),
            nn.ReLU(),
            nn.Linear(tower_dim, tower_dim)
        )
    
    def forward(self, item_id, category_id):
        """
        Perform forward pass of item tower for the batch of items. The items contains both
        positive and negative samples.

        Args:
            item_id: Tensor of shape (batch_size, N)
            category_id: Tensor of shape (batch_size, N)
        """
        i_id = self.item_emb(item_id) # (batch_size, N, emb_dim)
        i_cat = self.cat_emb(category_id) # (batch_size, N, emb_dim)
        x = torch.cat([i_id, i_cat], dim=-1) # (batch_size, N, emb_dim * 2)
        out = self.mlp(x) # (batch_size, tower_dim)
        # optional: L2-normalise
        out = F.normalize(out, p=2, dim=-1)
        return out

In [None]:
class TwoTowerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.user_tower = UserTower()
        self.item_tower = ItemTower()
    
    def forward(self, user_id, country_id, item_id, category_id):
        user_vec = self.user_tower(user_id, country_id) # (batch_size, tower_dim)
        item_vec = self.item_tower(item_id, category_id) # (batch_size, N, tower_dim)
        # Expand user_vec to match item_vec shape
        user_exp_vec = user_vec.unsqueeze(1).expand(-1, item_vec.size(1), -1) # (batch_size, 1, tower_dim)
        # Compute dot product similarity
        logits = torch.sum(user_exp_vec * item_vec, dim=-1) # (batch_size,)
        return logits

# Negative Sampling Dataset

In [None]:
# Negative sampling dataset
class TwoTowerNegSamplingDataset(Dataset):
    def __init__(self, 
                 interactions: Sequence[Tuple], # Sequence of (user_id, pos_item_id)
                 user_to_hard_negs: Dict[int, List[int]],
                 all_item_ids: Sequence[int], # Catalog of all item IDs
                 num_hard_negs: int = 2, # number of hard negatives per positive
                 num_random_negs: int = 2, # number of random negatives per positive
                 user_to_pos_items: Dict[int, List[int]] = None # optional mapping of user to their positive items
    ):
        """
        Dataset for two-tower model with negative sampling (1 positive + hard negatives + random negatives).
        Args:
            interactions: List of (user_id, pos_item_id) tuples.
            user_to_hard_negs: Dict mapping user_id to list of hard negative item_ids.
            all_item_ids: Sequence of all item IDs in the catalog.
            num_hard_negs: Number of hard negatives to sample per positive.
            num_random_negs: Number of random negatives to sample per positive.
            user_to_pos_items: Optional dict mapping user_id to their positive item_ids to avoid sampling positives as negatives.
        """
        self.interactions = list(interactions)
        self.user_to_hard_negs = user_to_hard_negs
        self.all_item_ids = all_item_ids
        self.num_hard_negs = num_hard_negs
        self.num_random_negs = num_random_negs
        self.user_to_pos_items = user_to_pos_items or self._build_user_pos_items()
        # Convert all_item_ids to a set for faster lookup
        self.all_item_id_set = set(all_item_ids)
    
    def _build_user_pos_items(self) -> Dict[int, List[int]]:
        user_pos = {}
        for u_id, i_id in self.interactions:
            user_pos.setdefault(u_id, []).append(i_id)
        return user_pos

    def __len__(self):
        return len(self.interactions)

    def _sample_hard_negs(self, user_id: int, pos_item_id: int) -> List[int]:
        hard_negs = self.user_to_hard_negs.get(user_id, [])
        # Remove the positive items from hard negatives
        filtered = [i for i in hard_negs if i != pos_item_id]
        if not filtered:
            return []
        if len(filtered) <= self.num_hard_negs:
            # If not enough hard negatives, return all available
            return filtered
        else:
            # otherwise, randomly sample the required number
            sampled = random.sample(filtered, min(self.num_hard_negs, len(filtered)))
            return sampled
    
    def _sample_random_negs(self, user_id: int, pos_item_id: int) -> List[int]:
        """
        Sample random negatives ensuring they are not in user's positive items.
        - Randomly sample from all_item_ids until we have enough valid negatives.
        - If a sampled item is already a positive for the user, drop it and sample again.
        """
        user_pos_items = set(self.user_to_pos_items.get(user_id, [])) # Filter in-batch negatives
        negatives = []
        attempts = 0
        max_attempts = self.num_random_negs * 10  # to avoid infinite loop
        while len(negatives) < self.num_random_negs and attempts < max_attempts:
            sampled_item = random.choice(self.all_item_ids)
            if sampled_item != pos_item_id and sampled_item not in user_pos_items:
                negatives.append(sampled_item)
            attempts += 1
        return negatives

    def __getitem__(self, idx):
        """
        For each interaction, return:
        - one positive interaction (user_id, pos_item_id)
        - num_hard_negs hard negative item_ids
        - num_random_negs random negative item_ids

        Returns:
            user_id: int
            pos_item_id: int
        """
        user_id, pos_item_id = self.interactions[idx]

        # 1) Positive
        item_ids = [pos_item_id]
        labels = [1.0]

        # 2) Hard negatives
        hard_negs = self._sample_hard_negs(user_id, pos_item_id)
        item_ids.extend(hard_negs)
        labels.extend([0.0] * len(hard_negs))

        # 3) Random negatives
        random_negs = self._sample_random_negs(user_id, pos_item_id)
        item_ids.extend(random_negs)
        labels.extend([0.0] * len(random_negs))

        # 4) convert to tensors
        user_id_tensor = torch.tensor(user_id, dtype=torch.long)
        item_ids_tensor = torch.tensor(item_ids, dtype=torch.long)
        labels_tensor = torch.tensor(labels, dtype=torch.float32)

        return {
            'user_id': user_id_tensor, # Scalar tensor
            'item_ids': item_ids_tensor, # Shape: (1 + num_hard_negs + num_random_negs,)
            'labels': labels_tensor # Shape: (1 + num_hard_negs + num_random_negs,)
        }


# Training

In [None]:
# ---------------------------
# 3. One training step
# ---------------------------
model = TwoTowerModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.BCEWithLogitsLoss()

In [None]:
dataset = TwoTowerNegSamplingDataset(
    interactions=[(0, 10), (1, 20), (2, 30)], # example interactions
    user_to_hard_negs={
        0: [11, 12, 13],
        1: [21, 22, 23],
        2: [31, 32, 33]
    },
    all_item_ids=list(range(n_items)),
    num_hard_negs=2,
    num_random_negs=2
)

In [None]:
loader = DataLoader(dataset, batch_size=32, shuffle=True) # DataLoader for batching

for batch in loader:
    user_ids = batch['user_id']          # (B,)
    item_ids = batch['item_ids']         # (B, N)
    labels = batch['labels']             # (B, N)

    logits = model(
        user_ids,
        torch.zeros_like(user_ids), # dummy country_id
        item_ids,
        torch.zeros_like(item_ids)  # dummy category_id
    )  # (B, N)
    loss = loss_fn(logits.view(-1), labels.view(-1))

    optimizer.zero_grad()
    loss.backward() # calculate gradients
    optimizer.step() # update parameters


# Precompute Item Embeddings

In [None]:
model.eval() # set model to eval mode
user_tower = model.user_tower
item_tower = model.item_tower

In [None]:
item_ids_all = torch.arange(n_items) # (n_items,)

with torch.no_grad():
    def compute_item_embeddings(item_ids, batch_size = 1024) -> torch.Tensor:
        """
        Precompute item embeddings for all items in the catalog.
        Returns:
            item_embeddings: Tensor of shape (n_items, batch_size, tower_dim)
        """
        embs = []
        for i in range(0, len(item_ids), batch_size):
            batch = item_ids[i:i+batch_size]
            # reshape to (batch_size, 1)
            batch = batch.unsqueeze(1)
            v = item_tower(batch)
            v = v.squeeze(1) # (batch_size, tower_dim)
            embs.append(v)
        return torch.cat(embs, dim=0) # (n_items, tower_dim)
    
    item_embeddings = compute_item_embeddings(item_ids_all) # (n_items, tower_dim)

# Retrieval

Two-tower at inference: retrieval with precomputed item embedding

In [None]:
import faiss # FAISS library for efficient similarity search

In [None]:
# Build FAISS index
d = item_embeddings.shape[1]
index = faiss.IndexFlatIP(d) # inner product index
index.add(item_embeddings.cpu().numpy()) # add item embeddings to index

In [None]:
# Recommend for a given user
def recommend_for_user(user_id: int, country_id: int, top_k: int = 10) -> List[Tuple[int, float]]:
    """
    Recommend top-k items for the given user using FAISS index.
    Args:
        user_id: int
        country_id: int
        top_k: int
    Returns:
        List of (item_id, score) tuples
    """
    user_id_tensor = torch.tensor([user_id], dtype=torch.long)
    country_id_tensor = torch.tensor([country_id], dtype=torch.long)
    with torch.no_grad():
        user_vec = user_tower(user_id_tensor, country_id_tensor) # (1, tower_dim)
        user_vec_np = user_vec.cpu().numpy().astype('float32') # FAISS requires float32
        # Search in FAISS index
        D, I = index.search(user_vec_np, top_k) # D: distances, I: indices
        recommendations = [(int(item_id), float(score)) for item_id, score in zip(I[0], D[0])]
    return recommendations

# Offline Evaluation

Common evaluation metrics for recommendation problems are:

__Hit@K__: did the true item appear in the top-K?

__Recall@K__: what fraction of the user's "relevant" items are in the top-K? Recall@K is defined as

```Recall@K = (# true positives in top-K) / (# total true positives)```

Recall@K is the primary metric for two-tower retrieval.

For example,
```
Ground truth = {7, 10, 25}
Ranking = [3, 7, 10, 2, 25, 14, ...]
```

then 
* Hit@3 → True (because 7 and 10 are in the top 3)
* Recall@3 → 2/3 (the model retrieved 2 of the 3 relevant items)

NDCG@K (Normalised Discounted Cumulative Gain) is a more nuanced ranking metric. 
* Top positions get high weight
* Lower positions get discounted
* Supports multiple positives

Usually, it is defined as:

NDCG@K = DCG / IDCG


__MRR (Mean Reciprocal Rank)__ rewards placing the true item very high. 

```if the true item is at rank 2 --> MRR = 1/2```

Instead of evaluating using all items in the catalog (often contains millions of items), the industry pattern is to use positive + sampled negatives to build the evaluation set. Here is the steps:
1.	Pick positives for evaluation
Items the user interacted with at test time (not in training).
2.	Sample negatives
* Often ~100 random items
* or ~1000 items from similar category
* Optionally: the same negatives used in training are avoided
3.	Score them
Use:
* user_tower(user_id)
* item_tower(candidate_items)
* dot product similarity
4.	Rank
Sort by score descending.
5.	Compute metrics
* Hit@K
* Recall@K
* MRR (mean reciprocal rank)
* NDCG@K (discounts position)

In [18]:
def evaluate(model, test_users, user_to_test_pos, all_items, K=10, num_neg = 100):
    model.eval() # set model to eval mode
    user_tower = model.user_tower
    item_tower = model.item_tower

    hit_count = 0
    recall_count = 0
    total_positives = 0

    for user in test_users:
        pos_items = user_to_test_pos[user]
        total_positives += len(pos_items)

        # 1) Sample negatives
        neg_items = torch.tensor(
            random.sample(all_items, num_neg),
            dtype=torch.long
        )

        # 2) Form candidate set = positives + negatives
        candidate_items = torch.cat([
            torch.tensor(pos_items, dtype=torch.long),
            neg_items
        ]) # (num_pos + num_neg,)

        # 3) Score them
        user_tensor = torch.tensor([user], dtype=torch.long)
        u = user_tower(user_tensor) # (1, tower_dim)
        v = item_tower(candidate_items.unsqueeze(1)).squeeze(1)  # (num_candidates, tower_dim)
        scores = torch.matmul(v, u.squeeze(0)) # (num_candidates,)

        # 4) Rank
        topk_indices = torch.topk(scores, K).indices # (K,)
        topk_items = candidate_items[topk_indices] # (K,)

        # 5) Metrices
        # Hit@K
        if any(item in pos_items for item in topk_items):
            hit_count += 1
        
        # Recall@K
        recall_count += sum(1 for item in topk_items if item in pos_items)

        return {
            'Hit@K': hit_count / len(test_users),
            'Recall@K': recall_count / total_positives
        }