# Module 05: Attention Mechanism

**Difficulty**: ⭐⭐⭐ Advanced  
**Estimated Time**: 130 minutes  
**Prerequisites**: [Module 04: Sequence-to-Sequence Models](04_sequence_to_sequence.ipynb)

## Learning Objectives

By the end of this notebook, you will be able to:

1. Understand the motivation and intuition behind attention mechanisms
2. Implement Bahdanau (additive) attention from scratch
3. Implement Luong (multiplicative) attention
4. Visualize attention weights to interpret model decisions
5. Understand self-attention as precursor to transformers
6. Compare different attention mechanisms and their trade-offs

## The Attention Revolution

### Problem with Standard Seq2Seq:

In vanilla seq2seq, the **entire input** is compressed into a **single fixed-size context vector**.

**Issues**:
- Information bottleneck for long sequences
- Encoder must remember everything
- Performance degrades with sequence length

### Solution: Attention Mechanism

**Key insight**: Let the decoder **attend to different parts** of the input at each decoding step!

**Analogy**: When translating "The cat sat on the mat" to French:
- When generating "chat" (cat), focus on "cat"
- When generating "le" (the), focus on "the"
- Dynamic focus on relevant input parts!

## Setup and Imports

In [None]:
# Core libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
import warnings
warnings.filterwarnings('ignore')

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Visualization
%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

# Random seeds
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print("✓ All libraries imported successfully!")

## 1. Attention Intuition

### How Attention Works:

1. **Query (Q)**: Current decoder state (what we're looking for)
2. **Keys (K)**: Encoder hidden states (what we can attend to)
3. **Values (V)**: Also encoder hidden states (what we retrieve)

**Steps**:
1. Compute alignment scores: $e_{ij} = \text{score}(h_i^{dec}, h_j^{enc})$
2. Normalize with softmax: $\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_k \exp(e_{ik})}$ (attention weights)
3. Weighted sum: $c_i = \sum_j \alpha_{ij} h_j^{enc}$ (context vector)

**Result**: Different context vector for each decoder step!

## 2. Bahdanau Attention (Additive)

**Bahdanau et al., 2015**: First attention mechanism for NMT.

**Score function**:
$$\text{score}(h_i^{dec}, h_j^{enc}) = v^T \tanh(W_1 h_i^{dec} + W_2 h_j^{enc})$$

Where $v$, $W_1$, $W_2$ are learned parameters.

In [None]:
class BahdanauAttention(nn.Module):
    """
    Bahdanau (additive) attention mechanism.
    
    Reference: "Neural Machine Translation by Jointly Learning to Align and Translate"
    Bahdanau et al., ICLR 2015
    """
    
    def __init__(self, hidden_dim):
        """
        Parameters:
        -----------
        hidden_dim : int
            Dimension of hidden states
        """
        super(BahdanauAttention, self).__init__()
        
        # Learned parameters
        self.W_query = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_key = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v = nn.Linear(hidden_dim, 1, bias=False)
        
    def forward(self, query, keys, values, mask=None):
        """
        Compute attention.
        
        Parameters:
        -----------
        query : torch.Tensor
            Decoder hidden state (batch_size, hidden_dim)
        keys : torch.Tensor
            Encoder hidden states (batch_size, seq_len, hidden_dim)
        values : torch.Tensor
            Encoder hidden states (batch_size, seq_len, hidden_dim)
        mask : torch.Tensor or None
            Padding mask (batch_size, seq_len)
            
        Returns:
        --------
        context : torch.Tensor
            Attention-weighted context (batch_size, hidden_dim)
        attention_weights : torch.Tensor
            Attention distribution (batch_size, seq_len)
        """
        # Expand query to match keys shape
        # query: (batch, hidden) -> (batch, 1, hidden) -> (batch, seq_len, hidden)
        query_expanded = query.unsqueeze(1)
        
        # Compute alignment scores
        # score = v^T * tanh(W_query * query + W_key * keys)
        scores = self.v(torch.tanh(
            self.W_query(query_expanded) + self.W_key(keys)
        )).squeeze(-1)  # (batch, seq_len)
        
        # Apply mask if provided (set padded positions to -inf)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax to get attention weights
        attention_weights = F.softmax(scores, dim=1)  # (batch, seq_len)
        
        # Weighted sum of values
        context = torch.bmm(
            attention_weights.unsqueeze(1),  # (batch, 1, seq_len)
            values  # (batch, seq_len, hidden)
        ).squeeze(1)  # (batch, hidden)
        
        return context, attention_weights

print("✓ BahdanauAttention class defined!")

In [None]:
# Test Bahdanau attention
batch_size = 2
seq_len = 5
hidden_dim = 8

