# Attention Mechanism in Sequence-to-Sequence (Seq2Seq) Models

## Overview
This notebook demonstrates the implementation of an attention mechanism in a Seq2Seq model for machine translation. The attention mechanism allows the decoder to focus on different parts of the input sequence when generating each output token.

### Key Concepts:
- **Encoder**: Processes input sequence and produces context vectors
- **Decoder**: Generates output sequence using attention over encoder outputs
- **Attention Mechanism**: Dynamically weights encoder outputs to create context for each decoder step
- **Attention Weights**: Softmax-normalized scores indicating focus on each input word

## Part 1: Import Libraries and Setup

We'll start by importing all necessary libraries for building and training our attention-based Seq2Seq model.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import seaborn as sns

# Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Part 2: Dataset Preparation

We'll use a small English-to-French translation dataset. In practice, you'd use larger datasets from sources like Multi30k or WMT.

In [None]:
# Simple English-French translation pairs
pairs = [
    ("i am a student", "je suis etudiant"),
    ("he is a teacher", "il est professeur"),
    ("she is a doctor", "elle est medecin"),
    ("we are engineers", "nous sommes ingenieurs"),
    ("they are artists", "ils sont artistes")
]

print("Dataset:")
for eng, fra in pairs:
    print(f"  EN: {eng:30s} | FR: {fra}")

## Part 3: Language Vocabulary Class

This class builds a vocabulary by mapping words to indices. Special tokens:
- `<SOS>` (Start Of Sequence): Index 0, marks the beginning of a sentence
- `<EOS>` (End Of Sequence): Index 1, marks the end of a sentence

In [None]:
class Lang:
    """Class to manage vocabulary and word-to-index mapping."""

    def __init__(self):
        """Initialize with special tokens."""
        self.word2index = {"<SOS>": 0, "<EOS>": 1}
        self.index2word = {0: "<SOS>", 1: "<EOS>"}
        self.n_words = 2  # Count of unique words

    def add_sentence(self, sentence):
        """Add all words from a sentence to vocabulary."""
        for word in sentence.split(' '):
            if word not in self.word2index:
                self.word2index[word] = self.n_words
                self.index2word[self.n_words] = word
                self.n_words += 1

# Build vocabularies for both languages
input_lang, output_lang = Lang(), Lang()

for eng, fra in pairs:
    input_lang.add_sentence(eng)
    output_lang.add_sentence(fra)

print(f"English Vocabulary Size: {input_lang.n_words}")
print(f"French Vocabulary Size: {output_lang.n_words}")
print(f"\nEnglish Word2Index: {input_lang.word2index}")
print(f"\nFrench Word2Index: {output_lang.word2index}")

## Part 4: Encoder Architecture

The **Encoder** reads the input sequence and produces:
- **Output**: Hidden state at each timestep (used for attention)
- **Hidden State**: Final context vector passed to decoder

### Components:
- **Embedding Layer**: Converts word indices to dense vectors
- **GRU Layer**: Gated Recurrent Unit processes embeddings sequentially

In [None]:
class Encoder(nn.Module):
    """Encoder: Maps input sequence to hidden representations."""

    def __init__(self, input_size, hidden_size):
        """Initialize encoder layers.

        Args:
            input_size: Size of input vocabulary
            hidden_size: Dimension of hidden states
        """
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size

        # Embedding layer: converts word indices to vectors
        self.embedding = nn.Embedding(input_size, hidden_size)

        # GRU layer: processes sequence of embeddings
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        """Forward pass through encoder.

        Args:
            input: Word index tensor
            hidden: Previous hidden state

        Returns:
            output: Current output (used for attention)
            hidden: New hidden state
        """
        # Embed the input word
        embedded = self.embedding(input).view(1, 1, -1)  # Shape: (1, 1, hidden_size)

        # Pass through GRU
        output, hidden = self.gru(embedded, hidden)

        return output, hidden

print("Encoder class defined successfully!")

