# Day 15: Bahdanau Attention

**Paper:** "Neural Machine Translation by Jointly Learning to Align and Translate" - Bahdanau, Cho, Bengio (2014)

We implement the additive attention mechanism from scratch and train a seq2seq model on a sequence reversal task to visualize how attention learns to align input and output positions.

---

## What You'll Learn

1. How the fixed-length bottleneck in vanilla seq2seq motivates attention
2. The additive (Bahdanau) scoring function: $a(s_{i-1}, h_j) = v^T \tanh(W s_{i-1} + U h_j)$
3. How alignment weights produce a dynamic context vector at each decoding step
4. Bidirectional encoder design (Section 3.2 of the paper)
5. Visualizing attention matrices to verify learned alignments

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import matplotlib.pyplot as plt
import random

# Set seeds for reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. The Attention Mechanism

The core idea (Section 3). Given:
- Decoder state `s` (what we're trying to predict)
- Encoder outputs `h1, h2, ..., hn` (what we're attending to)

We compute:
```
score(s, hj) = v^T * tanh(W_s*s + W_h*hj)   # How relevant is position j?
alpha_j = softmax(scores)                   # Normalize to probabilities
context = sum(alpha_j * hj)                 # Weighted sum
```

In [None]:
class BahdanauAttention(nn.Module):
    """
    Additive Attention (Bahdanau et al., 2014)
    
    The 'additive' name comes from: tanh(W_s*s + W_h*h)
    We ADD the transformed query and key, then apply tanh.
    """
    
    def __init__(self, encoder_dim, decoder_dim, attention_dim=None):
        super().__init__()
        if attention_dim is None:
            attention_dim = decoder_dim
        
        self.W_h = nn.Linear(encoder_dim, attention_dim, bias=False)  # Key transform
        self.W_s = nn.Linear(decoder_dim, attention_dim, bias=False)  # Query transform
        self.v = nn.Linear(attention_dim, 1, bias=False)              # Score projection
    
    def forward(self, decoder_state, encoder_outputs, mask=None):
        """
        Args:
            decoder_state: (batch, decoder_dim) - current decoder hidden state
            encoder_outputs: (batch, src_len, encoder_dim) - all encoder states
            mask: (batch, src_len) - True for positions to ignore (padding)
        
        Returns:
            context: (batch, encoder_dim) - weighted sum of encoder outputs
            weights: (batch, src_len) - attention distribution
        """
        # Transform encoder outputs: (batch, src_len, attention_dim)
        encoder_proj = self.W_h(encoder_outputs)
        
        # Transform decoder state: (batch, attention_dim) -> (batch, 1, attention_dim)
        decoder_proj = self.W_s(decoder_state).unsqueeze(1)
        
        # Additive attention: (batch, src_len, attention_dim)
        combined = torch.tanh(encoder_proj + decoder_proj)
        
        # Get scores: (batch, src_len)
        scores = self.v(combined).squeeze(-1)
        
        # Apply mask (set padded positions to -inf)
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        
        # Normalize with softmax
        weights = F.softmax(scores, dim=-1)
        
        # Compute context: (batch, encoder_dim)
        context = torch.bmm(weights.unsqueeze(1), encoder_outputs).squeeze(1)
        
        return context, weights

# Verification
attention = BahdanauAttention(encoder_dim=128, decoder_dim=128)
decoder_state = torch.randn(2, 128)
encoder_outputs = torch.randn(2, 10, 128)

context, weights = attention(decoder_state, encoder_outputs)
print(f"Context shape: {context.shape}")
print(f"Weights shape: {weights.shape}")
print(f"Weights sum: {weights.sum(dim=-1)}")

## 2. Bidirectional Encoder

Why bidirectional? When we attend to position 3, we want context from:
- What came BEFORE (positions 1, 2)
- What comes AFTER (positions 4, 5, ...)

This is crucial for understanding the full sentence context (Section 3.2 of the paper).

In [None]:
class Encoder(nn.Module):
    """Bidirectional GRU Encoder"""
    
    def __init__(self, vocab_size, embed_size, hidden_size, dropout=0.1):
        super().__init__()
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.gru = nn.GRU(
            embed_size, hidden_size,
            batch_first=True,
            bidirectional=True
        )
        # Project bidirectional output back to hidden_size
        self.projection = nn.Linear(hidden_size * 2, hidden_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, src, src_lengths):
        """
        Args:
            src: (batch, src_len) - source token IDs
            src_lengths: (batch,) - actual lengths (for packing)
        
        Returns:
            outputs: (batch, src_len, hidden_size) - encoder states
            hidden: (1, batch, hidden_size) - final hidden state
        """
        embedded = self.dropout(self.embedding(src))
        
        # Pack for efficiency with variable lengths
        packed = pack_padded_sequence(
            embedded, src_lengths.cpu(),
            batch_first=True, enforce_sorted=False
        )
        
        outputs, hidden = self.gru(packed)
        
        # Unpack
        outputs, _ = pad_packed_sequence(outputs, batch_first=True)
        
        # Project to hidden_size
        outputs = self.projection(outputs)
        
        # Combine forward and backward hidden states
        hidden = torch.cat([hidden[0], hidden[1]], dim=-1)
        hidden = torch.tanh(self.projection(hidden)).unsqueeze(0)
        
        return outputs, hidden

# Test encoder
encoder = Encoder(vocab_size=50, embed_size=64, hidden_size=128)
src = torch.randint(3, 50, (2, 8))
src_lengths = torch.tensor([8, 5])

outputs, hidden = encoder(src, src_lengths)
print(f"Encoder outputs: {outputs.shape}")
print(f"Final hidden: {hidden.shape}")

## 3. Attention Decoder

At each decoding step:
1. **Attend** - Compute attention over encoder outputs
2. **Combine** - Merge context with current input
3. **Update** - Run through GRU
4. **Predict** - Generate next token

In [None]:
class AttentionDecoder(nn.Module):
    """GRU Decoder with Bahdanau Attention"""
    
    def __init__(self, vocab_size, embed_size, hidden_size, encoder_hidden_size, dropout=0.1):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.attention = BahdanauAttention(encoder_hidden_size, hidden_size)
        self.gru = nn.GRU(embed_size + encoder_hidden_size, hidden_size, batch_first=True)
        self.output = nn.Linear(hidden_size + encoder_hidden_size + embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward_step(self, input_token, hidden, encoder_outputs, mask=None):
        """
        One decoding step.
        
        Args:
            input_token: (batch,) - previous token
            hidden: (1, batch, hidden_size) - previous hidden state
            encoder_outputs: (batch, src_len, encoder_hidden_size)
            mask: (batch, src_len) - padding mask
        
        Returns:
            output: (batch, vocab_size) - token probabilities
            hidden: (1, batch, hidden_size) - new hidden state
            attention: (batch, src_len) - attention weights
        """
        embedded = self.dropout(self.embedding(input_token))
        
        # Compute attention
        context, attention = self.attention(hidden.squeeze(0), encoder_outputs, mask)
        
        # Combine embedding and context
        gru_input = torch.cat([embedded, context], dim=-1).unsqueeze(1)
        
        # GRU step
        gru_output, hidden = self.gru(gru_input, hidden)
        gru_output = gru_output.squeeze(1)
        
        # Predict next token
        combined = torch.cat([gru_output, context, embedded], dim=-1)
        output = self.output(combined)
        
        return output, hidden, attention
    
    def forward(self, trg, hidden, encoder_outputs, mask=None):
        """
        Full decoding pass with teacher forcing.
        """
        batch_size, trg_len = trg.shape
        vocab_size = self.output.out_features
        src_len = encoder_outputs.size(1)
        
        outputs = torch.zeros(batch_size, trg_len - 1, vocab_size, device=trg.device)
        attentions = torch.zeros(batch_size, trg_len - 1, src_len, device=trg.device)
        
        input_token = trg[:, 0]  # SOS token
        
        for t in range(trg_len - 1):
            output, hidden, attn = self.forward_step(input_token, hidden, encoder_outputs, mask)
            outputs[:, t] = output
            attentions[:, t] = attn
            input_token = trg[:, t + 1]  # Teacher forcing
        
        return outputs, attentions

# Test decoder
decoder = AttentionDecoder(
    vocab_size=50, embed_size=64, 
    hidden_size=128, encoder_hidden_size=128
)

trg = torch.randint(3, 50, (2, 6))
trg[:, 0] = 1  # SOS token

outputs, attentions = decoder(trg, hidden, outputs, mask=None)
print(f"Decoder outputs: {outputs.shape}")
print(f"Attention matrices: {attentions.shape}")

## 4. Complete Seq2Seq Model

In [None]:
class Seq2SeqWithAttention(nn.Module):
    """Complete sequence-to-sequence model with attention."""
    
    def __init__(self, src_vocab, trg_vocab, embed_size=64, hidden_size=128, dropout=0.1):
        super().__init__()
        self.encoder = Encoder(src_vocab, embed_size, hidden_size, dropout)
        self.decoder = AttentionDecoder(trg_vocab, embed_size, hidden_size, hidden_size, dropout)
        self.pad_idx = 0
    
    def forward(self, src, src_lengths, trg):
        mask = (src == self.pad_idx)
        encoder_outputs, hidden = self.encoder(src, src_lengths)
        outputs, attentions = self.decoder(trg, hidden, encoder_outputs, mask)
        return outputs, attentions
    
    @torch.no_grad()
    def translate(self, src, src_lengths, max_len=20, sos_idx=1, eos_idx=2):
        """Greedy decoding."""
        self.eval()
        batch_size = src.size(0)
        
        mask = (src == self.pad_idx)
        encoder_outputs, hidden = self.encoder(src, src_lengths)
        
        input_token = torch.full((batch_size,), sos_idx, device=src.device)
        translations = []
        attentions = []
        
        for _ in range(max_len):
            output, hidden, attn = self.decoder.forward_step(
                input_token, hidden, encoder_outputs, mask
            )
            input_token = output.argmax(dim=-1)
            translations.append(input_token)
            attentions.append(attn)
            
            if (input_token == eos_idx).all():
                break
        
        return torch.stack(translations, dim=1), torch.stack(attentions, dim=1)

# Create model
model = Seq2SeqWithAttention(src_vocab=50, trg_vocab=50).to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 5. Dataset: Sequence Reversal

A toy task to verify our attention works:
- Input: `[5, 3, 8, 2, 1]`
- Output: `[1, 2, 8, 3, 5]`

The attention should form a **reversed diagonal pattern**.

In [None]:
class ReversalDataset(Dataset):
    """
    Generate (sequence, reversed_sequence) pairs.
    
    Tokens: 0=PAD, 1=SOS, 2=EOS, 3+=data
    """
    def __init__(self, num_samples=5000, min_len=4, max_len=10, vocab_size=50):
        self.samples = []
        for _ in range(num_samples):
            length = random.randint(min_len, max_len)
            seq = [random.randint(3, vocab_size - 1) for _ in range(length)]
            self.samples.append((seq, list(reversed(seq))))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        src, trg = self.samples[idx]
        trg = [1] + trg + [2]  # Add SOS and EOS
        return torch.tensor(src), torch.tensor(trg)

def collate_fn(batch):
    srcs, trgs = zip(*batch)
    src_lengths = torch.tensor([len(s) for s in srcs])
    src_padded = nn.utils.rnn.pad_sequence(srcs, batch_first=True, padding_value=0)
    trg_padded = nn.utils.rnn.pad_sequence(trgs, batch_first=True, padding_value=0)
    return src_padded, src_lengths, trg_padded

# Create datasets
train_data = ReversalDataset(5000)
val_data = ReversalDataset(500)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_data, batch_size=32, collate_fn=collate_fn)

# Show example
src, trg = train_data[0]
print(f"Source:  {src.tolist()}")
print(f"Target:  {trg.tolist()} (with SOS=1, EOS=2)")

## 6. Training

In [None]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    
    for src, src_lengths, trg in loader:
        src, src_lengths, trg = src.to(device), src_lengths.to(device), trg.to(device)
        
        optimizer.zero_grad()
        outputs, _ = model(src, src_lengths, trg)
        
        # Reshape for loss
        outputs = outputs.reshape(-1, outputs.size(-1))
        trg = trg[:, 1:].reshape(-1)  # Skip SOS
        
        loss = criterion(outputs, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for src, src_lengths, trg in loader:
            src, src_lengths, trg = src.to(device), src_lengths.to(device), trg.to(device)
            outputs, _ = model(src, src_lengths, trg)
            outputs = outputs.reshape(-1, outputs.size(-1))
            trg = trg[:, 1:].reshape(-1)
            loss = criterion(outputs, trg)
            total_loss += loss.item()
    
    return total_loss / len(loader)

def accuracy(model, dataset, num_samples=200):
    model.eval()
    correct = 0
    
    for i in range(min(num_samples, len(dataset))):
        src, trg = dataset[i]
        src = src.unsqueeze(0).to(device)
        src_len = torch.tensor([len(src[0])]).to(device)
        
        pred, _ = model.translate(src, src_len)
        pred = pred[0].cpu().tolist()
        if 2 in pred:
            pred = pred[:pred.index(2)]
        
        target = trg[1:-1].tolist()
        if pred == target:
            correct += 1
    
    return correct / min(num_samples, len(dataset))

In [None]:
# Training loop
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=0.001)

EPOCHS = 30
history = {'train_loss': [], 'val_loss': [], 'accuracy': []}

print("Training...")
print("=" * 50)

for epoch in range(1, EPOCHS + 1):
    train_loss = train_epoch(model, train_loader, optimizer, criterion)
    val_loss = evaluate(model, val_loader, criterion)
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    
    if epoch % 5 == 0:
        acc = accuracy(model, val_data)
        history['accuracy'].append(acc)
        print(f"Epoch {epoch:2d} | Train: {train_loss:.4f} | Val: {val_loss:.4f} | Acc: {acc:.1%}")
    else:
        print(f"Epoch {epoch:2d} | Train: {train_loss:.4f} | Val: {val_loss:.4f}")

print("=" * 50)
final_acc = accuracy(model, val_data, 500)
print(f"Final Accuracy: {final_acc:.1%}")

## 7. Visualize Attention

For the reversal task, the attention should form a **reversed diagonal pattern** - output position $i$ attends to input position $n - 1 - i$.

In [None]:
def visualize_attention(model, dataset, idx=0):
    """Visualize attention for a single example."""
    model.eval()
    
    src, trg = dataset[idx]
    src_t = src.unsqueeze(0).to(device)
    src_len = torch.tensor([len(src)]).to(device)
    
    pred, attentions = model.translate(src_t, src_len)
    
    # Process outputs
    attn = attentions.squeeze(0).cpu().numpy()
    pred_list = pred.squeeze(0).cpu().tolist()
    
    if 2 in pred_list:
        eos_idx = pred_list.index(2)
        pred_list = pred_list[:eos_idx]
        attn = attn[:eos_idx]
    
    source = [str(t) for t in src.tolist()]
    target = [str(t) for t in pred_list]
    expected = [str(t) for t in trg[1:-1].tolist()]
    
    correct = pred_list == trg[1:-1].tolist()
    
    # Plot
    fig, ax = plt.subplots(figsize=(10, 8))
    
    im = ax.imshow(attn, cmap='Blues', aspect='auto', vmin=0, vmax=1)
    plt.colorbar(im, ax=ax, label='Attention Weight')
    
    ax.set_xticks(range(len(source)))
    ax.set_xticklabels(source, fontsize=12)
    ax.set_yticks(range(len(target)))
    ax.set_yticklabels(target, fontsize=12)
    
    ax.set_xlabel('Source (Input)', fontsize=13)
    ax.set_ylabel('Target (Output)', fontsize=13)
    
    status = "Correct" if correct else "[FAIL] Wrong"
    ax.set_title(f'Attention Pattern for Reversal Task\n{status}', fontsize=14)
    
    # Add values
    for i in range(len(target)):
        for j in range(len(source)):
            val = attn[i, j]
            color = 'white' if val > 0.5 else 'black'
            ax.text(j, i, f'{val:.2f}', ha='center', va='center', color=color, fontsize=10)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Source:   {src.tolist()}")
    print(f"Predicted: {pred_list}")
    print(f"Expected:  {trg[1:-1].tolist()}")

In [None]:
# Visualize a few examples
for i in range(3):
    print(f"\n{'='*50}")
    print(f"Example {i+1}")
    print('='*50)
    visualize_attention(model, val_data, i)

## 8. Key Takeaways

### What We Learned

1. **Attention solves the bottleneck** - No more cramming everything into one vector

2. **It's differentiable** - End-to-end training with backpropagation

3. **It's interpretable** - We can visualize what the model focuses on

4. **The pattern tells the story** - Reversed diagonal = correct behavior

### What's Next?

- **Self-Attention** - Attend within the same sequence
- **Transformers** - Replace RNNs entirely with attention

---

*Bahdanau attention is a direct ancestor of the Transformer (Vaswani et al. 2017). The core idea - letting the decoder dynamically attend to encoder states - carried forward into modern architectures, though the mechanism evolved from additive scoring over RNNs to scaled dot-product self-attention.*