# Create dummy data
query = torch.randn(batch_size, hidden_dim)  # Current decoder state
keys = torch.randn(batch_size, seq_len, hidden_dim)  # Encoder states
values = keys  # Usually same as keys

# Initialize attention
attention = BahdanauAttention(hidden_dim)

# Forward pass
context, weights = attention(query, keys, values)

print(f"Query shape: {query.shape}")
print(f"Keys shape: {keys.shape}")
print(f"Context shape: {context.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"\nAttention weights (should sum to 1):")
print(weights)
print(f"Sum: {weights.sum(dim=1)}")

## 3. Luong Attention (Multiplicative)

**Luong et al., 2015**: Simpler, more efficient attention.

**Three variants**:

1. **Dot**: $\text{score}(h_i^{dec}, h_j^{enc}) = h_i^{dec} \cdot h_j^{enc}$
2. **General**: $\text{score}(h_i^{dec}, h_j^{enc}) = h_i^{dec} W h_j^{enc}$
3. **Concat**: Similar to Bahdanau

Most common: **General (multiplicative)**

In [None]:
class LuongAttention(nn.Module):
    """
    Luong (multiplicative) attention mechanism.
    
    Reference: "Effective Approaches to Attention-based Neural Machine Translation"
    Luong et al., EMNLP 2015
    """
    
    def __init__(self, hidden_dim, method='general'):
        """
        Parameters:
        -----------
        hidden_dim : int
            Dimension of hidden states
        method : str
            Attention method: 'dot', 'general', or 'concat'
        """
        super(LuongAttention, self).__init__()
        
        self.method = method
        self.hidden_dim = hidden_dim
        
        if method == 'general':
            self.W = nn.Linear(hidden_dim, hidden_dim, bias=False)
        elif method == 'concat':
            self.W = nn.Linear(hidden_dim * 2, hidden_dim)
            self.v = nn.Linear(hidden_dim, 1, bias=False)
        
    def forward(self, query, keys, values, mask=None):
        """
        Compute Luong attention.
        """
        # Compute scores based on method
        if self.method == 'dot':
            # Simple dot product
            scores = torch.bmm(
                query.unsqueeze(1),  # (batch, 1, hidden)
                keys.transpose(1, 2)  # (batch, hidden, seq_len)
            ).squeeze(1)  # (batch, seq_len)
            
        elif self.method == 'general':
            # Learned transformation then dot product
            scores = torch.bmm(
                self.W(query).unsqueeze(1),
                keys.transpose(1, 2)
            ).squeeze(1)
            
        elif self.method == 'concat':
            # Concatenate and feed through network
            query_expanded = query.unsqueeze(1).expand(-1, keys.size(1), -1)
            combined = torch.cat([query_expanded, keys], dim=2)
            scores = self.v(torch.tanh(self.W(combined))).squeeze(-1)
        
        # Apply mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax
        attention_weights = F.softmax(scores, dim=1)
        
        # Weighted sum
        context = torch.bmm(
            attention_weights.unsqueeze(1),
            values
        ).squeeze(1)
        
        return context, attention_weights

print("✓ LuongAttention class defined!")

**Exercise 1**: Compare attention mechanisms

1. Implement all three Luong attention variants
2. Compare their computational complexity
3. Test on same input and compare results
4. Which is fastest? Which uses most parameters?

In [None]:
# YOUR CODE HERE
# Compare different attention methods

## 4. Visualizing Attention Weights

Attention weights show **which input words** the model focuses on when generating each output word.

**Interpretability**: We can visualize and understand model decisions!

In [None]:
def visualize_attention(input_words, output_words, attention_weights):
    """
    Visualize attention weights as heatmap.
    
    Parameters:
    -----------
    input_words : list of str
        Source sentence words
    output_words : list of str
        Target sentence words
    attention_weights : np.ndarray
        Attention matrix (output_len, input_len)
    """
    plt.figure(figsize=(10, 8))
    
    sns.heatmap(
        attention_weights,
        xticklabels=input_words,
        yticklabels=output_words,
        cmap='YlOrRd',
        cbar_kws={'label': 'Attention Weight'},
        annot=True,
        fmt='.2f'
    )
    
    plt.xlabel('Input Sequence')
    plt.ylabel('Output Sequence')
    plt.title('Attention Weights Visualization')
    plt.tight_layout()
    plt.show()

# Example: English to French translation
input_sent = ['the', 'cat', 'sat', 'on', 'the', 'mat']
output_sent = ['le', 'chat', 'assis', 'sur', 'le', 'tapis']

# Simulate attention (in practice, this comes from trained model)
attention_matrix = np.array([
    [0.8, 0.1, 0.0, 0.0, 0.1, 0.0],  # 'le' attends to 'the'
    [0.1, 0.8, 0.1, 0.0, 0.0, 0.0],  # 'chat' attends to 'cat'
    [0.0, 0.2, 0.7, 0.1, 0.0, 0.0],  # 'assis' attends to 'sat'
    [0.0, 0.0, 0.2, 0.7, 0.1, 0.0],  # 'sur' attends to 'on'
    [0.1, 0.0, 0.0, 0.1, 0.8, 0.0],  # 'le' attends to 'the'
    [0.0, 0.0, 0.0, 0.1, 0.1, 0.8],  # 'tapis' attends to 'mat'
])

visualize_attention(input_sent, output_sent, attention_matrix)

**Observation**: The attention learns **alignments** between source and target words!

For translation, we see diagonal pattern (monotonic alignment).

## 5. Self-Attention: Introduction

**Self-attention**: Attention where query, keys, and values all come from the **same sequence**!

**Purpose**: Model relationships between words in same sentence.

**Example**: "The animal didn't cross the street because **it** was too tired"
- What does "it" refer to?
- Self-attention: "it" attends strongly to "animal"

**This is the core of Transformers!** (Module 06)

In [None]:
class SelfAttention(nn.Module):
    """
    Self-attention mechanism (simplified scaled dot-product).
    
    This is the building block of Transformers!
    """
    
    def __init__(self, hidden_dim):
        super(SelfAttention, self).__init__()
        
        self.hidden_dim = hidden_dim
        
        # Linear projections for Q, K, V
        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)
        
    def forward(self, x, mask=None):
        """
        Compute self-attention.
        
        Parameters:
        -----------
        x : torch.Tensor
            Input sequence (batch, seq_len, hidden_dim)
        mask : torch.Tensor or None
            Attention mask
            
        Returns:
        --------
        output : torch.Tensor
            Attention output (batch, seq_len, hidden_dim)
        attention_weights : torch.Tensor
            Attention weights (batch, seq_len, seq_len)
        """
        batch_size, seq_len, _ = x.size()
        
        # Project to Q, K, V
        Q = self.query_proj(x)  # (batch, seq_len, hidden)
        K = self.key_proj(x)    # (batch, seq_len, hidden)
        V = self.value_proj(x)  # (batch, seq_len, hidden)
        
        # Scaled dot-product attention
        # scores = Q * K^T / sqrt(d_k)
        scores = torch.bmm(Q, K.transpose(1, 2)) / np.sqrt(self.hidden_dim)
        # (batch, seq_len, seq_len)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax over keys dimension
        attention_weights = F.softmax(scores, dim=-1)
        # (batch, seq_len, seq_len)
        
        # Weighted sum of values
        output = torch.bmm(attention_weights, V)
        # (batch, seq_len, hidden)
        
        return output, attention_weights

