# ðŸŽ¯ Notebook 05: Attention Mechanisms - A Deep Dive

**Week 3-4: Deep Learning & NLP Foundations**  
**Gen AI Masters Program**

---

## ðŸ“‹ Objectives

In the last few notebooks, we've seen the power of LSTMs and Transformers. A key ingredient in the success of the Transformer is **attention**. But attention is not just one thing; it's a general and powerful concept that can be used in many different ways to allow a model to dynamically **focus on specific parts of its input** when performing a task.

Instead of forcing a model to compress an entire input sequence (like a long sentence) into a single fixed-size vectorâ€”a major bottleneckâ€”an attention mechanism computes a set of **attention weights**. These weights determine how much importance should be paid to each part of the input for a given step.

By the end of this notebook, you will be able to:
1.  **Understand the Core Intuition**: Grasp the high-level concept of attention as a "relevance-scoring" and "weighted-sum" mechanism.
2.  **Implement Scaled Dot-Product Attention**: Implement the most common type of attention from scratchâ€”the same one used in the Transformerâ€”and understand the roles of the **Query (Q)**, **Key (K)**, and **Value (V)** vectors.
3.  **Build a Seq2Seq Model with Attention**: Construct a simple sequence-to-sequence (Seq2Seq) model with an attention layer to see how it works in a practical translation-like task, solving the bottleneck problem of older architectures.
4.  **Visualize Attention Weights**: Create heatmaps to visualize the attention weights, providing a powerful and interpretable view into the model's decision-making process.

**Estimated Time:** 2-3 hours

---

## ðŸ“š Why is Attention So Important?

1.  **Solving the Bottleneck Problem**: In older encoder-decoder models (like those using LSTMs), the encoder had to compress the entire meaning of a long input sentence into a single "context vector." This was a huge bottleneck. Attention allows the decoder to "look back" at the entire input sequence at every step of the output generation, focusing on the most relevant parts.
2.  **Interpretability**: Attention weights are highly interpretable. By visualizing them, we can see what the model is "looking at" when it makes a prediction. For example, in machine translation, we can see which source words the model focuses on when generating a specific target word.
3.  **Performance Boost**: Attention mechanisms have consistently led to state-of-the-art results in a wide range of tasks, from NLP to computer vision. They are a fundamental building block of most modern deep learning architectures.

This notebook will demystify the attention mechanism and give you a solid, hands-on understanding of this fundamental concept. Let's begin! ðŸš€

## ðŸš€ Agenda

1.  **Setting the Stage**: We'll import the necessary libraries and configure our environment.
2.  **Scaled Dot-Product Attention from Scratch**: We'll implement the most famous attention mechanism, breaking down the formula and the roles of Query, Key, and Value.
3.  **Attention in a Seq2Seq Model**: We'll build a simple Encoder-Decoder model that uses an attention layer to translate a sequence of numbers into a sorted sequence. This will make the mechanism's function crystal clear.
    *   **Encoder**: A GRU to process the input sequence.
    *   **Attention Layer**: To compute the context vector.
    *   **Decoder**: A GRU that uses the context vector to generate the output.
4.  **Visualizing Attention**: We'll create a heatmap of the attention weights to see exactly where the model "looks" during the decoding process.
5.  **Conceptual Overview**: We'll conclude with a brief summary of different attention variants like self-attention vs. cross-attention.

In [None]:
# --- 1. Set up the Environment ---

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import math
from tqdm.notebook import tqdm

# For visualization
import matplotlib.pyplot as plt
import seaborn as sns

# --- Configuration ---

# Set a seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Set the default device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"âœ… Using device: {device.upper()}")

# --- Plotting Style ---
plt.style.use("seaborn-v0_8-whitegrid")
plt.rcParams["figure.figsize"] = (10, 6)
sns.set_palette("colorblind")

## ðŸ§  Part 2: Scaled Dot-Product Attention from Scratch

This is the heart of the Transformer and the most common form of attention used today. The formula looks intimidating at first, but we'll break it down into simple steps.

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

Let's understand the components, which are just matrices derived from the input embeddings:

*   **Query (Q):** A matrix of vectors, where each vector represents the current word/position we are focusing on. It's asking a question, like "What parts of the input are relevant to *me*?"
*   **Key (K):** A matrix of vectors, one for each word in the input sequence. You can think of these as "labels" or "identifiers" for the words. They are the "database keys" that the Query will be compared against.
*   **Value (V):** A matrix of vectors, also one for each word in the input. These vectors contain the actual information or content of the words.