## Part 5: Attention Decoder Architecture

The **Attention Decoder** generates the output sequence while dynamically attending to input words.

### Attention Mechanism Steps:
1. **Attention Score**: Combine decoder hidden state with encoder output
2. **Attention Weights**: Apply softmax to get probability distribution over inputs
3. **Context Vector**: Weighted sum of encoder outputs
4. **Combine**: Concatenate context with decoder input and pass through GRU
5. **Output**: Linear layer produces probability distribution over target vocabulary

In [None]:
class AttnDecoder(nn.Module):
    """Decoder with Attention mechanism."""

    def __init__(self, hidden_size, output_size, max_length=10):
        """Initialize decoder layers.

        Args:
            hidden_size: Dimension of hidden states
            output_size: Size of output vocabulary
            max_length: Maximum sequence length for attention
        """
        super(AttnDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.max_length = max_length

        # Embedding layer for target vocabulary
        self.embedding = nn.Embedding(self.output_size, self.hidden_size)

        # Attention layer: computes attention scores
        # Input: concatenation of embedding and hidden state
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)

        # Attention combine layer: combines context and embedding
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)

        # GRU layer
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)

        # Output projection: maps hidden state to vocabulary probabilities
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        """Forward pass with attention.

        Args:
            input: Word index for decoder input
            hidden: Decoder hidden state
            encoder_outputs: All encoder outputs for attention

        Returns:
            output: Probabilities over output vocabulary
            hidden: New hidden state
            attn_weights: Attention weights
        """
        # Step 1: Embed the input
        embedded = self.embedding(input).view(1, 1, -1)  # Shape: (1, 1, hidden_size)

        # Step 2: Calculate attention weights
        # Concatenate embedded input with current hidden state
        attn_input = torch.cat((embedded[0], hidden[0]), 1)  # Shape: (1, hidden_size*2)
        attn_scores = self.attn(attn_input)  # Shape: (1, max_length)
        attn_weights = F.softmax(attn_scores, dim=1)  # Normalize to probabilities

        # Step 3: Apply attention weights to encoder outputs
        # This creates a weighted context vector
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))

        # Step 4: Combine embedded input with attention context
        output = torch.cat((embedded[0], attn_applied[0]), 1)  # Shape: (1, hidden_size*2)
        output = self.attn_combine(output).unsqueeze(0)  # Shape: (1, 1, hidden_size)
        output = F.relu(output)

        # Step 5: Pass through GRU
        output, hidden = self.gru(output, hidden)

        # Step 6: Generate output probabilities
        output = F.log_softmax(self.out(output[0]), dim=1)

        return output, hidden, attn_weights

print("Attention Decoder class defined successfully!")

## Part 6: Training Function

The training function implements the full forward and backward pass:

### Training Steps:
1. **Encode**: Process entire input sequence, storing all outputs
2. **Decode with Teacher Forcing**: Feed ground truth as input at each step
3. **Calculate Loss**: Compare predictions with ground truth
4. **Backpropagate**: Update model parameters via gradient descent

In [None]:
def train(input_tensor, target_tensor, encoder, decoder,
          encoder_optimizer, decoder_optimizer, criterion, max_length=10):
    """Train the encoder and decoder for one sample.

    Args:
        input_tensor: Input sequence as tensor
        target_tensor: Target sequence as tensor
        encoder: Encoder model
        decoder: Decoder model with attention
        encoder_optimizer: Optimizer for encoder
        decoder_optimizer: Optimizer for decoder
        criterion: Loss function (NLLLoss)
        max_length: Maximum sequence length

    Returns:
        Normalized loss
    """
    # Initialize encoder hidden state
    encoder_hidden = torch.zeros(1, 1, encoder.hidden_size, device=device)

    # Clear gradients
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    # Store encoder outputs for attention
    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

    loss = 0

    # ENCODER: Process input sequence word by word
    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]  # Store for attention

    # DECODER: Generate output sequence with teacher forcing
    # Start with SOS (Start of Sequence) token
    decoder_input = torch.tensor([[0]], device=device)
    decoder_hidden = encoder_hidden  # Pass encoder's final state to decoder

    # Teacher forcing: feed ground truth at each step
    for di in range(target_length):
        decoder_output, decoder_hidden, decoder_attention = decoder(
            decoder_input, decoder_hidden, encoder_outputs)

        # Calculate loss for this step
        loss += criterion(decoder_output, target_tensor[di].unsqueeze(0))

        # Use ground truth as next input (teacher forcing)
        decoder_input = target_tensor[di]

    # BACKPROPAGATION
    loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length