print("✓ SelfAttention class defined!")

In [None]:
# Test self-attention
seq_len = 6
hidden_dim = 8
batch_size = 2

# Input sequence
x = torch.randn(batch_size, seq_len, hidden_dim)

# Self-attention
self_attn = SelfAttention(hidden_dim)
output, attn_weights = self_attn(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"\nAttention weights for first sample:")
print(attn_weights[0].detach().numpy().round(2))

**Exercise 2**: Analyze self-attention

1. Visualize self-attention weights for a sentence
2. Identify which words attend to which other words
3. Compare with encoder-decoder attention
4. Explain: Why is self-attention O(n²) in sequence length?

In [None]:
# YOUR CODE HERE
# Visualize and analyze self-attention patterns

## Summary

### Key Concepts Covered:

1. **Attention Motivation**:
   - Solves context bottleneck in seq2seq
   - Dynamic focus on relevant input parts
   - Different context for each decoder step

2. **Bahdanau Attention**:
   - Additive attention mechanism
   - First successful attention for NMT
   - Uses tanh and learned parameters

3. **Luong Attention**:
   - Multiplicative attention
   - Simpler and more efficient
   - Three variants: dot, general, concat

4. **Attention Visualization**:
   - Interpretable alignments
   - Shows what model focuses on
   - Useful for debugging and analysis

5. **Self-Attention**:
   - Query, key, value from same sequence
   - Models intra-sentence dependencies
   - Core building block of Transformers

### Attention Benefits:

✅ Better long-range dependencies  
✅ Interpretability through visualization  
✅ Parallelizable (self-attention)  
✅ State-of-the-art performance  

### What's Next?

In **Module 06: Transformer Architecture**, we'll learn:
- Full transformer architecture using self-attention
- Multi-head attention
- Positional encoding
- The "Attention is All You Need" revolution

### Additional Resources:

- **Bahdanau Paper**: [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473)
- **Luong Paper**: [Effective Approaches to Attention-based NMT](https://arxiv.org/abs/1508.04025)
- **Blog**: [Jay Alammar's Visualizing Attention](https://jalammar.github.io/visualizing-neural-machine-translation-mechanics-of-seq2seq-models-with-attention/)