**The Process:**

1.  **Compute Scores (`QK^T`):** We take the matrix product of the Queries with the transposed Keys. This efficiently computes the dot product between every query vector and every key vector. If a query is similar to a key, their dot product will be large. This gives us a matrix of **raw attention scores**, indicating the relevance of each word to every other word.
2.  **Scale (`/ sqrt(d_k)`):** We scale the scores by the square root of the dimension of the key vectors (`d_k`). This is a crucial step for stabilizing the gradients during training, preventing them from becoming too small, especially when `d_k` is large.
3.  **Apply Mask (Optional):** In some cases (like in a Transformer decoder), we need to prevent a position from attending to future positions. A mask is applied to set these illegal scores to a very small number (`-1e9`).
4.  **Apply Softmax:** We apply a softmax function along the last dimension of the scaled scores. This turns the scores for each query into a probability distribution (they all sum to 1). These are our final **attention weights**.
5.  **Weighted Sum (`* V`):** We perform a matrix multiplication between the attention weights and the Value matrix. This produces the final output: a new representation for each word, where words that received high attention contribute more to the result.

Let's implement this in code.

In [None]:
import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
    """
    Computes Scaled Dot-Product Attention as described in "Attention Is All You Need".
    This is the core building block of Multi-Head Attention.
    """
    def __init__(self, temperature, attn_dropout=0.1):
        """
        Args:
            temperature (float): The scaling factor (sqrt(d_k)).
            attn_dropout (float): Dropout rate for the attention weights.
        """
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        """
        Performs the forward pass of the attention mechanism.

        Args:
            q (torch.Tensor): Query tensor. Shape: (batch_size, n_head, len_q, d_k)
            k (torch.Tensor): Key tensor. Shape: (batch_size, n_head, len_k, d_k)
            v (torch.Tensor): Value tensor. Shape: (batch_size, n_head, len_v, d_v)
                              Note: len_k and len_v must be the same.
            mask (torch.Tensor, optional): A boolean mask to prevent attention to certain positions.
                                           Shape: (batch_size, 1, 1, len_k) or similar. Defaults to None.

        Returns:
            output (torch.Tensor): The context vector after applying attention.
                                   Shape: (batch_size, n_head, len_q, d_v)
            attn (torch.Tensor): The attention weights.
                                 Shape: (batch_size, n_head, len_q, len_k)
        """
        # 1. & 2. Compute Scores and Scale: Q * K^T / sqrt(d_k)
        # Matmul shapes: (..., len_q, d_k) @ (..., d_k, len_k) -> (..., len_q, len_k)
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

        # 3. Apply Mask (if provided)
        if mask is not None:
            # The mask is typically a tensor where `True` or `1` indicates a position to be masked.
            # We set these positions to a very large negative number before the softmax.
            attn = attn.masked_fill(mask == 0, -1e9) # Use a large negative value

        # 4. Apply Softmax to get attention weights
        # Softmax is applied on the last dimension (len_k) to get probabilities.
        attn = self.dropout(F.softmax(attn, dim=-1))

        # 5. Weighted Sum: Attention Weights * V
        # Matmul shapes: (..., len_q, len_k) @ (..., len_k, d_v) -> (..., len_q, d_v)
        output = torch.matmul(attn, v)

        return output, attn

# --- Let's test it with some example tensors ---

# Parameters
batch_size = 2
n_head = 8      # Number of attention heads
len_q = 4       # Length of query sequence
len_k = 6       # Length of key/value sequence
d_k = 64        # Dimension of keys/queries
d_v = 64        # Dimension of values

# Scaling factor
temperature = np.sqrt(d_k)

# Create random tensors for Q, K, V
q = torch.randn(batch_size, n_head, len_q, d_k)
k = torch.randn(batch_size, n_head, len_k, d_k)
v = torch.randn(batch_size, n_head, len_k, d_v) # len_k and len_v are the same

# Create an instance of our attention module
attention_layer = ScaledDotProductAttention(temperature=temperature)

# Get the output
output, attn_weights = attention_layer(q, k, v)

# Print the shapes to verify
print("--- Tensor Shapes ---")
print("Query (q):          ", q.shape)
print("Key (k):            ", k.shape)
print("Value (v):          ", v.shape)
print("\nOutput from Attention:", output.shape)
print("Attention Weights:  ", attn_weights.shape)

# --- Verification ---
# The output for each head should have the shape (len_q, d_v)
# The attention weights for each head should have the shape (len_q, len_k)
assert output.shape == (batch_size, n_head, len_q, d_v)
assert attn_weights.shape == (batch_size, n_head, len_q, len_k)