print("Training function defined successfully!")

## Part 7: Model Initialization

Now we'll initialize the encoder and decoder models with hyperparameters.

In [None]:
# Hyperparameters
hidden_size = 128
learning_rate = 0.01
num_epochs = 400

# Initialize models
encoder = Encoder(input_lang.n_words, hidden_size).to(device)
decoder = AttnDecoder(hidden_size, output_lang.n_words).to(device)

# Initialize optimizers
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

# Loss function (for log probabilities)
criterion = nn.NLLLoss()

print(f"Encoder: {encoder}")
print(f"\nDecoder: {decoder}")
print(f"\nTotal encoder parameters: {sum(p.numel() for p in encoder.parameters())}")
print(f"Total decoder parameters: {sum(p.numel() for p in decoder.parameters())}")

## Part 8: Training Loop

Train the model for multiple epochs, iterating over all translation pairs.

In [None]:
print(f"Starting training for {num_epochs} epochs...\n")

for epoch in range(num_epochs):
    total_loss = 0

    for eng, fra in pairs:
        # Convert sentences to tensor of word indices
        input_tensor = torch.tensor(
            [input_lang.word2index[w] for w in eng.split(' ')],
            device=device
        )
        target_tensor = torch.tensor(
            [output_lang.word2index[w] for w in fra.split(' ')],
            device=device
        )

        # Train on this pair
        loss = train(input_tensor, target_tensor, encoder, decoder,
                    encoder_optimizer, decoder_optimizer, criterion)
        total_loss += loss

    # Print progress
    if (epoch + 1) % 50 == 0:
        avg_loss = total_loss / len(pairs)
        print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

print(f"\nTraining completed!")

## Part 9: Inference Function

This function translates a sentence without teacher forcing (greedy decoding).

### Decoding Strategy:
- **Greedy Decoding**: Select word with highest probability at each step
- **Attention Visualization**: Collect attention weights for visualization

In [None]:
def translate(encoder, decoder, sentence, max_length=10):
    """Translate a sentence without teacher forcing.

    Args:
        encoder: Trained encoder
        decoder: Trained decoder
        sentence: Input sentence string
        max_length: Maximum output length

    Returns:
        translated_words: List of translated words
    """
    with torch.no_grad():
        # Convert input sentence to tensor
        input_tensor = torch.tensor(
            [input_lang.word2index[w] for w in sentence.split(' ')],
            device=device
        )

        # Initialize encoder
        encoder_hidden = torch.zeros(1, 1, encoder.hidden_size, device=device)
        encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

        # Encode input
        for i in range(input_tensor.size(0)):
            encoder_output, encoder_hidden = encoder(input_tensor[i], encoder_hidden)
            encoder_outputs[i] = encoder_output[0, 0]

        # Initialize decoder with SOS token
        decoder_input = torch.tensor([[0]], device=device)
        decoder_hidden = encoder_hidden

        translated_words = []

        # Generate output greedily
        for _ in range(max_length):
            decoder_output, decoder_hidden, attn_weights = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )

            # Select word with highest probability
            topi = decoder_output.argmax(1)

            # Stop if EOS token is generated
            if topi.item() == 1:
                break

            # Add word to translation
            translated_words.append(output_lang.index2word[topi.item()])
            decoder_input = topi.view(1, 1)

        return translated_words

