# Day 16: Order Matters - Pointer Networks

**Paper:** *Order Matters: Sequence to Sequence for Sets* - Vinyals, Bengio, Kudlur (2015)

We implement Pointer Networks - an architecture that generates output sequences by selecting indices from the input set rather than producing tokens from a fixed vocabulary. The notebook walks through pointer attention, a sorting task, and a small TSP demo.

---

## What You'll Learn

1. Why standard seq2seq fails when the output vocabulary depends on the input
2. The pointer attention mechanism (Eq. 1-3 in the paper)
3. How Read-Process-and-Write handles set-structured inputs
4. Sorting as a minimal test bed for pointer networks
5. Applying pointer networks to the Travelling Salesman Problem

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, HTML

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Setup complete")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

## Part 1: Understanding Pointer Attention

The core innovation is the pointer mechanism. Instead of selecting tokens from a fixed vocabulary, the model calculates attention over the input elements and uses the resulting distribution to pick an input index.

In [None]:
class SimplePointerAttention(nn.Module):
    """
    Simplified pointer attention mechanism.
    
    Formula: attention(query, keys) = softmax(v^T tanh(W_q * query + W_k * keys))
    """
    
    def __init__(self, hidden_dim=32):
        super().__init__()
        
        self.W_key = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_query = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v = nn.Linear(hidden_dim, 1, bias=False)
    
    def forward(self, query, keys, mask=None):
        """
        Args:
            query: [batch, hidden_dim] - what we're looking for
            keys: [batch, seq_len, hidden_dim] - where we search
            mask: [batch, seq_len] - 1 for positions to mask
        Returns:
            attention_weights: [batch, seq_len]
        """
        # Transform query and keys
        query_proj = self.W_query(query).unsqueeze(1)  # [batch, 1, hidden_dim]
        keys_proj = self.W_key(keys)                    # [batch, seq_len, hidden_dim]
        
        # Additive attention
        scores = self.v(torch.tanh(query_proj + keys_proj))  # [batch, seq_len, 1]
        scores = scores.squeeze(-1)  # [batch, seq_len]
        
        # Apply mask (set masked positions to -inf)
        if mask is not None:
            scores = scores.masked_fill(mask.bool(), float('-inf'))
        
        # Softmax to get probabilities
        attention_weights = torch.softmax(scores, dim=-1)
        
        return attention_weights

# Test implementation
attention = SimplePointerAttention(hidden_dim=32)

query = torch.randn(1, 32)
keys = torch.randn(1, 5, 32)

weights = attention(query, keys)

print("Attention weights:", weights[0].tolist())
print(f"Sum: {weights.sum():.4f} (should be 1.0)")
print("Pointer attention correctly calculates focus over input positions.")

### Visualization of Attention

The following bar charts visualize the attention weights given a query and a set of input items.