print("\nâœ… Shapes are correct!")

## ðŸ§  Part 3: Multi-Head Attention (MHA)

Scaled Dot-Product Attention is powerful, but Multi-Head Attention takes it a step further. Instead of calculating attention just once, MHA runs the process multiple times in parallel and concatenates the results.

**Why is this useful?**

It allows the model to jointly attend to information from different "representation subspaces" at different positions. A single attention head might learn to focus on, for example, subject-verb relationships, while another head might focus on adjective-noun pairings. By having multiple heads, the model can capture a richer variety of linguistic patterns simultaneously.

**The Process:**

1.  **Initial Projections:** The input Queries, Keys, and Values are not fed directly into the attention mechanism. Instead, for *each head*, we create a separate linear projection (a `nn.Linear` layer) of Q, K, and V. This is what creates the different "representation subspaces."
2.  **Parallel Attention:** We perform Scaled Dot-Product Attention on each of these projected sets of Q, K, and V in parallel. This results in `h` (number of heads) different output matrices.
3.  **Concatenate & Final Projection:** The `h` output matrices are concatenated together. This combined matrix is then passed through one final linear layer (`nn.Linear`) to produce the ultimate output of the MHA block.

This entire process is encapsulated in the `MultiHeadAttention` module below.

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Implements Multi-Head Attention as described in "Attention Is All You Need".
    This module contains multiple parallel Scaled Dot-Product Attention layers.
    """
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        """
        Args:
            n_head (int): The number of attention heads.
            d_model (int): The dimensionality of the input and output.
            d_k (int): The dimensionality of the queries and keys.
            d_v (int): The dimensionality of the values.
            dropout (float): Dropout rate.
        """
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        # 1. Initial Projections for Q, K, V for all heads
        # We create one large linear layer for each of Q, K, V.
        # This is more efficient than creating n_head separate linear layers.
        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)

        # 3. Final Projection Layer
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        # 2. Parallel Attention Layer
        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v, mask=None):
        """
        Forward pass for Multi-Head Attention.

        Args:
            q, k, v (torch.Tensor): Input tensors. Shape: (batch_size, seq_len, d_model)
            mask (torch.Tensor, optional): Mask for attention. Defaults to None.

        Returns:
            output (torch.Tensor): Output tensor. Shape: (batch_size, seq_len, d_model)
            attn (torch.Tensor): Attention weights. Shape: (batch_size, n_head, seq_len, seq_len)
        """
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = q # Save the original query for the residual connection

        # --- 1. Initial Projections & Reshape ---
        # Pass inputs through linear layers and split into n_head.
        # view() reshapes the tensor to (batch_size, seq_len, n_head, d_k/d_v)
        # transpose(1, 2) swaps seq_len and n_head to get the desired shape for attention:
        # (batch_size, n_head, seq_len, d_k/d_v)
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k).transpose(1, 2)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k).transpose(1, 2)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v).transpose(1, 2)

        # If a mask is provided, expand it to match the n_head dimension.
        if mask is not None:
            mask = mask.unsqueeze(1)   # Shape: (batch_size, 1, 1, len_k)

        # --- 2. Parallel Attention ---
        # q, k, v are now shape (batch_size, n_head, seq_len, d_k/d_v)
        output, attn = self.attention(q, k, v, mask=mask)

        # --- 3. Concatenate & Final Projection ---
        # Transpose back to (batch_size, len_q, n_head, d_v)
        # contiguous() is needed to create a contiguous block of memory.
        # reshape() or view() can then combine the last two dimensions.
        output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        output = self.dropout(self.fc(output))

        # --- Add & Norm (Residual Connection) ---
        output += residual
        output = self.layer_norm(output)

        return output, attn

# --- Let's test it ---
# Parameters
d_model = 512
n_head = 8
d_k = 64
d_v = 64
seq_len = 10
batch_size = 2

# Create an instance of the MHA module
mha = MultiHeadAttention(n_head, d_model, d_k, d_v)

# Create random input tensors
q = torch.randn(batch_size, seq_len, d_model)
k = torch.randn(batch_size, seq_len, d_model)
v = torch.randn(batch_size, seq_len, d_model)

# Get the output
output, attn_weights = mha(q, k, v)

# Print shapes to verify
print("--- MHA Tensor Shapes ---")
print("Input (q, k, v):    ", q.shape)
print("\nOutput of MHA:      ", output.shape)
print("Attention Weights:  ", attn_weights.shape)

# --- Verification ---
assert output.shape == (batch_size, seq_len, d_model)
assert attn_weights.shape == (batch_size, n_head, seq_len, seq_len)
print("\nâœ… Shapes are correct!")

## ðŸ“Š Part 4: Visualizing Attention Weights

One of the most powerful aspects of attention mechanisms is their **interpretability**. Unlike traditional neural network layers where it's hard to understand what the model is "thinking," attention weights give us a direct window into the model's decision-making process.

By visualizing the attention weights as a heatmap, we can see:
- Which input words the model focuses on when generating each output word
- How the model distributes its "attention" across the sequence
- Whether the model has learned meaningful patterns (e.g., aligning source and target words in translation)

Let's create some comprehensive visualizations!

In [None]:
def visualize_attention(attention_weights, input_tokens, output_tokens, head_idx=0, title="Attention Weights"):
    """
    Visualizes attention weights as a heatmap.
    
    Args:
        attention_weights (torch.Tensor): Attention weights of shape (batch, n_heads, seq_len, seq_len)
        input_tokens (list): List of input tokens/words
        output_tokens (list): List of output tokens/words  
        head_idx (int): Which attention head to visualize (default: 0)
        title (str): Title for the plot
    """
    # Extract weights for the specified head and convert to numpy
    # Shape: (seq_len_out, seq_len_in)
    weights = attention_weights[0, head_idx].detach().cpu().numpy()
    
    # Create figure
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Create heatmap
    im = ax.imshow(weights, cmap='YlOrRd', aspect='auto', vmin=0, vmax=1)
    
    # Set ticks and labels
    ax.set_xticks(np.arange(len(input_tokens)))
    ax.set_yticks(np.arange(len(output_tokens)))
    ax.set_xticklabels(input_tokens, rotation=45, ha='right', fontsize=10)
    ax.set_yticklabels(output_tokens, fontsize=10)
    
    # Labels and title
    ax.set_xlabel('Input Tokens (Keys)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Output Tokens (Queries)', fontsize=12, fontweight='bold')
    ax.set_title(f'{title} - Head {head_idx}', fontsize=14, fontweight='bold', pad=20)
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('Attention Weight', rotation=270, labelpad=20, fontsize=11)
    
    # Add grid for better readability
    ax.set_xticks(np.arange(len(input_tokens))-.5, minor=True)
    ax.set_yticks(np.arange(len(output_tokens))-.5, minor=True)
    ax.grid(which="minor", color="white", linestyle='-', linewidth=2)
    
    # Annotate cells with attention values
    for i in range(len(output_tokens)):
        for j in range(len(input_tokens)):
            text = ax.text(j, i, f'{weights[i, j]:.2f}',
                          ha="center", va="center", color="black", fontsize=8)
    
    plt.tight_layout()
    return fig

# --- Example: Visualize Self-Attention on a Sample Sentence ---

# Create a sample sentence
sentence = "The cat sat on the mat"
tokens = sentence.split()
seq_len = len(tokens)

# Create sample query, key, value matrices
d_model = 64
q = torch.randn(1, seq_len, d_model)
k = torch.randn(1, seq_len, d_model)
v = torch.randn(1, seq_len, d_model)

# Run through our scaled dot-product attention
attention_module = ScaledDotProductAttention(d_model)
output, attn_weights = attention_module(q, k, v)

# Visualize - attn_weights shape: (batch, seq_len, seq_len)
# We need to add an extra dimension for n_heads for our visualization function
attn_weights_expanded = attn_weights.unsqueeze(1)  # Now (batch, 1, seq_len, seq_len)

print("ðŸŽ¨ Visualizing Self-Attention Weights")
print("=" * 60)
print(f"Sentence: '{sentence}'")
print(f"Attention weights shape: {attn_weights.shape}")
print("\nThe heatmap shows how much each word (row) attends to each word (column)")

fig = visualize_attention(
    attn_weights_expanded, 
    tokens, 
    tokens,
    head_idx=0,
    title="Self-Attention on Sample Sentence"
)
plt.show()

print("\nðŸ’¡ Interpretation:")
print("- Bright colors (yellow/red) indicate strong attention")
print("- Each row sums to 1.0 (softmax ensures this)")
print("- Diagonal elements show how much each word attends to itself")
print("- Off-diagonal elements show cross-word attention")

### Visualizing Multi-Head Attention

Multi-head attention is powerful because different heads can learn to focus on different aspects of the relationships between words. Let's visualize all heads at once to see this diversity!

In [None]:
# --- Multi-Head Attention Visualization ---

# Create Multi-Head Attention module
n_heads = 4
d_k = 16
d_v = 16
mha = MultiHeadAttention(n_heads, d_model, d_k, d_v)

# Run input through MHA
output_mha, attn_weights_mha = mha(q, k, v)

# attn_weights_mha shape: (batch, n_heads, seq_len, seq_len)
print(f"ðŸŽ¯ Multi-Head Attention with {n_heads} heads")
print(f"Attention weights shape: {attn_weights_mha.shape}")

# Create subplot grid for all heads
fig, axes = plt.subplots(2, 2, figsize=(16, 14))
fig.suptitle(f'Multi-Head Attention - All {n_heads} Heads\nSentence: "{sentence}"', 
             fontsize=16, fontweight='bold', y=0.995)

for head_idx in range(n_heads):
    row = head_idx // 2
    col = head_idx % 2
    ax = axes[row, col]
    
    # Get weights for this head
    weights = attn_weights_mha[0, head_idx].detach().cpu().numpy()
    
    # Create heatmap
    im = ax.imshow(weights, cmap='YlOrRd', aspect='auto', vmin=0, vmax=1)
    
    # Set ticks and labels
    ax.set_xticks(np.arange(len(tokens)))
    ax.set_yticks(np.arange(len(tokens)))
    ax.set_xticklabels(tokens, rotation=45, ha='right', fontsize=9)
    ax.set_yticklabels(tokens, fontsize=9)
    
    # Labels
    ax.set_xlabel('Input Tokens', fontsize=10)
    ax.set_ylabel('Output Tokens', fontsize=10)
    ax.set_title(f'Head {head_idx}', fontsize=12, fontweight='bold')
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('Attention', rotation=270, labelpad=15, fontsize=9)
    
    # Add grid
    ax.set_xticks(np.arange(len(tokens))-.5, minor=True)
    ax.set_yticks(np.arange(len(tokens))-.5, minor=True)
    ax.grid(which="minor", color="white", linestyle='-', linewidth=1.5)
    
    # Annotate with values
    for i in range(len(tokens)):
        for j in range(len(tokens)):
            ax.text(j, i, f'{weights[i, j]:.2f}',
                   ha="center", va="center", color="black", fontsize=7)

plt.tight_layout()
plt.show()

print("\nðŸ’¡ Key Observations:")
print("- Each head learns different attention patterns")
print("- Some heads may focus on local relationships (nearby words)")
print("- Other heads may capture long-range dependencies")
print("- The diversity of patterns is what makes multi-head attention powerful!")
print("\nâœ… This visualization helps understand what the model 'sees'!")

## ? Part 4: Summary and Next Steps

In this notebook, we have demystified the attention mechanism, the cornerstone of modern NLP models like the Transformer.

### Key Takeaways:

*   **Attention is a Mechanism for Focusing:** It allows a model to weigh the importance of different parts of the input sequence when producing an output, rather than relying on a single fixed-size context vector from an RNN.
*   **Scaled Dot-Product Attention:** This is the most common implementation. It uses Query, Key, and Value matrices to calculate attention scores, which are then scaled and passed through a softmax to create a probability distribution (the attention weights). These weights are then used to create a weighted sum of the Value vectors.
*   **Multi-Head Attention (MHA):** This is the standard way to use attention in Transformers. It runs multiple Scaled Dot-Product Attention layers in parallel, allowing the model to capture different types of relationships and features from the input data simultaneously. The results are then concatenated and linearly projected to produce the final output.
*   **Residual Connections and Layer Normalization:** These are crucial components that are used alongside MHA to ensure stable training and effective information flow through deep networks.

### What We Built:

1.  A `ScaledDotProductAttention` module from scratch, understanding the core formula.
2.  A `MultiHeadAttention` module that wraps the scaled dot-product attention and adds the necessary linear projections and parallel head logic.

### Where to Go From Here:

*   **Transformer Encoder:** The next logical step is to see how Multi-Head Attention is used inside a full Transformer Encoder block. This involves combining MHA with a Position-wise Feed-Forward Network.
*   **Positional Encodings:** We haven't addressed *how* the model knows the order of the words. This is solved by adding positional encodings to the input embeddings before they are fed into the attention layers.
*   **Full Transformer Model:** Combine multiple Encoder blocks to build a full Transformer for a task like sentiment analysis or machine translation.

This notebook provides the fundamental building blocks. By understanding how Q, K, and V interact, you are well on your way to mastering the Transformer architecture.