# Test translation
print("Translation Examples:")
for eng, fra in pairs:
    translated = translate(encoder, decoder, eng)
    print(f"EN: {eng:30s} | FR (True): {fra:30s} | FR (Predicted): {' '.join(translated)}")

## Part 10: Visualization - Attention Heatmap

Now we'll visualize the attention weights as a heatmap. This shows which input words the decoder focuses on when generating each output word.

### Interpretation:
- **X-axis**: Input words (English)
- **Y-axis**: Output words (French)
- **Color intensity**: Attention weight (lighter = higher attention)

In [None]:
def visualize_attention(encoder, decoder, sentence, max_length=10):
    """Visualize attention weights as a heatmap.

    Args:
        encoder: Trained encoder
        decoder: Trained decoder
        sentence: Input sentence to visualize
        max_length: Maximum sequence length
    """
    with torch.no_grad():
        # Convert input to tensor
        input_tensor = torch.tensor(
            [input_lang.word2index[word] for word in sentence.split(' ')],
            device=device
        )
        input_length = input_tensor.size(0)

        # Initialize encoder
        encoder_hidden = torch.zeros(1, 1, encoder.hidden_size, device=device)
        encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

        # Encode input
        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
            encoder_outputs[ei] = encoder_output[0, 0]

        # Initialize decoder
        decoder_input = torch.tensor([[0]], device=device)
        decoder_hidden = encoder_hidden
        decoded_words = []
        decoder_attentions = torch.zeros(max_length, max_length)

        # Generate output and collect attention weights
        for di in range(max_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_attentions[di] = decoder_attention.data

            # Get word with highest probability
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == 1:  # EOS
                break
            else:
                decoded_words.append(output_lang.index2word[topi.item()])
            decoder_input = topi.squeeze().detach()

        # Trim attention matrix to actual lengths
        attention_data = decoder_attentions[:len(decoded_words), :input_length].cpu().numpy()

        # Create figure
        fig, ax = plt.subplots(figsize=(10, 6))

        # Plot heatmap
        im = ax.imshow(attention_data, cmap='bone', aspect='auto')
        cbar = fig.colorbar(im, ax=ax)
        cbar.set_label('Attention Weight', rotation=270, labelpad=20)

        # Set ticks and labels
        ax.set_xticks(np.arange(input_length))
        ax.set_yticks(np.arange(len(decoded_words)))
        ax.set_xticklabels(sentence.split(' '), rotation=45, ha='right')
        ax.set_yticklabels(decoded_words)

        # Add grid
        ax.set_xticks(np.arange(input_length) - 0.5, minor=True)
        ax.set_yticks(np.arange(len(decoded_words)) - 0.5, minor=True)
        ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5)

        # Labels
        ax.set_xlabel('Input Words (English)', fontsize=12, fontweight='bold')
        ax.set_ylabel('Output Words (French)', fontsize=12, fontweight='bold')
        ax.set_title(
            f'Attention Mechanism: "{sentence}" â†’ "{" ".join(decoded_words)}"',
            fontsize=13, fontweight='bold', pad=15
        )

        plt.tight_layout()
        plt.show()

        return decoded_words, attention_data

print("Visualization function defined!")

## Part 11: Visualize Attention for Sample Sentences

Let's see how the attention mechanism works for different input sentences.

In [None]:
# Visualize attention for different sentences
print("="*60)
print("ATTENTION VISUALIZATION FOR SENTENCE 1")
print("="*60)
visualize_attention(encoder, decoder, "i am a student")

print("\n" + "="*60)
print("ATTENTION VISUALIZATION FOR SENTENCE 2")
print("="*60)
visualize_attention(encoder, decoder, "he is a teacher")

print("\n" + "="*60)
print("ATTENTION VISUALIZATION FOR SENTENCE 3")
print("="*60)
visualize_attention(encoder, decoder, "she is a doctor")