In [None]:
def visualize_attention_example():
    # Create 5 input items with different values
    items = torch.tensor([[0.2, 0.8, 0.1, 0.9, 0.5]])
    
    # Simple embedding
    embed = nn.Linear(1, 32)
    keys = embed(items.unsqueeze(-1))
    
    # Query looking for "high values"
    query = torch.randn(1, 32)
    
    # Compute attention
    attention_layer = SimplePointerAttention(32)
    weights = attention_layer(query, keys)
    
    # Visualize
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 6))
    
    # Input values
    ax1.bar(range(5), items[0].numpy(), color='skyblue', alpha=0.7)
    ax1.set_title('Input Values', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Position')
    ax1.set_ylabel('Value')
    ax1.grid(True, alpha=0.3)
    
    # Attention weights
    ax2.bar(range(5), weights[0].detach().numpy(), color='coral', alpha=0.7)
    ax2.set_title('Attention Weights (Where to Point)', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Position')
    ax2.set_ylabel('Probability')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Show which position has highest attention
    best_pos = torch.argmax(weights[0]).item()
    print(f"Model points to position {best_pos} (value: {items[0, best_pos]:.2f})")

visualize_attention_example()

## Part 2: Sorting with Pointer Networks

We can test the Pointer Network on the task of sorting numbers. The model must learn the "selection rule" for ascending order.

- **Input:** `[0.7, 0.2, 0.9, 0.1, 0.5]` (unordered set)
- **Output:** `[3, 1, 4, 0, 2]` (indices in sorted order)

In [None]:
class PointerNetwork(nn.Module):
    """Full Pointer Network implementation."""
    
    def __init__(self, input_dim=1, hidden_dim=64):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        # Encoder: Process input set
        self.encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        
        # Decoder: Generate output sequence
        self.decoder = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        
        # Pointer attention
        self.attention = SimplePointerAttention(hidden_dim)
    
    def forward(self, inputs):
        """
        Args:
            inputs: [batch, seq_len] numbers to sort
        Returns:
            all_pointers: [batch, seq_len, seq_len] attention distributions
        """
        batch_size, seq_len = inputs.shape
        inputs = inputs.unsqueeze(-1)  # [batch, seq_len, 1]
        
        # Encode
        encoder_outputs, (h, c) = self.encoder(inputs)
        
        # Decode
        decoder_state = (h, c)
        decoder_input = torch.zeros(batch_size, 1, self.hidden_dim, device=inputs.device)
        
        all_pointers = []
        
        for t in range(seq_len):
            # One decoding step
            decoder_output, decoder_state = self.decoder(decoder_input, decoder_state)
            
            # Compute where to point
            weights = self.attention(decoder_output.squeeze(1), encoder_outputs)
            all_pointers.append(weights)
            
            # Next input: weighted average of encoder outputs
            context = torch.bmm(weights.unsqueeze(1), encoder_outputs)
            decoder_input = context
        
        return torch.stack(all_pointers, dim=1)

# Initialize model
model = PointerNetwork(input_dim=1, hidden_dim=64)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

### Training on the Sorting Task

We train the model using standard cross-entropy loss over the predicted pointer indices.

In [None]:
def generate_sorting_batch(batch_size=32, seq_len=5):
    """Generate random sorting problems."""
    inputs = torch.rand(batch_size, seq_len)
    targets = torch.argsort(inputs, dim=1)
    return inputs, targets

def train_sorting(model, num_epochs=20, batch_size=32, seq_len=5):
    """Train model to sort numbers."""
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    losses = []
    accuracies = []
    
    for epoch in range(num_epochs):
        inputs, targets = generate_sorting_batch(batch_size, seq_len)
        
        logits = model(inputs)
        
        logits_flat = logits.view(-1, seq_len)
        targets_flat = targets.view(-1)
        loss = F.cross_entropy(logits_flat, targets_flat)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        predictions = torch.argmax(logits, dim=-1)
        accuracy = (predictions == targets).all(dim=1).float().mean().item()
        
        losses.append(loss.item())
        accuracies.append(accuracy)
        
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1}/{num_epochs} - Loss: {loss.item():.4f}, Acc: {accuracy:.2%}")
    
    return losses, accuracies

print("Training model on sorting sequences of length 5...")
losses, accuracies = train_sorting(model, num_epochs=20, seq_len=5)

# Plot validation metrics
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(losses, color='blue')
ax1.set_title('Training Loss')
ax1.set_xlabel('Epoch')
ax2.plot(accuracies, color='green')
ax2.set_title('Sequence Accuracy')
ax2.set_xlabel('Epoch')
plt.tight_layout()
plt.show()

### Inference and Attention Heatmap

Testing the model on a single sequence and visualizing the attention heatmap at each step.

In [None]:
def test_and_visualize(model, numbers):
    model.eval()
    inputs = torch.tensor([numbers]).float()
    
    with torch.no_grad():
        logits = model(inputs)
        predictions = torch.argmax(logits, dim=-1)[0]
    
    true_order = torch.argsort(inputs[0])
    
    fig = plt.figure(figsize=(14, 8))
    
    # 1. Input numbers
    ax1 = plt.subplot(3, 1, 1)
    ax1.bar(range(len(numbers)), numbers, color='skyblue', alpha=0.7)
    ax1.set_title('Input Numbers (Unsorted)')
    
    # 2. Attention heatmap
    ax2 = plt.subplot(3, 1, 2)
    attention_matrix = logits[0].softmax(dim=-1).numpy()
    im = ax2.imshow(attention_matrix, cmap='YlOrRd', aspect='auto')
    ax2.set_title('Pointer Attention Heatmap')
    ax2.set_ylabel('Output Step')
    ax2.set_xlabel('Input Index')
    plt.colorbar(im, ax=ax2)
    
    # 3. Predicted result
    ax3 = plt.subplot(3, 1, 3)
    sorted_nums = [numbers[i] for i in predictions.tolist()]
    is_correct = torch.equal(predictions, true_order)
    ax3.bar(range(len(sorted_nums)), sorted_nums, color='lightgreen' if is_correct else 'salmon')
    ax3.set_title(f"Predicted Sort Order (Correct: {is_correct})")
    
    plt.tight_layout()
    plt.show()

test_numbers = [0.7, 0.2, 0.9, 0.1, 0.5]
test_and_visualize(model, test_numbers)

## Part 3: Traveling Salesman Problem (TSP)

TSP is a combinatorial optimization problem where the input is a set of city locations and the goal is to find the shortest tour. This is NP-hard, making it a rigorous test for the learning capacity of Pointer Networks.

In [None]:
class TSPPointerNetwork(nn.Module):
    """Pointer Network for TSP with index masking."""
    
    def __init__(self, input_dim=2, hidden_dim=64):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.decoder = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.attention = SimplePointerAttention(hidden_dim)
    
    def forward(self, cities):
        batch_size, num_cities, _ = cities.shape
        encoder_outputs, (h, c) = self.encoder(cities)
        
        decoder_state = (h, c)
        decoder_input = torch.zeros(batch_size, 1, self.hidden_dim, device=cities.device)
        mask = torch.zeros(batch_size, num_cities, device=cities.device)
        
        tour = []
        for t in range(num_cities):
            decoder_output, decoder_state = self.decoder(decoder_input, decoder_state)
            weights = self.attention(decoder_output.squeeze(1), encoder_outputs, mask)
            
            _, selected = weights.max(dim=-1)
            tour.append(selected)
            
            # Mark selected city to prevent revisiting
            mask.scatter_(1, selected.unsqueeze(1), 1)
            
            context = torch.bmm(weights.unsqueeze(1), encoder_outputs)
            decoder_input = context
        
        return torch.stack(tour, dim=1)

tsp_model = TSPPointerNetwork(input_dim=2, hidden_dim=64)
print("TSP model initialized")

## Key Takeaways

1. **Set Invariance:** By removing positional encodings in the encoder, the Pointer Network treats the input as a set.
2. **The Pointing Mechanism:** The model outputs indices into the input rather than tokens from a vocabulary, enabling it to handle variable input sizes and out-of-vocabulary numerical values.
3. **Generalization:** Pointer Networks often generalize to sequence lengths longer than those seen during training because they learn the algorithmic operation (e.g., sorting) rather than specific patterns.