# üéØ Day 16: Order Matters - Pointer Networks

**Paper:** *Order Matters: Sequence to Sequence for Sets* (Vinyals et al., 2015)

Welcome to this interactive notebook! Today we'll explore **Pointer Networks** - a neural architecture that can process unordered sets and output ordered sequences.

## üçï The Pizza Delivery Analogy

Imagine you're a pizza delivery driver:

- **Input (Set):** You get 5 pizza orders from different addresses - order doesn't matter
- **Output (Sequence):** You need to plan your delivery route - order DOES matter!

That's what Pointer Networks do:
1. Read a SET of items (no order)
2. Process them intelligently
3. Output a SEQUENCE by "pointing" to items one-by-one

Let's dive in! üöÄ

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, HTML

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("‚úÖ Setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

## üìö Part 1: Understanding Pointer Attention

The core innovation is the **pointer mechanism**:
- Instead of generating words from a vocabulary
- We "point" to positions in the input

Let's build it step by step!

In [None]:
class SimplePointerAttention(nn.Module):
    """
    Simplified pointer attention mechanism.
    
    Formula: attention(query, keys) = softmax(v^T tanh(W_q * query + W_k * keys))
    """
    
    def __init__(self, hidden_dim=32):
        super().__init__()
        
        self.W_key = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_query = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v = nn.Linear(hidden_dim, 1, bias=False)
    
    def forward(self, query, keys, mask=None):
        """
        Args:
            query: [batch, hidden_dim] - what we're looking for
            keys: [batch, seq_len, hidden_dim] - where we search
            mask: [batch, seq_len] - 1 for positions to mask
        Returns:
            attention_weights: [batch, seq_len]
        """
        # Transform query and keys
        query_proj = self.W_query(query).unsqueeze(1)  # [batch, 1, hidden_dim]
        keys_proj = self.W_key(keys)                    # [batch, seq_len, hidden_dim]
        
        # Additive attention
        scores = self.v(torch.tanh(query_proj + keys_proj))  # [batch, seq_len, 1]
        scores = scores.squeeze(-1)  # [batch, seq_len]
        
        # Apply mask (set masked positions to -inf)
        if mask is not None:
            scores = scores.masked_fill(mask.bool(), float('-inf'))
        
        # Softmax to get probabilities
        attention_weights = torch.softmax(scores, dim=-1)
        
        return attention_weights

# Test it!
attention = SimplePointerAttention(hidden_dim=32)

query = torch.randn(1, 32)
keys = torch.randn(1, 5, 32)

weights = attention(query, keys)

print("Attention weights:", weights[0].tolist())
print(f"Sum: {weights.sum():.4f} (should be 1.0)")
print(f"\n‚úÖ Pointer attention works! It's deciding which input position to focus on.")

### üéØ Interactive: Visualize Attention

Let's see what the attention "sees":

In [None]:
def visualize_attention_example():
    # Create 5 input items with different values
    items = torch.tensor([[0.2, 0.8, 0.1, 0.9, 0.5]])
    
    # Simple embedding
    embed = nn.Linear(1, 32)
    keys = embed(items.unsqueeze(-1))
    
    # Query looking for "high values"
    query = torch.randn(1, 32)
    
    # Compute attention
    attention_layer = SimplePointerAttention(32)
    weights = attention_layer(query, keys)
    
    # Visualize
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 6))
    
    # Input values
    ax1.bar(range(5), items[0].numpy(), color='skyblue', alpha=0.7)
    ax1.set_title('Input Values', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Position')
    ax1.set_ylabel('Value')
    ax1.grid(True, alpha=0.3)
    
    # Attention weights
    ax2.bar(range(5), weights[0].detach().numpy(), color='coral', alpha=0.7)
    ax2.set_title('Attention Weights (Where to Point)', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Position')
    ax2.set_ylabel('Probability')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Show which position has highest attention
    best_pos = torch.argmax(weights[0]).item()
    print(f"\nüëâ Model points to position {best_pos} (value: {items[0, best_pos]:.2f})")

visualize_attention_example()

## üî¢ Part 2: Sorting with Pointer Networks

Let's tackle the simplest problem: **sorting numbers**

- Input: `[0.7, 0.2, 0.9, 0.1, 0.5]` (unordered set)
- Output: `[3, 1, 4, 0, 2]` (indices in sorted order)

This means: point to position 3 (0.1), then 1 (0.2), then 4 (0.5), then 0 (0.7), then 2 (0.9)

In [None]:
class PointerNetwork(nn.Module):
    """Full Pointer Network for sorting."""
    
    def __init__(self, input_dim=1, hidden_dim=64):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        # Encoder: Process input set
        self.encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        
        # Decoder: Generate output sequence
        self.decoder = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        
        # Pointer attention
        self.attention = SimplePointerAttention(hidden_dim)
    
    def forward(self, inputs):
        """
        Args:
            inputs: [batch, seq_len] numbers to sort
        Returns:
            all_pointers: [batch, seq_len, seq_len] attention logits
        """
        batch_size, seq_len = inputs.shape
        inputs = inputs.unsqueeze(-1)  # [batch, seq_len, 1]
        
        # Encode
        encoder_outputs, (h, c) = self.encoder(inputs)
        
        # Decode
        decoder_state = (h, c)
        decoder_input = torch.zeros(batch_size, 1, self.hidden_dim, device=inputs.device)
        
        all_pointers = []
        
        for t in range(seq_len):
            # One decoding step
            decoder_output, decoder_state = self.decoder(decoder_input, decoder_state)
            
            # Compute where to point
            weights = self.attention(decoder_output.squeeze(1), encoder_outputs)
            all_pointers.append(weights)
            
            # Next input: weighted sum of encoder outputs
            context = torch.bmm(weights.unsqueeze(1), encoder_outputs)
            decoder_input = context
        
        return torch.stack(all_pointers, dim=1)

# Create model
model = PointerNetwork(input_dim=1, hidden_dim=64)
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

### üèãÔ∏è Train on Sorting Task

Let's train the model to sort numbers:

In [None]:
def generate_sorting_batch(batch_size=32, seq_len=5):
    """Generate random sorting problems."""
    # Random numbers
    inputs = torch.rand(batch_size, seq_len)
    
    # Sorting indices (targets)
    targets = torch.argsort(inputs, dim=1)
    
    return inputs, targets

def train_sorting(model, num_epochs=20, batch_size=32, seq_len=5):
    """Train model to sort numbers."""
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    losses = []
    accuracies = []
    
    for epoch in range(num_epochs):
        # Generate training batch
        inputs, targets = generate_sorting_batch(batch_size, seq_len)
        
        # Forward pass
        logits = model(inputs)
        
        # Compute loss
        logits_flat = logits.view(-1, seq_len)
        targets_flat = targets.view(-1)
        loss = F.cross_entropy(logits_flat, targets_flat)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Compute accuracy
        predictions = torch.argmax(logits, dim=-1)
        accuracy = (predictions == targets).all(dim=1).float().mean().item()
        
        losses.append(loss.item())
        accuracies.append(accuracy)
        
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1}/{num_epochs} - Loss: {loss.item():.4f}, Acc: {accuracy:.2%}")
    
    return losses, accuracies

# Train!
print("üèãÔ∏è Training on sorting task...\n")
losses, accuracies = train_sorting(model, num_epochs=20, seq_len=5)

# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(losses, color='red', linewidth=2)
ax1.set_title('Training Loss', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.grid(True, alpha=0.3)

ax2.plot(accuracies, color='green', linewidth=2)
ax2.set_title('Accuracy (Perfect Sequences)', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_ylim([0, 1])
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\n‚úÖ Final accuracy: {accuracies[-1]:.2%}")

### üéØ Interactive: Test Your Own Numbers!

Let's sort some numbers and visualize what the model is doing:

In [None]:
def test_and_visualize(model, numbers):
    """
    Test model on custom numbers and visualize attention.
    
    Args:
        model: Trained PointerNetwork
        numbers: List of numbers to sort
    """
    model.eval()
    
    # Prepare input
    inputs = torch.tensor([numbers]).float()
    
    # Get predictions
    with torch.no_grad():
        logits = model(inputs)
        predictions = torch.argmax(logits, dim=-1)[0]
    
    # Ground truth
    true_order = torch.argsort(inputs[0])
    
    # Create visualization
    fig = plt.figure(figsize=(14, 8))
    
    # 1. Input numbers
    ax1 = plt.subplot(3, 1, 1)
    bars = ax1.bar(range(len(numbers)), numbers, color='skyblue', alpha=0.7, edgecolor='black')
    for i, (bar, num) in enumerate(zip(bars, numbers)):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                f'{num:.2f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
        ax1.text(bar.get_x() + bar.get_width()/2, -0.05, 
                f'idx:{i}', ha='center', va='top', fontsize=10, color='gray')
    ax1.set_title('Input Numbers (Unsorted)', fontsize=14, fontweight='bold')
    ax1.set_ylim([0, max(numbers) + 0.2])
    ax1.set_xticks([])
    ax1.grid(True, alpha=0.3, axis='y')
    
    # 2. Attention heatmap
    ax2 = plt.subplot(3, 1, 2)
    attention_matrix = logits[0].softmax(dim=-1).numpy()
    im = ax2.imshow(attention_matrix, cmap='YlOrRd', aspect='auto')
    ax2.set_title('Attention Heatmap (Where Model Points)', fontsize=14, fontweight='bold')
    ax2.set_ylabel('Output Step')
    ax2.set_xlabel('Input Position')
    ax2.set_yticks(range(len(numbers)))
    ax2.set_yticklabels([f'Step {i+1}' for i in range(len(numbers))])
    ax2.set_xticks(range(len(numbers)))
    plt.colorbar(im, ax=ax2, label='Attention Weight')
    
    # 3. Predicted order
    ax3 = plt.subplot(3, 1, 3)
    sorted_nums = [numbers[i] for i in predictions.tolist()]
    bars = ax3.bar(range(len(sorted_nums)), sorted_nums, 
                   color='lightgreen' if torch.equal(predictions, true_order) else 'salmon', 
                   alpha=0.7, edgecolor='black')
    for i, (bar, num, idx) in enumerate(zip(bars, sorted_nums, predictions.tolist())):
        ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                f'{num:.2f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
        ax3.text(bar.get_x() + bar.get_width()/2, -0.05, 
                f'from:{idx}', ha='center', va='top', fontsize=10, color='gray')
    
    title = 'Predicted Order ‚úÖ CORRECT!' if torch.equal(predictions, true_order) else 'Predicted Order ‚ùå WRONG'
    ax3.set_title(title, fontsize=14, fontweight='bold')
    ax3.set_ylim([0, max(numbers) + 0.2])
    ax3.set_xticks([])
    ax3.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()
    
    # Print results
    print("\n" + "="*60)
    print("üéØ Sorting Results")
    print("="*60)
    print(f"Input:          {numbers}")
    print(f"Predicted order: {predictions.tolist()}")
    print(f"True order:      {true_order.tolist()}")
    print(f"Sorted values:   {sorted_nums}")
    
    if torch.equal(predictions, true_order):
        print("\n‚úÖ Perfect sorting!")
    else:
        print("\n‚ùå Not quite right. Try training longer!")

# Test on example
test_numbers = [0.7, 0.2, 0.9, 0.1, 0.5]
print(f"Testing on: {test_numbers}\n")
test_and_visualize(model, test_numbers)

### üéÆ Try Your Own Numbers!

Change the numbers below and see how the model performs:

In [None]:
# ‚úèÔ∏è EDIT THESE NUMBERS!
my_numbers = [0.3, 0.8, 0.1, 0.6, 0.4]

test_and_visualize(model, my_numbers)

## üó∫Ô∏è Part 3: Traveling Salesman Problem (TSP)

Now let's tackle a **harder problem**: visiting cities in optimal order

- Input: Set of city locations `[(x‚ÇÅ, y‚ÇÅ), (x‚ÇÇ, y‚ÇÇ), ...]`
- Output: Tour order `[0, 3, 1, 4, 2]` that minimizes total distance

This is **NP-hard** (no known efficient optimal algorithm)!

In [None]:
class TSPPointerNetwork(nn.Module):
    """Pointer Network for TSP with masking."""
    
    def __init__(self, input_dim=2, hidden_dim=64):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        self.encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.decoder = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.attention = SimplePointerAttention(hidden_dim)
    
    def forward(self, cities):
        """
        Args:
            cities: [batch, num_cities, 2] - city coordinates
        Returns:
            tour: [batch, num_cities] - predicted tour
            all_logits: [batch, num_cities, num_cities] - attention scores
        """
        batch_size, num_cities, _ = cities.shape
        
        # Encode
        encoder_outputs, (h, c) = self.encoder(cities)
        
        # Decode with masking
        decoder_state = (h, c)
        decoder_input = torch.zeros(batch_size, 1, self.hidden_dim, device=cities.device)
        
        mask = torch.zeros(batch_size, num_cities, device=cities.device)
        tour = []
        all_logits = []
        
        for t in range(num_cities):
            decoder_output, decoder_state = self.decoder(decoder_input, decoder_state)
            
            # Compute attention with masking
            weights = self.attention(decoder_output.squeeze(1), encoder_outputs, mask)
            all_logits.append(weights)
            
            # Greedy selection
            _, selected = weights.max(dim=-1)
            tour.append(selected)
            
            # Update mask (mark selected city as visited)
            mask.scatter_(1, selected.unsqueeze(1), 1)
            
            # Update decoder input
            context = torch.bmm(weights.unsqueeze(1), encoder_outputs)
            decoder_input = context
        
        tour = torch.stack(tour, dim=1)
        all_logits = torch.stack(all_logits, dim=1)
        
        return tour, all_logits

def compute_tour_length(cities, tour):
    """Compute total tour distance."""
    length = 0.0
    for i in range(len(tour)):
        current = tour[i]
        next_city = tour[(i + 1) % len(tour)]
        length += np.linalg.norm(cities[current] - cities[next_city])
    return length

# Create TSP model
tsp_model = TSPPointerNetwork(input_dim=2, hidden_dim=64)
print(f"TSP model created with {sum(p.numel() for p in tsp_model.parameters()):,} parameters")

### üó∫Ô∏è Visualize TSP Tour

In [None]:
def visualize_tsp(cities, tour, title="TSP Tour"):
    """Visualize a TSP tour."""
    plt.figure(figsize=(8, 8))
    
    # Plot cities
    plt.scatter(cities[:, 0], cities[:, 1], c='blue', s=200, alpha=0.6, zorder=3, edgecolors='black', linewidth=2)
    
    # Label cities
    for i, (x, y) in enumerate(cities):
        plt.text(x, y, str(i), fontsize=12, ha='center', va='center', 
                fontweight='bold', color='white', zorder=4)
    
    # Draw tour
    for i in range(len(tour)):
        start_idx = tour[i]
        end_idx = tour[(i + 1) % len(tour)]
        
        start = cities[start_idx]
        end = cities[end_idx]
        
        # Draw arrow
        dx = end[0] - start[0]
        dy = end[1] - start[1]
        plt.arrow(start[0], start[1], dx*0.9, dy*0.9,
                 head_width=0.03, head_length=0.03, 
                 fc='red', ec='red', alpha=0.7, zorder=2,
                 length_includes_head=True, linewidth=2)
    
    # Compute and display tour length
    length = compute_tour_length(cities, tour)
    
    plt.title(f"{title}\nTour Length: {length:.3f}", fontsize=14, fontweight='bold')
    plt.xlim(-0.1, 1.1)
    plt.ylim(-0.1, 1.1)
    plt.grid(True, alpha=0.3)
    plt.axis('equal')
    plt.tight_layout()
    plt.show()

# Generate random cities
np.random.seed(42)
num_cities = 8
cities = np.random.rand(num_cities, 2)

# Test untrained model
tsp_model.eval()
cities_tensor = torch.from_numpy(cities).float().unsqueeze(0)

with torch.no_grad():
    tour, _ = tsp_model(cities_tensor)
    tour = tour[0].numpy()

visualize_tsp(cities, tour, "Untrained Model")

print(f"\nTour: {tour.tolist()}")
print(f"Length: {compute_tour_length(cities, tour):.3f}")
print("\nüí° Notice: Untrained model makes a random tour")
print("   Training would make it find shorter tours!")

## üéì Key Takeaways

### ‚ú® What We Learned

1. **Pointer Networks** solve set-to-sequence problems:
   - Input: Unordered set
   - Output: Ordered sequence by "pointing" to input elements

2. **Three Key Components:**
   - **Encoder:** Process input set
   - **Decoder:** Generate sequence step-by-step
   - **Pointer Attention:** Decide which input to point to at each step

3. **Applications:**
   - ‚úÖ Sorting numbers
   - ‚úÖ Convex hull computation
   - ‚úÖ Traveling Salesman Problem
   - ‚úÖ Any problem where output references input!

### üöÄ Going Further

Want to dive deeper? Check out:

1. **Exercise files** (`exercises/exercise_*.py`) - Build components from scratch
2. **Solution files** (`solutions/solution_*.py`) - See complete implementations
3. **PAPER_NOTES.md** - Deep dive into the original paper
4. **CHEATSHEET.md** - Quick reference guide

### üéØ Challenges

Try these on your own:

1. Train the sorting model on longer sequences (10-20 numbers)
2. Implement beam search instead of greedy decoding
3. Add a 2-opt post-processing step for TSP
4. Try reinforcement learning instead of supervised learning

---

**Happy learning! üéâ**

Remember: The key insight is that **order matters** in the output, even when it doesn't in the input. Pointer Networks elegantly handle this by learning to point!