## Part 12: Enhanced Visualization with Seaborn

Alternative visualization using seaborn for a more polished look.

In [None]:
def visualize_attention_seaborn(encoder, decoder, sentence, max_length=10):
    """Visualize attention using seaborn heatmap.

    Args:
        encoder: Trained encoder
        decoder: Trained decoder
        sentence: Input sentence
        max_length: Maximum sequence length
    """
    with torch.no_grad():
        input_tensor = torch.tensor(
            [input_lang.word2index[word] for word in sentence.split(' ')],
            device=device
        )
        input_length = input_tensor.size(0)

        encoder_hidden = torch.zeros(1, 1, encoder.hidden_size, device=device)
        encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
            encoder_outputs[ei] = encoder_output[0, 0]

        decoder_input = torch.tensor([[0]], device=device)
        decoder_hidden = encoder_hidden
        decoded_words = []
        decoder_attentions = torch.zeros(max_length, max_length)

        for di in range(max_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_attentions[di] = decoder_attention.data
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == 1:
                break
            else:
                decoded_words.append(output_lang.index2word[topi.item()])
            decoder_input = topi.squeeze().detach()

        attention_data = decoder_attentions[:len(decoded_words), :input_length].cpu().numpy()

        # Seaborn heatmap
        fig, ax = plt.subplots(figsize=(10, 6))
        sns.heatmap(
            attention_data,
            xticklabels=sentence.split(' '),
            yticklabels=decoded_words,
            cmap='YlOrRd',
            cbar_kws={'label': 'Attention Weight'},
            ax=ax,
            linewidths=0.5,
            linecolor='gray'
        )

        ax.set_xlabel('Input Words (English)', fontsize=12, fontweight='bold')
        ax.set_ylabel('Output Words (French)', fontsize=12, fontweight='bold')
        ax.set_title(
            f'Attention Heatmap (Seaborn): "{sentence}"',
            fontsize=13, fontweight='bold'
        )
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()

        return decoded_words, attention_data

# Visualize with seaborn
print("SEABORN VISUALIZATION")
visualize_attention_seaborn(encoder, decoder, "we are engineers")

## Part 13: Summary and Key Insights

### What We Learned:

1. **Encoder-Decoder Architecture**: The encoder processes the input and passes a context vector to the decoder

2. **Attention Mechanism**:
   - Calculates dynamic attention weights for each decoder step
   - Allows the model to focus on relevant input words
   - Improves translation quality, especially for longer sequences

3. **Training Process**:
   - Uses teacher forcing during training
   - Loss is calculated for each output word
   - Gradients flow through both encoder and decoder

4. **Inference**:
   - Uses greedy decoding (select highest probability word)
   - No teacher forcing at inference time
   - Attention weights show interpretability

### Attention Weights Interpretation:
- The heatmap shows how much the decoder attends to each input word
- Lighter colors = higher attention
- Typically shows alignment between source and target languages
- Can reveal linguistic patterns the model has learned

### Extensions:
- **Multi-head Attention**: Multiple attention mechanisms in parallel
- **Transformer Models**: Replaced RNNs with self-attention
- **Beam Search**: Generate multiple hypotheses for better translations
- **Larger Datasets**: Use real datasets (WMT, Multi30k) for better performance

In [None]:
# Final statistics
print("\n" + "="*60)
print("MODEL STATISTICS")
print("="*60)
print(f"Encoder hidden size: {hidden_size}")
print(f"Encoder parameters: {sum(p.numel() for p in encoder.parameters()):,}")
print(f"Decoder parameters: {sum(p.numel() for p in decoder.parameters()):,}")
print(f"Total parameters: {sum(p.numel() for p in encoder.parameters()) + sum(p.numel() for p in decoder.parameters()):,}")
print(f"\nTraining epochs: {num_epochs}")
print(f"Learning rate: {learning_rate}")
print(f"Device: {device}")
print("="*60)