# Lab B.3 Solutions: Two-Tower Retrieval System

Solutions to exercises from Lab B.3.

---

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Exercise 1 Solution: Temperature Tuning

Temperature controls the sharpness of the softmax distribution in contrastive learning.

In [None]:
def visualize_temperature_effect(logits, temperatures):
    """
    Visualize how temperature affects the softmax distribution.
    """
    fig, axes = plt.subplots(1, len(temperatures), figsize=(15, 4))
    
    for ax, temp in zip(axes, temperatures):
        probs = F.softmax(logits / temp, dim=0).numpy()
        ax.bar(range(len(probs)), probs)
        ax.set_title(f'Temperature = {temp}')
        ax.set_xlabel('Class')
        ax.set_ylabel('Probability')
        ax.set_ylim(0, 1)
    
    plt.tight_layout()
    plt.show()


# Example logits (similarity scores)
logits = torch.tensor([2.0, 1.5, 0.5, 0.2, -0.5])
temperatures = [0.01, 0.07, 0.2, 0.5, 1.0]

visualize_temperature_effect(logits, temperatures)

print("ðŸ“Š Observations:")
print("   - Low temp (0.01): Almost one-hot, very confident")
print("   - Medium temp (0.07): Sharp but smooth, good for learning")
print("   - High temp (0.5+): Soft distribution, weak learning signal")

In [None]:
# Training with different temperatures
def train_with_temperature(temperature, batch_size=32, steps=100):
    """
    Simulate training loss with different temperatures.
    """
    # Simulate random embeddings
    query_dim = 128
    losses = []
    
    for step in range(steps):
        # Random normalized embeddings
        query = F.normalize(torch.randn(batch_size, query_dim), dim=1)
        item = F.normalize(torch.randn(batch_size, query_dim), dim=1)
        
        # Add some positive signal (diagonal should be higher)
        item = 0.8 * query + 0.2 * item
        item = F.normalize(item, dim=1)
        
        # Compute similarity
        logits = torch.matmul(query, item.T) / temperature
        labels = torch.arange(batch_size)
        
        loss = F.cross_entropy(logits, labels)
        losses.append(loss.item())
    
    return losses


# Compare temperatures
temps_to_test = [0.01, 0.07, 0.2, 0.5]
temp_results = {}

for temp in temps_to_test:
    losses = train_with_temperature(temp)
    temp_results[temp] = losses

# Plot
plt.figure(figsize=(10, 5))
for temp, losses in temp_results.items():
    plt.plot(losses, label=f'temp={temp}')

plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Training Loss by Temperature')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("\nðŸ“Š Key insight: Lower temperature = lower loss but can be unstable")
print("   Recommended: 0.05-0.1 for most contrastive learning tasks")

---

## Exercise 2 Solution: Batch Size Impact

In [None]:
def analyze_batch_size_negatives(batch_sizes):
    """
    Analyze how batch size affects number of negatives.
    """
    analysis = []
    
    for bs in batch_sizes:
        num_positives = bs
        num_negatives_per_positive = bs - 1
        total_negative_pairs = bs * (bs - 1)
        negative_ratio = num_negatives_per_positive / 1
        
        analysis.append({
            'batch_size': bs,
            'positives': num_positives,
            'negatives_per_positive': num_negatives_per_positive,
            'total_negative_pairs': total_negative_pairs,
            'negative_ratio': f'{num_negatives_per_positive}:1'
        })
    
    return analysis


batch_sizes = [32, 64, 128, 256, 512, 1024, 2048]
analysis = analyze_batch_size_negatives(batch_sizes)

print("In-Batch Negatives Analysis")
print("="*60)
print(f"{'Batch Size':>12} | {'Negatives/Positive':>18} | {'Total Neg Pairs':>15}")
print("-"*60)
for row in analysis:
    print(f"{row['batch_size']:>12} | {row['negatives_per_positive']:>18} | {row['total_negative_pairs']:>15,}")

print("\nðŸ“Š Observation: Batch size 512 gives 511 negatives per positive - for free!")

In [None]:
# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Negatives per positive
negs = [a['negatives_per_positive'] for a in analysis]
axes[0].bar([str(bs) for bs in batch_sizes], negs, color='steelblue', edgecolor='black')
axes[0].set_xlabel('Batch Size')
axes[0].set_ylabel('Negatives per Positive')
axes[0].set_title('In-Batch Negatives Scale with Batch Size')
axes[0].grid(axis='y', alpha=0.3)

# GPU memory estimate (rough)
# Embedding dim = 128, FP32
emb_dim = 128
bytes_per_float = 4
memory_mb = [(bs * emb_dim * bytes_per_float * 2) / (1024**2) for bs in batch_sizes]  # *2 for query+item

axes[1].bar([str(bs) for bs in batch_sizes], memory_mb, color='coral', edgecolor='black')
axes[1].set_xlabel('Batch Size')
axes[1].set_ylabel('Embedding Memory (MB)')
axes[1].set_title('Memory Usage by Batch Size')
axes[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print("\nðŸ’¡ DGX Spark Advantage: With 128GB memory, you can use batch sizes of 4096+!")
print("   This gives 4095 negatives per positive - massive training signal.")

---

## Bonus: Hard Negative Mining

In [None]:
def hard_negative_mining(query_embeddings, item_embeddings, positive_indices, k=5):
    """
    Find hard negatives: items that are similar but not positive.
    
    Args:
        query_embeddings: (num_queries, dim) tensor
        item_embeddings: (num_items, dim) tensor
        positive_indices: Dict mapping query_idx -> set of positive item indices
        k: Number of hard negatives to find per query
    
    Returns:
        Dict mapping query_idx -> list of hard negative indices
    """
    # Compute all similarities
    similarities = torch.matmul(query_embeddings, item_embeddings.T)
    
    hard_negatives = {}
    
    for query_idx in range(len(query_embeddings)):
        scores = similarities[query_idx].clone()
        
        # Mask positives
        for pos_idx in positive_indices.get(query_idx, set()):
            scores[pos_idx] = -float('inf')
        
        # Get top-k negatives (highest scoring non-positives = hardest)
        _, top_k = torch.topk(scores, k)
        hard_negatives[query_idx] = top_k.tolist()
    
    return hard_negatives


# Example
num_queries = 10
num_items = 100
dim = 64

queries = F.normalize(torch.randn(num_queries, dim), dim=1)
items = F.normalize(torch.randn(num_items, dim), dim=1)

# Assume each query has 1 positive
positive_indices = {i: {i * 10} for i in range(num_queries)}

hard_negs = hard_negative_mining(queries, items, positive_indices, k=5)

print("Hard Negatives Found:")
for query_idx in range(3):
    print(f"  Query {query_idx}: Hard negatives = {hard_negs[query_idx]}")

print("\nðŸ’¡ Use these hard negatives in the next training epoch for better learning!")

---

## Key Takeaways

1. **Temperature**: 0.05-0.1 is typically optimal. Too low = unstable, too high = weak signal.

2. **Batch size**: Larger is better for in-batch negatives. DGX Spark can go huge!

3. **Hard negatives**: Mining similar-but-wrong items improves discrimination.

4. **Memory**: Two-tower scales because item embeddings are pre-computed.