# Module 1: Deep Dive into Transformers for Text.

**Goal:** Understand the inner workings of the Transformer architecture for Natural Language Processing (NLP) tasks. We will explore tokenization, embeddings, positional encoding, self-attention (encoder & decoder), cross-attention, and how these components fit together.

**Approach:** We will use a combination of theoretical explanations, visualizations, and hands-on code examples. We will leverage the Hugging Face `transformers` library to inspect a pre-trained model and also implement key components from scratch using PyTorch to solidify understanding. We will **not** be training a model, but rather dissecting an existing one.

> For a visual overview of the Transformer architecture (**HIGHLY RECOMMENDED**), please see:
> * https://jalammar.github.io/illustrated-transformer/
> * https://poloclub.github.io/transformer-explainer/  
> ![Transformer Architecture](https://jalammar.github.io/images/t/The_transformer_encoders_decoders.png)


**Let's get started!**

## Prerequisites

In [None]:
!pip install -q transformers torch matplotlib seaborn numpy pandas plotly

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import math
from transformers import AutoTokenizer, AutoModel
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import logging as hf_logging
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (10, 6)

# Set seed for reproducibility (optional)
torch.manual_seed(42)
np.random.seed(42)

# Suppress verbose Hugging Face warnings
hf_logging.set_verbosity_error()

# Ignore warnings
warnings.filterwarnings("ignore")

# Check for GPU availability (optional, but good practice)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [None]:
# Visualize tokenization
def visualize_tokenization(text, tokenizer):
    # Get regular tokens and their IDs (without special tokens)
    tokens = tokenizer.tokenize(text)
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    df = pd.DataFrame({
        'Token': tokens,
        'ID': token_ids
    })

    # Get input IDs including special tokens (e.g., [CLS] and [SEP] for BERT)
    input_ids = tokenizer.encode(text, add_special_tokens=True)
    all_tokens = tokenizer.convert_ids_to_tokens(input_ids)

    # Create a figure for visualization
    num_rows = len(df) + 1  # +1 for the header row
    fig, ax = plt.subplots(figsize=(8, num_rows * 0.5 + 2))
    ax.axis('tight')
    ax.axis('off')

    # Generate a table with token data
    table = ax.table(cellText=df.values,
                     colLabels=df.columns,
                     cellLoc='center',
                     loc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 1.5)

    # Add a title on top
    plt.title(f"Tokenization Visualization for:\n{text}", fontsize=12, fontweight='bold')

    # Display information about special tokens at the bottom of the figure
    plt.figtext(0.5, 0.01,
                f"Tokens with special tokens: {all_tokens}\nInput IDs with special tokens: {input_ids}",
                wrap=True, horizontalalignment='center', fontsize=10)

    plt.show()

# --- Interactive Plot with Plotly (Optional, but nice!) ---
def plot_attention_interactive(attention_layer_tensor, tokens, layer_idx):
    """Creates an interactive heatmap using Plotly, allowing head selection."""
    num_heads = attention_layer_tensor.shape[1]
    seq_len = attention_layer_tensor.shape[2]

    fig = make_subplots(rows=1, cols=1)

    # Add traces for each head, initially visible=False except the first
    for h in range(num_heads):
        fig.add_trace(
            go.Heatmap(
                z=attention_layer_tensor[0, h, :, :].numpy(), # Select batch 0, head h
                x=tokens,
                y=tokens,
                colorscale='Viridis',
                name=f'Head {h}',
                visible=(h == 0), # Only first head visible initially
                showscale=False # Hide individual color bars
            ),
            row=1, col=1
        )

    # Create dropdown menu to select head
    buttons = []
    for h in range(num_heads):
        visibility = [False] * num_heads
        visibility[h] = True
        buttons.append(dict(
            label=f'Head {h}',
            method='update',
            args=[{'visible': visibility},
                  {'title': f'Self-Attention Weights - Layer {layer_idx}, Head {h}'}]
        ))

    fig.update_layout(
        updatemenus=[dict(
            active=0,
            buttons=buttons,
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.1,
            xanchor="left",
            y=1.15,
            yanchor="top"
        )],
        title=f'Self-Attention Weights - Layer {layer_idx}, Head 0',
        xaxis_title="Key (Attended To)",
        yaxis_title="Query (Attending From)",
        yaxis_autorange='reversed', # Put [CLS] at the top-left
        height=700,
        width=700
    )
    fig.show()

In [None]:
def plot_attention_heatmap_improved(attention_matrix, x_labels, y_labels, title, figsize=(12, 10), cmap="viridis", fmt=".2f"):
    """
    Plots a clearer heatmap for attention weights using Matplotlib/Seaborn.

    Args:
        attention_matrix (np.array): 2D numpy array of attention weights.
        x_labels (list): List of labels for the X-axis (Keys/Values).
        y_labels (list): List of labels for the Y-axis (Queries).
        title (str): Title for the plot.
        figsize (tuple): Figure size.
        cmap (str): Colormap name.
        fmt (str): String format for annotations (e.g., '.2f' for 2 decimal places). Set to None to disable annotations.
    """
    plt.figure(figsize=figsize)
    sns.heatmap(
        attention_matrix,
        xticklabels=x_labels,
        yticklabels=y_labels,
        cmap=cmap,
        linewidths=.5,
        linecolor='lightgray', # Add lines between cells
        cbar=True,           # Show color bar
        annot= (attention_matrix.shape[0] < 20 and attention_matrix.shape[1] < 20 and fmt is not None), # Show annotations only for smaller matrices
        fmt=fmt,
        annot_kws={"size": 8} # Adjust annotation font size if needed
    )
    plt.xlabel("Encoder Tokens (Keys / Values from Source)", fontsize=12)
    plt.ylabel("Decoder Tokens (Queries from Target)", fontsize=12)
    plt.title(title, fontsize=14)
    plt.xticks(rotation=45, ha='right', fontsize=10)
    plt.yticks(rotation=0, fontsize=10)
    plt.tight_layout(pad=1.5) # Adjust layout to prevent overlap
    plt.show()


def plot_cross_attention_interactive(attention_layer_tensor, encoder_tokens, decoder_tokens, layer_idx):
    """
    Creates an interactive heatmap specifically for CROSS-attention using Plotly.

    Args:
        attention_layer_tensor (torch.Tensor): Attention weights tensor for a layer.
                                              Shape: (batch_size, num_heads, target_seq_len, source_seq_len)
        encoder_tokens (list): List of tokens for the source sequence (X-axis).
        decoder_tokens (list): List of tokens for the target sequence (Y-axis).
        layer_idx (int): The index of the decoder layer being visualized.
    """
    num_heads = attention_layer_tensor.shape[1]
    # Ensure tensor is on CPU and NumPy for plotting
    attention_layer_np = attention_layer_tensor.cpu().numpy()

    fig = make_subplots(rows=1, cols=1)

    # Add traces for each head
    for h in range(num_heads):
        fig.add_trace(
            go.Heatmap(
                z=attention_layer_np[0, h, :, :], # Select batch 0, head h
                x=encoder_tokens,               # Use encoder tokens for X axis
                y=decoder_tokens,               # Use decoder tokens for Y axis
                colorscale='Viridis',
                name=f'Head {h}',
                hoverongaps = False,
                visible=(h == 0), # Only first head visible initially
                showscale=True,   # Show color scale by default
                colorbar=dict(title='Attention Weight', titleside='right') # Add color bar title
            ),
            row=1, col=1
        )

    # Create dropdown menu to select head
    buttons = []
    for h in range(num_heads):
        # Create a visibility mask: True for the selected head, False for others
        visibility = [False] * num_heads
        visibility[h] = True
        buttons.append(dict(
            label=f'Head {h}',
            method='update',
            # Update visibility and the plot title when a head is selected
            args=[{'visible': visibility},
                  {'title.text': f'Cross-Attention Weights - Decoder Layer {layer_idx}, Head {h}'}]
        ))

    fig.update_layout(
        updatemenus=[dict(
            active=0,
            buttons=buttons,
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.05, # Position dropdown slightly left
            xanchor="left",
            y=1.18, # Position dropdown slightly higher
            yanchor="top"
        )],
        title=dict(
            text=f'Cross-Attention Weights - Decoder Layer {layer_idx}, Head 0', # Initial title
            x=0.5, # Center title
            xanchor='center'
        ),
        xaxis_title="Encoder Tokens (Key / Value from Source)",
        yaxis_title="Decoder Tokens (Query from Target)",
        yaxis_autorange='reversed', # Puts decoder start token at the top
        xaxis_tickangle=-45,       # Angle ticks for better readability
        height=700,                # Adjust height if needed
        width=850,                 # Adjust width if needed
        xaxis_showgrid=False,      # Hide grid lines for cleaner look
        yaxis_showgrid=False,
        hovermode='closest',       # Show hover info for the closest point
    )

    # Update y-axis tick labels for better readability if needed
    fig.update_yaxes(tickfont=dict(size=10))
    fig.update_xaxes(tickfont=dict(size=10))

    fig.show()

## 1. Introduction: What is a Transformer?

The Transformer architecture was introduced in the paper ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762) by Vaswani et al. (2017). It revolutionized NLP by relying entirely on **attention mechanisms**, discarding recurrence (like RNNs/LSTMs) and convolutions.

**Why was this revolutionary?**

1.  **Parallelization:** Unlike RNNs that process sequences step-by-step, Transformers can process all tokens in a sequence simultaneously, making training much faster on modern hardware (GPUs/TPUs).
2.  **Long-Range Dependencies:** Attention mechanisms allow the model to directly model dependencies between any two tokens in the sequence, regardless of their distance, overcoming limitations of RNNs with long sequences.

**High-Level Architecture:**

The original Transformer has an **Encoder-Decoder** structure, commonly used for sequence-to-sequence tasks like machine translation or summarization.

> ![Transformer Architecture](https://deeprevision.github.io/posts/001-transformer/transformer.png)


*   **Encoder:** Maps an input sequence of symbols $(x_1, ..., x_n)$ to a sequence of continuous representations $\mathbf{z} = (z_1, ..., z_n)$.
*   **Decoder:** Given $\mathbf{z}$, generates an output sequence $(y_1, ..., y_m)$ one symbol at a time (autoregressive).

Many popular models are variations:
*   **Encoder-Only:** BERT, RoBERTa, ALBERT (good for NLU tasks like classification, QA, NER).
*   **Decoder-Only:** GPT series (good for text generation).
*   **Encoder-Decoder:** Original Transformer, T5, BART (good for seq2seq tasks).

We'll explore all the key components shown in the diagram.

## 2. Tokenization: Converting Text to Numbers

Transformers don't understand raw text. We need to convert text into a sequence of numerical IDs, a process called **tokenization**.

**Why not just use words?**
* Vocabulary size would be enormous (millions of words, including typos, variations).
* Handling unknown words ("Out-Of-Vocabulary" or OOV problem).

**Common Tokenization Strategies:**

1. **Word-Based:** Split by spaces/punctuation. Simple, but suffers from large vocab & OOV.
2. **Character-Based:** Split into individual characters. Small vocab, no OOV, but loses word-level meaning and creates very long sequences.
3. **Subword-Based:** The sweet spot! Breaks words into smaller, meaningful units. Common words remain intact, rare words are broken down. Handles OOV gracefully and keeps vocabulary size manageable.
    * **Byte-Pair Encoding (BPE):** Starts with characters, iteratively merges most frequent pairs. Used by GPT, RoBERTa.
        * **Byte-Level BPE (BBPE):** A variant of BPE that operates directly on the raw bytes of the input text rather than on characters. This approach enhances robustness since it can efficiently handle any Unicode string, including rare symbols and diverse languages. It is famously used in models like GPT-2.
    * **WordPiece:** Similar to BPE, but merges pairs that maximize likelihood of training data. Used by BERT, DistilBERT.
    * **SentencePiece:** Treats the input as a raw stream, includes whitespace in the tokens. Language-agnostic. Used by T5, XLNet.

**Let's use a real tokenizer (WordPiece from BERT).**


In [None]:
# Load a pre-trained tokenizer (bert-base-uncased)
tokenizer_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

print(f"Loaded Tokenizer: {tokenizer_name}")
print(f"Vocabulary Size: {tokenizer.vocab_size}")
print(f"Special Tokens: {tokenizer.special_tokens_map}")
print(f"CLS token: {tokenizer.cls_token} (ID: {tokenizer.cls_token_id})")
print(f"SEP token: {tokenizer.sep_token} (ID: {tokenizer.sep_token_id})")
print(f"PAD token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")
print(f"UNK token: {tokenizer.unk_token} (ID: {tokenizer.unk_token_id})")

# Example sentence
text = "Transformers are powerful! They revolutionized NLP."
text2 = "Let's tokenize this sentence using BertTokenizer."

# 1. Tokenize: Split into tokens (subwords)
tokens = tokenizer.tokenize(text)
print(f"\nText: '{text}'")
print(f"Tokens: {tokens}")

# 2. Convert tokens to IDs
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(f"Token IDs: {token_ids}")

# 3. Add special tokens ([CLS] and [SEP]) - common practice for BERT
# [CLS] often used for classification tasks, [SEP] separates segments.
encoded_plus = tokenizer.encode_plus(text, add_special_tokens=True)
special_token_ids = encoded_plus['input_ids']
special_tokens = tokenizer.convert_ids_to_tokens(special_token_ids)

print(f"\nTokens with Special Tokens: {special_tokens}")
print(f"Token IDs with Special Tokens: {special_token_ids}")
print(f"Attention Mask: {encoded_plus['attention_mask']}") # Tells model which tokens to attend to (ignore padding)

# 4. Decode: Convert IDs back to text
decoded_text = tokenizer.decode(special_token_ids, skip_special_tokens=False)
decoded_text_no_special = tokenizer.decode(special_token_ids, skip_special_tokens=True)
print(f"\nDecoded Text (with special): {decoded_text}")
print(f"Decoded Text (no special): {decoded_text_no_special}")

# Example showing subword splitting
print("\nSubword Example:")
print(f"Tokenizing 'revolutionized': {tokenizer.tokenize('revolutionized')}")
print(f"Tokenizing 'Tokenization': {tokenizer.tokenize('Tokenization')}") # Note the capitalization handling for 'uncased' model

In [None]:
# Batch encoding example (handling multiple sentences, padding, truncation)
sentences = [text, text2]
batch_encoded = tokenizer(
    sentences,
    padding=True,        # Pad shorter sequences to the length of the longest
    truncation=True,     # Truncate sequences longer than max model length
    max_length=20,       # Example max length
    return_tensors="pt"  # Return PyTorch tensors
)

print("\nBatch Encoding Example:")
print("Input IDs (Batch):\n", batch_encoded['input_ids'])
print("Attention Mask (Batch):\n", batch_encoded['attention_mask'])


visualize_tokenization("Let's understand tokenization.", tokenizer)
visualize_tokenization("You are great students!", tokenizer)

**Practical Tips (Tokenization):**
*   Always use the *exact* tokenizer that the pre-trained model was trained with. Mismatched tokenizers lead to poor performance.
*   Pay attention to `uncased` vs `cased` models. `uncased` models lowercase text before tokenization.
*   Understand the role of special tokens (`[CLS]`, `[SEP]`, `[PAD]`, `[UNK]`, `[MASK]`) for your specific model and task.
*   Padding and Attention Masks are crucial when processing batches of sequences with different lengths. The attention mask tells the model to ignore padding tokens.

## 3. Embeddings and Positional Encoding

Once we have token IDs, we need to convert them into dense vectors (embeddings) that the model can process.

**Input Embeddings:**
*   A simple lookup table (Embedding Matrix).
*   Each token ID corresponds to a row in the matrix.
*   The size of the matrix is `(vocabulary_size, embedding_dimension)`.
*   `embedding_dimension` (often denoted `d_model`) is a hyperparameter (e.g., 768 for BERT-base).
*   These embeddings are *learned* during training.

**The Problem:** Standard Transformers have no built-in sense of sequence order (unlike RNNs). If you shuffle the input embeddings, the self-attention output would also be shuffled – the meaning would be lost!

**Positional Encoding (PE):**
*   Injects information about the position of each token in the sequence.
*   These positional encodings are *added* to the input embeddings.
*   The original paper used fixed sine and cosine functions of different frequencies:
    *   $PE_{(pos, 2i)} = \sin(pos / 10000^{2i / d_{model}})$
    *   $PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i / d_{model}})$
    *   Where `pos` is the token position, `i` is the dimension index within the embedding vector (`0` to `d_model/2 - 1`), and `d_model` is the embedding dimension.

**Why sine and cosine?**
*   Produces unique encodings for each position.
*   Allows the model to easily attend to relative positions, since $PE_{pos+k}$ can be represented as a linear function of $PE_{pos}$.
*   Values remain bounded between -1 and 1.
*   Can generalize to sequence lengths longer than those seen during training (though performance might degrade).

### Task:
* Learn how nn.Embedding works. What must written inside it?
* Learn how positional encoding works. Write a final formulas for PE.

**Useful links**:
1. https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html

In [None]:
# We'll use a pre-trained model's embeddings later,
# but here's how you'd define one conceptually:
vocab_size = tokenizer.vocab_size # From our BERT tokenizer
d_model = 128 # Smaller dimension for easier visualization/computation here
embedding_layer = nn.Embedding(<YOUR_CODE>)

# Example: Get embeddings for our batch_encoded IDs
# (Using the conceptual layer, not BERT's actual embeddings yet)
sample_ids = batch_encoded['input_ids'].detach().clone()
sample_embeddings = embedding_layer(sample_ids)
print(f"Sample Token IDs shape: {sample_ids.shape}") # (batch_size, seq_len)
print(f"Sample Embeddings shape: {sample_embeddings.shape}") # (batch_size, seq_len, d_model)

# --- Positional Encoding ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0.1):
        """
        Args:
            d_model (int): Dimension of the embeddings.
            max_len (int): Maximum sequence length.
            dropout (float): Dropout probability.
        """
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Create positional encoding matrix (max_len, d_model)
        pe = torch.zeros(max_len, d_model)

        # Position indices (0, 1, ..., max_len-1) -> shape (max_len, 1)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        # Term for division: 10000^(2i / d_model)
        # Calculate 2i first: torch.arange(0, d_model, 2) -> (0, 2, ..., d_model-2)
        # Then calculate the exponent term
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        # Calculate sine for even indices, cosine for odd indices
        pe[:, 0::2] = <YOUR_CODE> # Even indices
        pe[:, 1::2] = <YOUR_CODE> # Odd indices

        # Add a batch dimension (1, max_len, d_model) so it can be added to input embeddings
        # Using register_buffer makes 'pe' part of the model's state, but not a parameter to be trained
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input embeddings (batch_size, seq_len, d_model).
        Returns:
            torch.Tensor: Embeddings with added positional encoding.
        """
        # x.size(1) is the sequence length of the current batch
        # Add positional encoding to the input embeddings
        # self.pe is (1, max_len, d_model), slice it to match input seq_len
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x) # Apply dropout


# Instantiate Positional Encoding
max_sequence_length = 50 # Max length for our example
pe_layer = PositionalEncoding(d_model, max_len=max_sequence_length)

# Apply PE to our sample embeddings
embeddings_with_pe = pe_layer(sample_embeddings)
print(f"\nEmbeddings with PE shape: {embeddings_with_pe.shape}")

# --- Visualize Positional Encoding ---
def visualize_pe(pe_layer, max_len_to_show=50, d_model_to_show=128):
    pe = pe_layer.pe.squeeze(0).cpu().numpy() # Remove batch dim, move to CPU, convert to numpy
    pe = pe[:max_len_to_show, :d_model_to_show] # Slice for visualization

    plt.figure(figsize=(12, 8))
    sns.heatmap(pe, cmap="viridis")
    plt.xlabel("Embedding Dimension Index")
    plt.ylabel("Token Position in Sequence")
    plt.title(f"Positional Encoding (First {max_len_to_show} Positions, {d_model_to_show} Dimensions)")
    plt.show()

visualize_pe(pe_layer, max_len_to_show=max_sequence_length, d_model_to_show=d_model)

print("\nObservations from PE visualization:")
print("- Each row (position) has a unique encoding pattern.")
print("- Columns (dimensions) vary with different frequencies (sine/cosine waves).")
print("- Smooth transitions between positions, potentially allowing generalization.")

In [None]:
# --- Combine Embeddings and PE using a real model (BERT) ---
# Load a small BERT model
model_name = "bert-base-uncased"
bert_model = AutoModel.from_pretrained(model_name).to(device)
bert_embeddings = bert_model.embeddings # Access the embedding module

# Prepare input for BERT
text_for_bert = "Example sentence for BERT embeddings."
inputs = tokenizer(text_for_bert, return_tensors="pt", padding=True, truncation=True).to(device)
input_ids = inputs['input_ids']

# Get embeddings from BERT (includes token, position, and token_type embeddings added together)
bert_model.eval() # Set model to evaluation mode
with torch.no_grad(): # Disable gradient calculation for inference
    outputs = bert_model(**inputs, output_hidden_states=True)
    final_embeddings = outputs.hidden_states[0] # Embeddings are the first hidden state

print(f"\n--- BERT Embeddings ---")
print(f"Input text: '{text_for_bert}'")
print(f"Input IDs shape: {input_ids.shape}")
print(f"BERT Final Embeddings (Layer 0 output) shape: {final_embeddings.shape}") # (batch_size, seq_len, 768)

# Let's visualize the components of BERT's embeddings
bert_embedding_dim = bert_model.config.hidden_size # Should be 768 for base
token_embed = bert_embeddings.word_embeddings(input_ids)
pos_embed = bert_embeddings.position_embeddings(torch.arange(input_ids.shape[1], device=device).unsqueeze(0))
# BERT also has token_type_embeddings, usually 0 for single sentences
tok_type_ids = torch.zeros_like(input_ids, device=device)
tok_type_embed = bert_embeddings.token_type_embeddings(tok_type_ids)

summed_embeds = token_embed + pos_embed + tok_type_embed
# LayerNorm is applied after summing in BERT's embedding layer
normed_embeds = bert_embeddings.LayerNorm(summed_embeds)
# Dropout is also applied
final_embeds_manual = bert_embeddings.dropout(normed_embeds)

# Check if manually calculated embeddings match the model's output
# Use a tolerance due to potential floating point differences
print(f"Manually calculated embeddings match model output[0]? {torch.allclose(final_embeds_manual, final_embeddings, atol=1e-6)}")

# Visualize BERT's Positional Embeddings (different from the sin/cos version)
# BERT uses learned positional embeddings
bert_pos_embed_weights = bert_embeddings.position_embeddings.weight.detach().cpu().numpy()

plt.figure(figsize=(12, 8))
# Only show first N positions for clarity
num_pos_to_show = 50
sns.heatmap(bert_pos_embed_weights[:num_pos_to_show, :], cmap="viridis")
plt.xlabel("Embedding Dimension Index")
plt.ylabel("Token Position in Sequence")
plt.title(f"BERT Learned Positional Embeddings (First {num_pos_to_show} Positions)")
plt.show()

print("Note: BERT uses *learned* positional embeddings, not the fixed sin/cos ones.")
print("The pattern looks different, but serves the same purpose: encoding position.")

**Practical Tips (Embeddings & PE):**
*   The `d_model` (embedding dimension) must be consistent throughout the Transformer layers.
*   While the original paper used fixed sin/cos PE, many modern Transformers (like BERT) use *learned* positional embeddings, which are just another embedding layer indexed by position.
*   Adding PE is crucial. Without it, the model is permutation-invariant.
*   Dropout is typically applied after adding PE.
*   BERT includes a third type of embedding: Segment Embeddings (or Token Type Embeddings), used to distinguish between different sentences in input pairs (e.g., for Next Sentence Prediction or Question Answering). We saw this as `token_type_embeddings` above.

## 4. The Encoder: Processing the Input

The Encoder's job is to take the sequence of input embeddings (with PE) and produce a sequence of contextualized representations. It consists of a stack of identical layers (N layers, e.g., N=6 for the original Transformer, N=12 for BERT-base).

**Each Encoder Layer has two main sub-layers:**

1.  **Multi-Head Self-Attention (MHA):** Allows each token to attend to all other tokens in the *input* sequence (including itself) to capture contextual information.
2.  **Position-wise Feed-Forward Network (FFN):** A simple fully connected feed-forward network applied independently to each position.

Residual connections (`Add`) and Layer Normalization (`Norm`) are used around each sub-layer.

**Encoder Layer Structure:**
> ![Transformer Architecture](https://www.researchgate.net/profile/Vittorio-Mazzia/publication/352992757/figure/fig1/AS:1042115790389249@1625471174152/Transformer-encoder-layer-architecture-left-and-schematic-overview-of-a-multi-head.ppm)

**Let's break down Multi-Head Self-Attention first.**

### 4.1 Multi-Head Self-Attention (MHA)

**Self-Attention Core Idea:** For each token, we want to compute a representation that is a weighted sum of the representations of *all* tokens in the sequence. The weights determine "how much attention" one token should pay to another when representing itself.

**Scaled Dot-Product Attention:** This is the building block.

1.  **Project Embeddings:** Create three vectors for each input embedding vector $x_i$:
    *   **Query ($q_i$):** Represents the current token "asking" for information. $q_i = W_q x_i$
    *   **Key ($k_j$):** Represents the token $x_j$ "offering" information or being indexed. $k_j = W_k x_j$
    *   **Value ($v_j$):** Represents the actual content/representation of token $x_j$. $v_j = W_v x_j$
    (Where $W_q, W_k, W_v$ are learned weight matrices).

2.  **Calculate Attention Scores:** Compute the dot product between the query of the current token ($q_i$) and the keys of all tokens ($k_j$). This measures compatibility/similarity.
    *   $score_{ij} = q_i \cdot k_j$

3.  **Scale Scores:** Divide the scores by the square root of the dimension of the key vectors ($d_k$). This prevents the dot products from growing too large for high dimensions, which could saturate the softmax function.
    *   $scaled\_score_{ij} = \frac{q_i \cdot k_j}{\sqrt{d_k}}$

4.  **Apply Softmax:** Normalize the scores across all source tokens ($j$) so they sum to 1, yielding the attention weights ($\alpha_{ij}$).
    *   $\alpha_{ij} = \text{softmax}_j(scaled\_score_{ij}) = \frac{\exp(scaled\_score_{ij})}{\sum_{l=1}^{n} \exp(scaled\_score_{il})}$

5.  **Compute Weighted Sum:** Multiply the attention weights ($\alpha_{ij}$) by the corresponding value vectors ($v_j$) and sum them up. This gives the output representation $z_i$ for token $i$.
    *   $z_i = \sum_{j=1}^{n} \alpha_{ij} v_j$

**In Matrix Form (for the whole sequence):**
Input $X$ (batch_size, seq_len, d_model)
Queries $Q = X W_q$
Keys $K = X W_k$
Values $V = X W_v$
(Where $W_q, W_k, W_v$ are projection matrices, shapes like (d_model, d_k) or (d_model, d_v))

Attention Output $Z = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$

**Multi-Head Attention:** Instead of just one set of Q, K, V projections, MHA uses multiple "heads".

1.  Project $X$ into $h$ sets of $(Q_i, K_i, V_i)$ using different weight matrices ($W_{q,i}, W_{k,i}, W_{v,i}$) for each head $i=1...h$. The dimensions are typically $d_k = d_v = d_{model} / h$.
2.  Perform Scaled Dot-Product Attention independently for each head, producing $h$ output vectors $Z_i$.
    *   $head_i = \text{Attention}(Q W_{q,i}, K W_{k,i}, V W_{v,i})$
3.  Concatenate the outputs from all heads: $Concat(head_1, ..., head_h)$.
4.  Apply a final linear projection ($W_o$) to the concatenated output to get the final MHA result.
    *   $\text{MultiHead}(Q, K, V) = \text{Concat}(head_1, ..., head_h) W_o$

**Why multiple heads?** Allows the model to jointly attend to information from different representation subspaces at different positions. A single head might average away important details; multiple heads capture different types of relationships.

**Let's implement Scaled Dot-Product Attention and MHA from scratch.**

### Task:
* Learn how Attention works. Try to understand formula and fill in the blanks.

In [None]:
def scaled_dot_product_attention(q, k, v, mask=None):
    """
    Calculate scaled dot product attention scores.

    Args:
        q (torch.Tensor): Queries. Shape: (batch_size, n_heads, seq_len_q, d_k)
        k (torch.Tensor): Keys. Shape: (batch_size, n_heads, seq_len_k, d_k)
        v (torch.Tensor): Values. Shape: (batch_size, n_heads, seq_len_v, d_v)
                          Usually seq_len_k = seq_len_v
        mask (torch.Tensor, optional): Mask to apply (e.g., for padding or look-ahead).
                                       Shape should be broadcastable to (batch_size, n_heads, seq_len_q, seq_len_k).
                                       Mask values should be 0 for tokens to attend to, and -inf (or large negative) for masked tokens.

    Returns:
        torch.Tensor: Output tensor after attention. Shape: (batch_size, n_heads, seq_len_q, d_v)
        torch.Tensor: Attention weights. Shape: (batch_size, n_heads, seq_len_q, seq_len_k)
    """
    d_k = k.size(-1) # Dimension of keys

    # MatMul Q and K transpose -> (batch_size, n_heads, seq_len_q, seq_len_k)
    scores = <YOUR_CODE>

    # Apply mask if provided (we will use it in next sections, just wait a bit)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9) # Use large negative value

    # Apply softmax to get attention weights
    attn_weights = <YOUR_CODE> # Softmax over the key sequence length dimension

    # MatMul attention weights and V -> (batch_size, n_heads, seq_len_q, d_v)
    output = <YOUR_CODE>

    return output, attn_weights


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        """
        Args:
            d_model (int): Total dimension of the model.
            n_heads (int): Number of attention heads. d_model must be divisible by n_heads.
            dropout (float): Dropout probability.
        """
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads # Dimension of each head's key/query/value

        # Linear layers for Q, K, V projections (can be combined for efficiency)
        self.W_q = nn.Linear(d_model, d_model) # Projects input to Q space for all heads combined
        self.W_k = nn.Linear(d_model, d_model) # Projects input to K space
        self.W_v = nn.Linear(d_model, d_model) # Projects input to V space

        # Output projection layer
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x, batch_size):
        """
        Split the last dimension into (n_heads, d_k).
        Transpose to shape (batch_size, n_heads, seq_len, d_k).

        Args:
            x (torch.Tensor): Input tensor (batch_size, seq_len, d_model).
            batch_size (int): Batch size.

        Returns:
            torch.Tensor: Tensor reshaped for multi-head attention.
                          Shape: (batch_size, n_heads, seq_len, d_k)
        """
        x = x.view(batch_size, -1, self.n_heads, self.d_k) # (batch_size, seq_len, n_heads, d_k)
        return x.transpose(1, 2) # (batch_size, n_heads, seq_len, d_k)

    def forward(self, query, key, value, mask=None):
        """
        Forward pass for Multi-Head Attention.

        Args:
            query (torch.Tensor): Query input. Shape: (batch_size, seq_len_q, d_model)
            key (torch.Tensor): Key input. Shape: (batch_size, seq_len_k, d_model)
            value (torch.Tensor): Value input. Shape: (batch_size, seq_len_v, d_model)
                                  Usually seq_len_k = seq_len_v
            mask (torch.Tensor, optional): Mask. Shape broadcastable to (batch_size, 1, seq_len_q, seq_len_k).

        Returns:
            torch.Tensor: Output tensor after MHA. Shape: (batch_size, seq_len_q, d_model)
            torch.Tensor: Attention weights. Shape: (batch_size, n_heads, seq_len_q, seq_len_k)
        """
        batch_size = query.size(0)

        # 1. Project Q, K, V using linear layers
        q = <YOUR_CODE> # (batch_size, seq_len_q, d_model)
        k = <YOUR_CODE>  # (batch_size, seq_len_k, d_model)
        v = <YOUR_CODE>   # (batch_size, seq_len_v, d_model)

        # 2. Split into multiple heads
        q = self.split_heads(q, batch_size) # (batch_size, n_heads, seq_len_q, d_k)
        k = self.split_heads(k, batch_size) # (batch_size, n_heads, seq_len_k, d_k)
        v = self.split_heads(v, batch_size) # (batch_size, n_heads, seq_len_v, d_k)

        # 3. Apply scaled dot-product attention
        # The mask needs to be compatible: (batch_size, 1, seq_len_q, seq_len_k)
        # or similar broadcastable shape.
        attention_output, attn_weights = scaled_dot_product_attention(q, k, v, mask)
        # attention_output shape: (batch_size, n_heads, seq_len_q, d_k)
        # attn_weights shape: (batch_size, n_heads, seq_len_q, seq_len_k)

        # 4. Concatenate heads and project back
        # Transpose back to (batch_size, seq_len_q, n_heads, d_k)
        attention_output = attention_output.transpose(1, 2).contiguous()
        # Reshape to (batch_size, seq_len_q, d_model)
        concat_attention = attention_output.view(batch_size, -1, self.d_model)

        # Apply final linear layer W_o
        output = self.W_o(concat_attention) # (batch_size, seq_len_q, d_model)

        return output, attn_weights

In [None]:
# Create dummy input data
batch_size = 1
seq_len = 5
d_model = 128
n_heads = 8

dummy_input = torch.rand(batch_size, seq_len, d_model)

# Instantiate MHA layer
mha_layer = MultiHeadAttention(d_model, n_heads)

# In self-attention, query, key, and value are the same
# No mask needed for this simple example yet
output, attn_weights = mha_layer(dummy_input, dummy_input, dummy_input, mask=None)

print("--- MHA Scratch Implementation ---")
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}") # (batch_size, n_heads, seq_len, seq_len)

### 4.2 Visualizing Self-Attention

Now, let's use a pre-trained BERT model to visualize the *actual* attention weights on a real sentence. This shows which words the model focuses on when representing a specific word.

We'll extract the attention weights from one of BERT's encoder layers. BERT-base has 12 layers and 12 heads per layer.

### Task:
* Analyze attention layers, which words are more "connected" with each other in your chosen layer.

In [None]:
# Reload BERT model if needed, ensuring we get attention outputs
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True).to(device)
model.eval() # Set to evaluation mode

# Choose a sentence where attention patterns might be interesting
# text = "The cat sat on the mat, because it was tired." # "it" should attend to "cat" or "mat"?
text = "The quick brown fox jumps over the lazy dog."
# text = "I went to the bank to deposit money." # "bank" has multiple meanings

inputs = tokenizer(text, return_tensors="pt").to(device)
token_list = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

print(f"Input Text: '{text}'")
print(f"Tokens: {token_list}")

# Perform inference and get outputs, including attentions
with torch.no_grad():
    outputs = model(**inputs)
    attentions = outputs.attentions # Tuple of tensors, one for each layer
    # Each tensor shape: (batch_size, num_heads, sequence_length, sequence_length)

# Let's examine attentions from a specific layer
layer_index = <YOUR_CODE> # You can choose first layer (0) or any other :)
attention_layer = attentions[layer_index].cpu() # Move to CPU for plotting
# Shape: (batch_size, num_heads, seq_len, seq_len)

# Average attention weights across all heads for a simpler view first
attention_avg_heads = attention_layer.mean(dim=1).squeeze(0) # Squeeze batch dim
# Shape: (seq_len, seq_len)

# --- Plotting Function ---
def plot_attention_heatmap(attention_matrix, x_labels, y_labels, title):
    """Plots a heatmap for attention weights."""
    plt.figure(figsize=(10, 8))
    sns.heatmap(attention_matrix, xticklabels=x_labels, yticklabels=y_labels, cmap='viridis', linewidths=.1)
    plt.xlabel("Key (Attended To)")
    plt.ylabel("Query (Attending From)")
    plt.title(title)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

# Plot average attention weights for layer 0
plot_attention_heatmap(attention_avg_heads.numpy(), token_list, token_list, f"Average Self-Attention (Layer {layer_index})")

In [None]:
# Plot interactive heatmap for layer 0
print("\n--- Interactive Self-Attention Plot ---")
plot_attention_interactive(attention_layer, token_list, layer_index)

# It is interactive, so you can just click on plot to choose head

**Interpreting Self-Attention Visualizations:**

*   **Diagonal:** Tokens usually attend strongly to themselves (check Head 6).
*   **Off-Diagonal Bright Spots:** Indicate strong attention between different tokens. Look for meaningful relationships (e.g., pronouns attending to nouns, verbs attending to subjects/objects).
*   **`[CLS]` Token:** Often aggregates information from the entire sequence, especially in later layers, as it's used for classification tasks.
*   **`[SEP]` Token:** Marks boundaries; attention patterns around it can be interesting.
*   **Padding Tokens (`[PAD]`):** Should have near-zero attention weights directed *towards* them if the attention mask is working correctly (though they might attend *to* other tokens).
*   **Different Heads, Different Patterns:** Notice how different heads capture different relationships (e.g., some might focus on local context, others on syntactic dependencies, others on distant relationships). This highlights the benefit of MHA.

**Practical Tips (Self-Attention):**
*   The dimensionality of Q, K, and V doesn't *have* to be the same, but it often is ($d_k = d_v = d_{model} / n_{heads}$).
*   Scaling by $\sqrt{d_k}$ is crucial for stable training.
*   Attention masks are essential for handling padding and for decoder self-attention (look-ahead mask).

### 4.3 Add & Norm and Feed-Forward Network

**Add & Norm:**

*   **Residual Connection (`Add`):** The input to the sub-layer is added to the output of the sub-layer ($x + \text{Sublayer}(x)$).
    *   **Why?** Helps gradients flow during training (mitigates vanishing gradients), allows layers to learn modifications to the identity function, leading to deeper networks.
*   **Layer Normalization (`Norm`):** Normalizes the activations *across the feature dimension* for each token independently.
    *   $LN(x) = \gamma \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$
    *   Where $\mu$ and $\sigma$ are the mean and standard deviation calculated over the `d_model` dimension for a specific token, and $\gamma$ (gamma) and $\beta$ (beta) are learnable scale and shift parameters.
    *   **Why Layer Norm (vs Batch Norm)?** Works well for variable sequence lengths and performs consistently across batch sizes, which is common in NLP. Stabilizes activations and improves training.

**Position-wise Feed-Forward Network (FFN):**

*   A simple network applied independently to each position (each token's representation) in the sequence.
*   Consists of two linear transformations with a non-linearity (usually ReLU or GELU) in between.
    *   $\text{FFN}(x) = \text{max}(0, xW_1 + b_1)W_2 + b_2$ (using ReLU)
    *   The inner dimension (`d_ff`, typically $4 \times d_{model}$) is larger than the input/output dimension (`d_model`).
*   **Why?** Adds non-linearity and capacity to the model, allowing it to learn more complex transformations of the token representations after the attention mechanism has mixed information across the sequence. It can be seen as processing the information aggregated by the attention layer for each token.

**Let's implement the FFN and put together a full Encoder Layer.**

### Task:
* Try to realize forward pass in PositionwiseFeedForward and EncoderLayer.

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        """
        Args:
            d_model (int): Input and output dimension.
            d_ff (int): Inner dimension (usually 4 * d_model).
            dropout (float): Dropout probability.
        """
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        # Common activation functions: ReLU or GELU (used in BERT)
        # self.activation = nn.ReLU()
        self.activation = nn.GELU()

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor (batch_size, seq_len, d_model).
        Returns:
            torch.Tensor: Output tensor (batch_size, seq_len, d_model).
        """
        # Lin1->activation->dropour->lin2
        <YOUR_CODE>
        return x

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        """
        Args:
            d_model (int): Model dimension.
            n_heads (int): Number of attention heads.
            d_ff (int): Inner dimension of FFN.
            dropout (float): Dropout probability.
        """
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)

        # Layer Normalization layers
        self.norm1 = nn.LayerNorm(d_model, eps=1e-6) # Epsilon for numerical stability
        self.norm2 = nn.LayerNorm(d_model, eps=1e-6)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask):
        """
        Forward pass for a single Encoder layer.

        Args:
            x (torch.Tensor): Input tensor (batch_size, seq_len, d_model).
            mask (torch.Tensor): Attention mask (broadcastable to batch_size, 1, seq_len, seq_len).

        Returns:
            torch.Tensor: Output tensor (batch_size, seq_len, d_model).
        """
        # 1. Multi-Head Self-Attention + Add & Norm
        <YOUR_CODE> # Self-attention: Q, K, V are all x
        # Apply dropout to attention output, then add residual connection, then layer norm
        <YOUR_CODE>

        # 2. Feed-Forward Network + Add & Norm
        <YOUR_CODE>
        # Apply dropout to FFN output, then add residual connection, then layer norm
        <YOUR_CODE>

        return x

In [None]:
batch_size = 1
seq_len = 10
d_model = 128 # Should match MHA's d_model
n_heads = 8   # Should match MHA's n_heads
d_ff = d_model * 4 # Common practice

# Dummy input and mask
dummy_input = torch.rand(batch_size, seq_len, d_model)
# Example padding mask: assume last 3 tokens are padding
# Mask needs shape compatible with MHA: (batch_size, 1, seq_len_q, seq_len_k)
# Here seq_len_q = seq_len_k = seq_len
dummy_mask = torch.ones(batch_size, 1, 1, seq_len) # Start with all ones (attend)
dummy_mask[:, :, :, -3:] = 0 # Mask out last 3 tokens (0 means mask)
# Ensure mask is on the correct device if using GPU
dummy_mask = dummy_mask.to(dummy_input.device)

# Instantiate EncoderLayer
encoder_layer = EncoderLayer(d_model, n_heads, d_ff)

# Pass input through the layer
encoder_output = encoder_layer(dummy_input, dummy_mask)

print("--- Encoder Layer Scratch Implementation ---")
print(f"Input shape: {dummy_input.shape}")
print(f"Mask shape (original): (batch_size, 1, 1, seq_len)")
print(f"Output shape: {encoder_output.shape}") # Should be same as input shape

### 4.4 Layer-wise Attention Pattern Evolution (Encoder Self-Attention)
How do the self-attention patterns change as information propagates through the layers of the Encoder?

Lower layers might focus more on local, syntactic relationships, while higher layers might capture broader, more semantic context. Let's visualize the average self-attention weights from different layers in BERT.

In [None]:
# Ensure BERT model and tokenizer are loaded from previous cells
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True).to(device)
model.eval()

# Use a sentence where context matters
text_layer_viz = "The animal didn't cross the street because it was too tired."
# text_layer_viz = "I went to the bank to deposit money, not the river bank."

inputs_layer_viz = tokenizer(text_layer_viz, return_tensors="pt").to(device)
tokens_layer_viz = tokenizer.convert_ids_to_tokens(inputs_layer_viz["input_ids"][0])

# Get attentions from all layers
with torch.no_grad():
    outputs_layer_viz = model(**inputs_layer_viz)
    attentions_all_layers = outputs_layer_viz.attentions # Tuple of (batch_size, num_heads, seq_len, seq_len)

num_layers = len(attentions_all_layers)
print(f"Number of layers in {model_name}: {num_layers}")

# Select layers to visualize (e.g., first, latest)
layers_to_show = [0, num_layers - 3]

print(f"\nVisualizing Average Self-Attention for Layers: {layers_to_show}")

for i, layer_idx in enumerate(layers_to_show):
    # Get attention for the specific layer, move to CPU
    attention_layer = attentions_all_layers[layer_idx].cpu()
    # Average across heads
    attention_avg_heads = attention_layer.mean(dim=1).squeeze(0) # Squeeze batch dim
    # Shape: (seq_len, seq_len)

    plot_attention_heatmap_improved(
        attention_avg_heads.numpy(),
        x_labels=tokens_layer_viz,
        y_labels=tokens_layer_viz,
        title=f"Average Self-Attention Pattern (Layer {layer_idx})",
        figsize=(10, 8), # Adjust size if needed
        fmt=".1f" # Use fewer decimals for less clutter
    )

print(f"\nObserve how patterns change:")
print(f"- Layer {layers_to_show[0]} (First): Often focuses on local context, diagonal, maybe syntax.")
print(f"- Layer {layers_to_show[1]} (Latest): Often shows more diffused attention, potentially focusing on key semantic tokens or the [CLS] token aggregating context.")
print(f"Look at the row for 'it' - which token(s) does it attend to most strongly in different layers ('animal' vs 'street')?")

### 4.5 Using Encoder Output: The [CLS] Token for Classification (Conceptual)
For models like BERT (Encoder-only), a common way to perform sequence classification (e.g., sentiment analysis) is to use the final hidden state corresponding to the special $[CLS]$ token.

This token's representation is assumed to aggregate the contextual information from the entire sequence after passing through all encoder layers. A simple linear classifier is then trained on top of this single vector.

In [None]:
# Ensure BERT model and tokenizer are loaded
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device) # No need for attentions here
model.eval()

# Input sentence
text_cls = "This movie was absolutely fantastic!"
inputs_cls = tokenizer(text_cls, return_tensors="pt", truncation=True, padding=True).to(device)
print(f"Input Text: '{text_cls}'")
print(f"Input IDs: {inputs_cls['input_ids']}")
tokens_cls = tokenizer.convert_ids_to_tokens(inputs_cls['input_ids'][0])
print(f"Tokens: {tokens_cls}")

# Get the outputs from the BERT model
with torch.no_grad():
    outputs_cls = model(**inputs_cls)
    # outputs_cls.last_hidden_state contains the final hidden states for all tokens
    # Shape: (batch_size, sequence_length, hidden_size)
    last_hidden_states = outputs_cls.last_hidden_state

print(f"\nShape of final hidden states: {last_hidden_states.shape}")

# The [CLS] token is always the first token (index 0)
cls_hidden_state = last_hidden_states[:, 0, :] # Select the hidden state for the [CLS] token
# Shape: (batch_size, hidden_size)
print(f"Shape of [CLS] token hidden state: {cls_hidden_state.shape}")
print(f"Hidden size (d_model) for {model_name}: {model.config.hidden_size}")

In [None]:
# --- Conceptual Classifier ---
# In a real scenario, you would define and train this layer
hidden_size = model.config.hidden_size
num_classes = 2 # Example: Positive/Negative sentiment (binary)

# Define a simple linear layer (untrained)
classifier_head = nn.Linear(hidden_size, num_classes).to(device)

# Pass the [CLS] token's hidden state through the classifier
# (We're not training, just showing the mechanism)
logits_cls = classifier_head(cls_hidden_state)
# Shape: (batch_size, num_classes)
print(f"\nShape of output logits from conceptual classifier: {logits_cls.shape}")

# These logits would then be used with a loss function (e.g., CrossEntropyLoss)
# during training, or passed through a softmax for prediction probabilities during inference.
print("Mechanism: The [CLS] token's final hidden state is used as the aggregated sequence representation for classification tasks.")

Try to understand the code above. We will speak a lot about BERT models next Monday!

**Practical Tips (Encoder Layer):**
*   The order matters: Attention -> Add & Norm -> FFN -> Add & Norm.
*   Dropout is applied at multiple points (usually after MHA, after FFN, after PE/Embedding addition) to prevent overfitting.
*   Layer Normalization parameters ($\gamma, \beta$) are learned.
*   The choice of activation in FFN (ReLU, GELU) can impact performance. GELU is common in modern Transformers like BERT.

## 5. The Decoder: Generating the Output

The Decoder's role is to generate the output sequence token by token, based on the encoded input representation ($\mathbf{z}$) and the previously generated output tokens. It also consists of a stack of N identical layers.

**Each Decoder Layer has *three* main sub-layers:**

1.  **Masked Multi-Head Self-Attention:** Allows each position in the *output* sequence to attend to *previous* positions (including itself) in the output sequence. The "Masked" part is crucial to prevent attending to future tokens, maintaining the autoregressive property (generation depends only on past outputs).
2.  **Multi-Head Cross-Attention:** This is where the Decoder interacts with the Encoder output. The queries ($Q$) come from the output of the previous Decoder sub-layer (Masked MHA), while the keys ($K$) and values ($V$) come from the **output of the Encoder**. This allows each output token to attend to relevant parts of the *input* sequence.
3.  **Position-wise Feed-Forward Network (FFN):** Same structure and function as in the Encoder layer, applied to the output of the Cross-Attention sub-layer.

Residual connections (`Add`) and Layer Normalization (`Norm`) are used around each of these three sub-layers.

**Decoder Layer Structure:**
> ![Transformer Architecture](https://upload.wikimedia.org/wikipedia/commons/thumb/5/55/Transformer%2C_one_decoder_block.png/640px-Transformer%2C_one_decoder_block.png)


**Let's look at the two new attention mechanisms in the Decoder.**

### 5.1 Masked Multi-Head Self-Attention

This is almost identical to the Encoder's self-attention, but with one critical difference: the **look-ahead mask**.

**Look-Ahead Mask:**
*   Prevents positions from attending to subsequent positions.
*   Ensures that the prediction for position $i$ can only depend on the known outputs at positions less than $i$.
*   Implemented by creating a mask matrix where entries corresponding to future positions (upper triangle of the attention score matrix) are set to $-\infty$ (or a large negative number) before the softmax step.

**Example Mask (for seq_len = 4):**
> [[ 0., -inf, -inf, -inf], <- Pos 0 can only attend to Pos 0<br>
> [ 0., 0., -inf, -inf], <- Pos 1 can attend to Pos 0, 1<br>
> [ 0., 0., 0., -inf], <- Pos 2 can attend to Pos 0, 1, 2<br>
> [ 0., 0., 0., 0.]] <- Pos 3 can attend to Pos 0, 1, 2, 3<br>
> (Where 0 allows attention, -inf prevents it)

*Note: Depending on implementation, the mask might use 1s and 0s, where 0 means "mask out". Our `scaled_dot_product_attention` expects 0 for masking in the `masked_fill` function.*

We can modify our `scaled_dot_product_attention` function or pass the appropriate mask to our existing `MultiHeadAttention` module.

### Task:
* Realize a mask creating function.

**Useful functions**:
1. https://pytorch.org/docs/stable/generated/torch.triu.html
2. https://pytorch.org/docs/stable/generated/torch.ones.html


In [None]:
# Function to create a look-ahead mask
def create_look_ahead_mask(size):
    """
    Creates a look-ahead mask for self-attention.
    Mask has 1s where attention is allowed, 0s where it is blocked.
    Shape: (1, 1, size, size) to be broadcastable.
    """
    mask = <YOUR_CODE> # Upper triangle (True where we want to mask)
    # In our attention function, we use mask == 0 for masking, so we invert
    return ~mask # Lower triangle + diagonal = True (allow attention)

In [None]:
seq_len_dec = 6
d_model_dec = 128
n_heads_dec = 8

# Dummy decoder input
dummy_decoder_input = torch.rand(1, seq_len_dec, d_model_dec)

# Create the look-ahead mask
look_ahead_mask = create_look_ahead_mask(seq_len_dec)
print("--- Look-Ahead Mask ---")
# Print the mask content (True/1 means allow, False/0 means mask)
print(look_ahead_mask.squeeze().int())

# Instantiate MHA layer
decoder_self_mha = MultiHeadAttention(d_model_dec, n_heads_dec)

# Apply masked self-attention
masked_attn_output, masked_attn_weights = decoder_self_mha(
    dummy_decoder_input, dummy_decoder_input, dummy_decoder_input,
    mask=look_ahead_mask # Pass the look-ahead mask
)

print("\n--- Masked Self-Attention Output ---")
print(f"Decoder Input shape: {dummy_decoder_input.shape}")
print(f"Masked Attention Output shape: {masked_attn_output.shape}")
print(f"Masked Attention Weights shape: {masked_attn_weights.shape}")

In [None]:
# We use the weights from our dummy example here.
# For real models, we'd extract from a decoder layer.

# Average weights across heads for visualization
avg_masked_weights = masked_attn_weights.mean(dim=1).squeeze(0).detach().numpy()

# Generate labels for the plot
dec_labels = [f"Dec_{i}" for i in range(seq_len_dec)]

plot_attention_heatmap(avg_masked_weights, dec_labels, dec_labels, "Average Masked Self-Attention (Decoder)")

print("\nObservations from Masked Self-Attention visualization:")
print("- The upper triangle (above the main diagonal) should have near-zero attention weights due to the mask.")
print("- Each token can only attend to itself and preceding tokens in the sequence.")

### 5.2 Multi-Head Cross-Attention

This is the mechanism that allows the Decoder to incorporate information from the Encoder's output.

*   **Queries ($Q$):** Come from the Decoder's previous sub-layer (output of the masked self-attention + Add & Norm). Shape: `(batch_size, target_seq_len, d_model)`.
*   **Keys ($K$) and Values ($V$):** Come from the **final output of the Encoder stack**. They are the same for every step of the Decoder. Shape: `(batch_size, source_seq_len, d_model)`.

The `MultiHeadAttention` module we implemented earlier can be used directly, just by feeding the correct inputs:
`cross_attn_output, cross_attn_weights = mha_layer(query=decoder_intermediate_output, key=encoder_output, value=encoder_output, mask=padding_mask)`

**Important Masking Note:** In cross-attention, the mask used should typically be the **padding mask** from the *Encoder's input* sequence. This prevents the Decoder from attending to padding tokens in the source sequence. The look-ahead mask is *not* used here because the Decoder is allowed (and encouraged) to attend to *any* position in the encoded input sequence.

**Let's visualize Cross-Attention using a pre-trained Encoder-Decoder model (like T5).**

In [None]:
# Load a small T5 model and tokenizer
model_name_t5 = "t5-small"
tokenizer_t5 = T5Tokenizer.from_pretrained(model_name_t5)
model_t5 = T5ForConditionalGeneration.from_pretrained(model_name_t5, output_attentions=True).to(device)
model_t5.eval()

# Example: English to French translation task
task_prefix = "translate English to French: "
# Let's use a slightly longer sentence to better see attention patterns
input_text = "This framework is very useful for NLP tasks."
target_text = "Ce framework est très utile pour les tâches NLP." # Corresponding French translation

# Encode the input (English)
encoder_inputs = tokenizer_t5(task_prefix + input_text, return_tensors="pt", padding=True, truncation=True).to(device)
encoder_input_ids = encoder_inputs['input_ids']
# Get tokens, handling potential special characters like ' ' for SentencePiece
encoder_tokens = [tokenizer_t5.decode([t_id]) for t_id in encoder_input_ids[0]]


# Encode the target (French) - simulating decoder input during training/teacher forcing
decoder_input_ids_raw = tokenizer_t5(target_text, return_tensors="pt", padding=True, truncation=True).input_ids
# T5 requires decoder input IDs to start with a pad token ID for generation.
# During training/teacher forcing, we shift the target sequence right and add the start token.
decoder_input_ids = model_t5._shift_right(decoder_input_ids_raw).to(device)
# Get tokens for the decoder input sequence
decoder_tokens = [tokenizer_t5.decode([t_id]) for t_id in decoder_input_ids[0]]


print(f"Encoder Input Text: '{task_prefix + input_text}'")
print(f"Encoder Tokens: {encoder_tokens}")
print(f"Decoder Input Text (simulated): '{target_text}'")
print(f"Decoder Input Tokens: {decoder_tokens}") # Includes start token '<pad>', ends with '</s>'

In [None]:
# --- Perform Inference and Get Attentions ---
with torch.no_grad():
    outputs_t5 = model_t5(
        input_ids=encoder_input_ids,
        decoder_input_ids=decoder_input_ids, # Provide decoder inputs for teacher forcing / attention viz
        output_attentions=True,
        return_dict=True
    )
    # Cross attentions are present in the 'cross_attentions' field of the output object
    cross_attentions = outputs_t5.cross_attentions # Tuple of tensors, one for each decoder layer

# --- Visualize Cross-Attention ---

# Select a decoder layer to visualize (e.g., layer 0 or a middle layer like 3)
layer_index_dec = 3 # Try changing this index (0 to 5 for t5-small)
cross_attention_layer = cross_attentions[layer_index_dec].cpu()
# Shape: (batch_size, num_heads, target_sequence_length, source_sequence_length)

# Average across heads for the static plot
cross_attention_avg = cross_attention_layer.mean(dim=1).squeeze(0) # Squeeze batch dim
# Shape: (target_seq_len, source_seq_len)

print(f"\n--- Static Cross-Attention Plot (Layer {layer_index_dec}, Averaged Heads) ---")
plot_attention_heatmap_improved(
    cross_attention_avg.numpy(),
    x_labels=encoder_tokens,  # Source tokens (English)
    y_labels=decoder_tokens,  # Target tokens (French)
    title=f"Average Cross-Attention (Decoder Layer {layer_index_dec})"
)

In [None]:
print(f"\n--- Interactive Cross-Attention Plot (Layer {layer_index_dec}) ---")
# Pass the full layer tensor (with heads) to the interactive function
plot_cross_attention_interactive(
    cross_attention_layer,
    encoder_tokens,
    decoder_tokens,
    layer_index_dec
)

**Interpreting Cross-Attention Visualizations:**

*   Each row corresponds to a token being generated by the Decoder (Query).
*   Each column corresponds to a token in the original input sequence processed by the Encoder (Key/Value).
*   Bright spots indicate which input tokens the Decoder paid most attention to when generating a specific output token.
*   You can often see alignments between source and target words (e.g., "house" in English attending strongly to "maison" in French).
*   The `</s>` (end-of-sentence) token in the Decoder might attend broadly or to specific concluding words in the input.

**Practical Tips (Decoder):**
*   The two masks (look-ahead for self-attention, padding for cross-attention) are critical and distinct.
*   During inference (actual generation), the Decoder operates autoregressively: generate one token, feed it back as input for the next step, repeat until an end-of-sequence token is produced.
*   Teacher Forcing (used during training): Feed the *actual* ground-truth target sequence token as input at each step, regardless of what the model predicted previously. This stabilizes training. Our T5 visualization simulated this.

### 5.3 Visualizing Attention Masking Effects
We've discussed padding masks (for Encoder self-attention and Cross-attention) and look-ahead masks (for Decoder self-attention). Let's explicitly visualize how these masks affect the attention scores before the softmax step and the final attention weights after softmax.

This helps understand why masking is crucial. Remember, the goal of masking is to prevent attention to certain tokens by setting their pre-softmax scores to a very large negative number (like -infinity), which results in near-zero probability after softmax.

In [None]:
# --- Setup Dummy Data ---
batch_size = 1
n_heads = 1 # Simulate single head for clarity
seq_len = 6
d_k = 8 # Small dimension for easier inspection

# Dummy Q, K, V (normally derived from projections)
q = torch.randn(batch_size, n_heads, seq_len, d_k)
k = torch.randn(batch_size, n_heads, seq_len, d_k)
v = torch.randn(batch_size, n_heads, seq_len, d_k) # d_v = d_k here

# --- 1. Padding Mask Example ---
# Assume last 2 tokens are padding
# Mask shape (batch_size, n_heads, seq_len_q, seq_len_k) -> (1, 1, 6, 6)
# Mask should be 1 (True) for tokens to *keep*, 0 (False) for tokens to *mask out*
padding_mask = torch.ones(batch_size, n_heads, seq_len, seq_len)
padding_mask[:, :, :, -2:] = 0 # Mask attention TO the last 2 key tokens

# Calculate scores without mask
scores_unmasked = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

# Calculate scores WITH mask applied *before* softmax
scores_masked_padding = scores_unmasked.masked_fill(padding_mask == 0, -1e9)

# Calculate attention weights (softmax)
attn_weights_unmasked = F.softmax(scores_unmasked, dim=-1)
attn_weights_masked_padding = F.softmax(scores_masked_padding, dim=-1)

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
sns.heatmap(scores_unmasked.squeeze().detach().numpy(), annot=True, cmap="coolwarm", fmt=".1f", ax=axes[0], cbar=False)
axes[0].set_title("Raw Scores (No Mask)")
axes[0].set_xlabel("Key Position")
axes[0].set_ylabel("Query Position")

sns.heatmap(padding_mask.squeeze().detach().numpy(), annot=True, cmap="gray", fmt=".0f", ax=axes[1], cbar=False)
axes[1].set_title("Padding Mask (0 = Masked)")
axes[1].set_xlabel("Key Position")


sns.heatmap(attn_weights_masked_padding.squeeze().detach().numpy(), annot=True, cmap="viridis", fmt=".2f", ax=axes[2])
axes[2].set_title("Attention Weights (With Padding Mask)")
axes[2].set_xlabel("Key Position")

plt.suptitle("Effect of Padding Mask on Attention Weights", fontsize=14)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

print("Observation: Notice how the last two columns in the final attention weights are near zero due to the padding mask.")

In [None]:
# --- 2. Look-Ahead Mask Example ---
look_ahead_mask = create_look_ahead_mask(seq_len)

# Use the same unmasked scores from before
scores_masked_lookahead = scores_unmasked.masked_fill(look_ahead_mask == 0, -1e9)

# Calculate attention weights (softmax)
attn_weights_masked_lookahead = F.softmax(scores_masked_lookahead, dim=-1)

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
sns.heatmap(scores_unmasked.squeeze().detach().numpy(), annot=True, cmap="coolwarm", fmt=".1f", ax=axes[0], cbar=False)
axes[0].set_title("Raw Scores (No Mask)")
axes[0].set_xlabel("Key Position")
axes[0].set_ylabel("Query Position")

sns.heatmap(look_ahead_mask.squeeze().int().detach().numpy(), annot=True, cmap="gray", fmt=".0f", ax=axes[1], cbar=False)
axes[1].set_title("Look-Ahead Mask (0 = Masked)")
axes[1].set_xlabel("Key Position")

sns.heatmap(attn_weights_masked_lookahead.squeeze().detach().numpy(), annot=True, cmap="viridis", fmt=".2f", ax=axes[2])
axes[2].set_title("Attention Weights (With Look-Ahead Mask)")
axes[2].set_xlabel("Key Position")


plt.suptitle("Effect of Look-Ahead Mask on Attention Weights", fontsize=14)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

print("Observation: Notice how the upper triangle (excluding the diagonal) in the final attention weights is near zero due to the look-ahead mask. Each query position can only attend to itself and previous key positions.")

## 6. Final Linear Layer and Softmax

After the final Decoder layer produces its output representations (shape: `batch_size, target_seq_len, d_model`), we need to convert these back into probabilities over the vocabulary for each position.

1.  **Linear Layer:** A final linear layer (without bias is common, but depends on implementation) projects the `d_model`-dimensional vector for each position into a `vocab_size`-dimensional vector (logits).
    *   Input: `(batch_size, target_seq_len, d_model)`
    *   Weight Matrix: `(d_model, vocab_size)`
    *   Output (Logits): `(batch_size, target_seq_len, vocab_size)`

2.  **Softmax:** The logits are converted into probabilities using the softmax function, applied independently at each position along the `vocab_size` dimension.
  * $P(y_i | y_{<i}, x) = \text{softmax}(\text{Linear}(\text{decoder_output}_i))$
  *   The output represents the probability distribution over the entire vocabulary for the next token at each position.

During inference, we typically select the token with the highest probability (greedy decoding) or use more advanced sampling strategies like beam search, top-k sampling, or nucleus sampling.

In [None]:
# Assume 'decoder_output' is the output from the last Decoder layer
# Example using the T5 output logits
final_decoder_output = outputs_t5.logits.cpu() # Logits already computed by T5 model
# Shape: (batch_size, target_seq_len, vocab_size)

print(f"--- Final Output ---")
print(f"Decoder Output (Logits) shape: {final_decoder_output.shape}")

# Apply Softmax to get probabilities
probabilities = F.softmax(final_decoder_output, dim=-1)
print(f"Probabilities shape: {probabilities.shape}")

# Find the token with the highest probability at each position
predicted_token_ids = torch.argmax(probabilities, dim=-1) # Shape: (batch_size, target_seq_len)
print(f"Predicted Token IDs shape: {predicted_token_ids.shape}")

# Decode the predicted IDs
predicted_tokens = tokenizer_t5.convert_ids_to_tokens(predicted_token_ids[0])
predicted_sentence = tokenizer_t5.decode(predicted_token_ids[0], skip_special_tokens=True)

print(f"\nTarget Tokens (Input to Decoder): {decoder_tokens}")
print(f"Predicted Tokens (Greedy):        {predicted_tokens}")
print(f"Predicted Sentence (Greedy):      '{predicted_sentence}'")
# Note: The predicted sentence might not exactly match the target_text used as input.
# This is because we fed the *entire* target sequence at once (teacher forcing style).
# True autoregressive generation builds the sequence step-by-step.

## 7. Conclusion & Further Exploration

We have journeyed through the core components of the Transformer architecture:

1.  **Tokenization:** Breaking text into manageable pieces (subwords) and mapping them to IDs.
2.  **Embeddings & Positional Encoding:** Converting IDs to vectors and injecting sequence order information.
3.  **Encoder:** Processing the input sequence using Self-Attention and Feed-Forward networks to build contextualized representations.
4.  **Decoder:** Generating the output sequence using Masked Self-Attention (autoregressive property), Cross-Attention (linking to input), and Feed-Forward networks.
5.  **Final Projection:** Converting final representations to vocabulary probabilities.

**Key Takeaways & Practical Tips Recap:**

*   **Attention is Powerful:** It allows modeling long-range dependencies effectively.
*   **Multi-Head Attention:** Captures diverse relationships simultaneously.
*   **Masking is Crucial:** For handling padding and ensuring autoregressive behavior in the decoder.
*   **Residuals & LayerNorm:** Essential for training deep networks.
*   **Pre-trained Models:** Leverage models trained on massive datasets (like BERT, T5, GPT) via libraries like Hugging Face `transformers`. Always match the tokenizer to the model.
*   **Visualization Helps:** Understanding attention patterns provides insight into model behavior.

**Where to go from here?**

*   **Explore Different Architectures:** Dive into BERT (Encoder-only), GPT (Decoder-only) specifics.
*   **Fine-tuning:** Learn how to adapt pre-trained models to specific downstream tasks (classification, QA, summarization, etc.).
*   **Generation Strategies:** Investigate beam search, top-k, nucleus sampling for better text generation.
*   **Efficiency:** Look into techniques like knowledge distillation, quantization, and efficient attention variants (e.g., Linformer, Performer).
*   **Implement a Full Transformer:** Try building and maybe even training a small Transformer from scratch on a toy task.
*   **Read the Paper:** Go back to "Attention Is All You Need" now with a deeper understanding.

# Module 2: Vision Transformers (ViT)

**Goal**: Understand how the main ideas of Tranformers for NLP (Module 1) can be transferred to computer vision. We will explore how tokenization, embeddings, positional encodings, attention are realized for images.

**Approach**: We will leverage the Hugging Face transformers library to finetune and inspect the key components of Vision Transformer for image classification.

## Prerequisites

In [None]:
! pip install transformers datasets evaluate



In [None]:
from datasets import load_dataset
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import evaluate
import cv2
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
def compute_distance_matrix(patch_size, num_patches, length):
    distance_matrix = np.zeros((num_patches, num_patches))
    for i in range(num_patches):
        for j in range(num_patches):
            if i == j:  # zero distance
                continue

            xi, yi = (int(i / length)), (i % length)
            xj, yj = (int(j / length)), (j % length)
            distance_matrix[i, j] = patch_size * np.linalg.norm([xi - xj, yi - yj])

    return distance_matrix

## 1. Introduction

The Vision Transformer (ViT) architecture was proposed in the paper ["An Image Is Worth 16X16 Words: Transformers for Image Recognition at Scale"](https://arxiv.org/pdf/2010.11929) by Dosovitskiy et al.

Inspired by success of Transformers in NLP, the authors proposed to apply Transformer architecture to images, with the fewest possible modifications.

When pre-trained on large datasets and transferred to tasks with fewer datapoints, ViT achieves performance comparable to state-of-the-art convolutional networks while requiring fewer computational resources to train.

**High-level architecture**

![](https://yastatic.net/s3/education-portal/media/Vi_T_c7f1bfbedc_e047c467f5.webp)

The main components of ViT:
1. **Linear projection**: an image is converted into a set of vectors as Transformer input.
2. **Transformer Encoder**: encoder similar to encoder for NLP.
3. **MLP Head**: fully-connected classification head.

## 2. ViT architecture

### 2.1 Linear Projection

To handle $2D$ images, we need to convert them to a sequence of embeddings.

Let $H, W$ - spatial height and width, $C$ - number of channels.

1. Split an image $x\in\mathbb{R}^{H\times W \times C}$ into patches of shape $P\times P \times C$.

2. Reshape the patches of $P \times P \times C$ into vectors of shape $P \times P \times C$.

Thus, we obtain $x_p \in \mathbb{R}^{N \times (P^2 \cdot C)}$, where $N = HW / P^2$ - number of patches (also serves as the effective input sequence length for the Transformer). Image patches are treated the same way as tokens (words) in NLP.

3. Obtain patch embeddings via flattenning and mapping to $D$-dimensional vectors.

$$z_0 = [x_{class}; x_p^1\pmb{E}; x_p^2\pmb{E}; \dots x_p^N\pmb{E}] + \pmb{E}_{pos}, \quad \pmb{E}\in \mathbb{R}^{(P^2\cdot C) \times D}, \quad \pmb{E}_{pos} \in \mathbb{R}^{(N + 1) \times D}$$

Note: similar to BERT [CLS] token, a learnable embedding $z_0^0 = x_{class}$ is added to the patch embeddings.

4. Similar to NLP, position embeddings are added to patch embeddings to retain positional information. Position embeddings are learnable $1D$ parameters.

### 2.2 Transformer Encoder

5. The resulting sequence $z_0$ of embedding vectors is input to Transformer model with alternating Multihead Self-Attention layers (MSA) and MLP blocks. LayerNorm (LN) is applied before each block, and residual connections after each block.

$$z^\prime_l = \text{MSA}(\text{LN}(z_{l-1})) + z_{l-1}, \quad l=1 \dots L$$

$$z_l = \text{MLP}(\text{LN}(z^\prime_l)) + z^\prime_l, \quad l=1 \dots L$$

### 2.3 MLP Head

6. Classification token's state at the output of the Transformer encoder $z_L^0$ serves as the image representation $y$.

$$y = \text{LN}(z_L^0)$$

The classification head is implemented by a MLP. The model on image classification is trained in a supervised way.


## 3. Practice

Now, let us perform pre-trained ViT finetuning and inspect its key components.

### 3.1 Dataset Preprocessing

First, let us load the dataset. While the original dataset contains $101$ classes, we will use only $20$ first classes for faster convergence.

In [None]:
# Load and preprocess the dataset
dataset_path = "ethz/food101"

# Load the dataset
dataset_raw = load_dataset(dataset_path)
N_LABELS = 20 # set the desired number of classes
labels_full = dataset_raw['train'].features['label'].names
labels = labels_full[:N_LABELS]
idx2label = {i: l for i, l in enumerate(labels)}

# Filter dataset: choose examples with desired classes
dataset_raw = dataset_raw.filter(lambda example: example['label'] in idx2label and example['image'].mode == "RGB")
dataset_raw

Visualize some images from the original dataset.

In [None]:
# Choose a batch of random images
np.random.seed(0)
random_idx = np.random.choice(len(dataset_raw['validation']), size=10, replace=False)
batch = dataset_raw['validation'][random_idx]

# Plot examples
fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(25, 10), tight_layout=True)
ax = ax.flatten()
for i, (img, label) in enumerate(zip(batch['image'], batch['label'])):
  ax[i].imshow(img)
  ax[i].set_title("Class: {}\n Size: {}x{}".format(idx2label[label], img.size[0], img.size[1]))

To finetune ViT model, we will apply the corresponding image preprocessing as dataset transform.

In [None]:
# Choose the pre-trained model path
model_name = 'google/vit-base-patch16-224'
# Load the ViT image Processor
processor = ViTImageProcessor.from_pretrained(model_name)

Define transform function using the ViTImageProcessor and apply to the dataset

**Task**:
1. Apply processor from prev cell for an each image in the batch.

In [None]:
def transform(batch):
    # Apply processor to images from the batch
    proc_batch = <YOUR_CODE>
    # Copy the image labels
    proc_batch['labels'] = batch['label']
    return proc_batch

# Apply the defined transform to the dataset
dataset = dataset_raw.with_transform(transform)

Inspect the images after transform.

In [None]:
# Choose the same batch from the transformed dataset
subset = dataset['validation'][random_idx]
mean = torch.tensor(processor.image_mean)
std = torch.tensor(processor.image_std)

# Plot exmaples
fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(25, 10), tight_layout=True)
ax = ax.flatten()
# After transform, 'image' is replaced with 'pixel_values', 'label' - with 'labels'
for i, (img, label) in enumerate(zip(subset['pixel_values'], subset['labels'])):
  proc_img = img.permute((1, 2, 0)) * std + mean
  ax[i].imshow(proc_img)
  ax[i].set_title("Class: {}\n Size: {}x{}".format(idx2label[label], proc_img.shape[0], proc_img.shape[1]))

### 3.2 Model Training

To train the model, we need some utils.

Define collate function to form batches

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

Load and define accuracy metric

In [None]:
def compute_metrics(eval_input):
  # Load the metric
  accuracy_metric = evaluate.load("accuracy")
  logits, labels = eval_input
  # Compute predictions
  preds = np.argmax(logits, axis=1)
  return accuracy_metric.compute(predictions=preds, references=labels)

Load the pre-trained ViT

**Task**:
1. Fill in basic arguments in functions
2. Use 1 epoch with 0.0001 LR and batch_size of 64.

In [None]:
# Specify model name, number of clasees
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=len(idx2label),
    # We change the number of classes
    ignore_mismatched_sizes=True,
)

Setup training arguments and trainer

In [None]:
# Choose batch sizes to fit GPU memory
TRAIN_BATCH_SIZE = <YOUR_CODE>
EVAL_BATCH_SIZE =  <YOUR_CODE>
# Hint: We said 1, because 1 epoch is enough, training is not fast
NUM_EPOCHS = <YOUR_CODE>
# Set appropriate learning rate
LR = <YOUR_CODE>

In [None]:
training_args = TrainingArguments(
  # Set desired value
  output_dir=None,
  per_device_train_batch_size=<YOUR_CODE>,
  per_device_eval_batch_size=<YOUR_CODE>,
  num_train_epochs=<YOUR_CODE>,
  learning_rate=<YOUR_CODE>,
  # Set logging frequency
  logging_steps=50,
  evaluation_strategy="steps",
  fp16=True,
  remove_unused_columns=False,
)

In [None]:
trainer = Trainer(
    # Model
    model=<YOUR_CODE>,
    # Training args
    args=<YOUR_CODE>,
    # Collate function
    data_collator=collate_fn,
    # Metric function
    compute_metrics=compute_metrics,
    # Train dataset
    train_dataset=dataset["train"],
    # Validation dataset
    eval_dataset=dataset["validation"],
    # ViT ImageProcessor
    tokenizer=processor,
)

  trainer = Trainer(


Train the model

In [None]:
trainer.train()

Evaluate finetuned model

In [None]:
trainer.evaluate()

### 3.3 Inspect Linear Projection

Choose an example image, get model prediction

**Task:**
* How to we get prediction from logits? Do you remember? Realize it.

In [None]:
# Image from the validation dataset
img = dataset['validation'][0]['pixel_values']
# True label
true_label = dataset['validation'][0]['labels']
# Obtain logits
logits = model(img.unsqueeze(0).to('cuda')).logits.detach().cpu()
# Get predictions
pred_label = <YOUR_CODE>
# Inverse process image
img = img.permute((1, 2, 0)) * std + mean

Plot the image

In [None]:
plt.imshow(img)
plt.title('True label: {},\n Predicted label: {}'.format(idx2label[true_label], idx2label[pred_label]))

Let us visualize image patches. First, get patch size and number of patches

In [None]:
# Get patch size from model config
PATCH_SIZE = model.config.patch_size
# Get spatial sizes from processor
H, W = processor.size['height'], processor.size['width']
# Compute number of patches
NUM_PATCHES = (H * W) // (PATCH_SIZE * PATCH_SIZE)
PATCH_SIZE, NUM_PATCHES

In [None]:
# Example patches
patches = img.unfold(0, PATCH_SIZE, PATCH_SIZE).unfold(1, PATCH_SIZE, PATCH_SIZE).permute((0, 1, 3, 4, 2))
patches.shape

In [None]:
H_NUM_PATCHES, W_NUM_PATCHES = patches.shape[:2]
H_NUM_PATCHES, W_NUM_PATCHES

In [None]:
# Visualize patches
fig, ax = plt.subplots(H_NUM_PATCHES, W_NUM_PATCHES)
for i in range(H_NUM_PATCHES):
    for j in range(W_NUM_PATCHES):
        p = patches[i, j]
        ax[i][j].imshow(p)
        ax[i][j].axis('off')

In ViT, each patch is projected to $D$-dimensional vector with the help of ViT. Define a convolution layer, that would do this linear projection.

**Task:**
* Fill in the blanks.

[Hint: use model.config to derive input and output channels]

In [None]:
# Input chnnels
in_channels = <YOUR_CODE>
# Output channels [Hint: this is D]
hidden_size = <YOUR_CODE>
# Define convoluton parameters
kernel_size = <YOUR_CODE>
stride = <YOUR_CODE>
padding = <YOUR_CODE>
# Define projection
projection_layer = nn.Conv2d(in_channels=in_channels,
                             out_channels=hidden_size,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=padding)

Compare your result with model projection layer.

Visualize learned projection convolution weights.

In [None]:
# Projection convolution weight
weight = model.vit.embeddings.patch_embeddings.projection.weight
weight = weight.detach().cpu().permute((2, 3, 1, 0)).numpy()
print ('Weight shape: ', weight.shape)
# Scale values for visualization
scaled_weight = MinMaxScaler().fit_transform(weight.reshape(-1, hidden_size))
scaled_weight = scaled_weight.reshape(PATCH_SIZE, PATCH_SIZE, in_channels, -1)
# Visualize first 25 kernels
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(5*1, 5*1), tight_layout=True)
axes = axes.flatten()
for i in range(25):
    axes[i].imshow(scaled_weight[:, :, :, i])
    axes[i].axis("off")

Learned convolution filters look like reasonable low-dimensional patterns of patches.

### 3.4 Inspect positional embeddings

Explore similarity between learned positional embeddings.

First, extract positional encodings. Do not account for classification token.

**Task:**
* Realize basic cos sim function.

In [None]:
# Do not forger to do detach
pos_embeddings = model.vit.embeddings.position_embeddings.squeeze().detach().cpu()[1:]
pos_embeddings.shape

In [None]:
# Normalize embeddings
norm_pos_embeddings = (pos_embeddings / pos_embeddings.norm(dim=1, keepdim=True))
# Compute cosine similarity
cos_sim = <YOUR_CODE>

Visalize cosine similairty

In [None]:
pos = plt.imshow(cos_sim)
plt.colorbar(pos, fraction=0.05)

There is a clear diagonal pattern indicating that position is the most similar to itself. Also, there are repeated diagonal patterns resembling positional encodings in NLP.

### 3.5 Attention Distance

In [None]:
img = dataset['validation'][0]['pixel_values']
outputs = model(img.unsqueeze(0).to('cuda'), output_attentions=True)
attn_tensor = torch.stack([x.detach().cpu() for x in outputs['attentions']]).squeeze().numpy()
# Attention has shape (batch_size x num_layers x num_heads x (num_pathes + 1) x (num_patches + 1))
print(attn_tensor.shape)

In [None]:
# Compute distances between patch coordinates
distance_matrix = compute_distance_matrix(PATCH_SIZE, NUM_PATCHES, H_NUM_PATCHES)[None, None, :]
# Do not account for classification token
mean_distances = attn_tensor[:, :, 1:, 1:] * distance_matrix
# Sum over attention per token
mean_distances = np.sum(mean_distances, axis=-1)
# Mean over tokens
mean_distances = np.mean(mean_distances, axis=-1)

In [None]:
num_heads = mean_distances.shape[-1]
for idx in range(len(mean_distances)):
    mean_dist = mean_distances[idx]
    x = [idx] * num_heads
    plt.scatter(x=x, y=mean_dist, label=f"transformer_block_{idx}")
plt.legend(loc="lower right")
plt.xlabel("Attention Head", fontsize=14)
plt.ylabel("Attention Distance", fontsize=14)
plt.title(model_name, fontsize=14)
plt.grid()
plt.show()

Different attention heads may have different attention distances using both local and global information of the image. When depth increases, transformer layers focus on global information.

### 3.6 Attention Heatmaps

Finally, visualize the attention maps.

First, use the same validation image with inverse transforms as an example.

In [None]:
img = dataset['validation'][0]['pixel_values']
img = img.permute((1, 2, 0)) * std + mean
img.shape

This time we will look at final layer and classification token.

In [None]:
last_attn_cls = attn_tensor[-1, :, 0, 1:].reshape((-1, H_NUM_PATCHES, W_NUM_PATCHES))
last_attn_cls.shape

Since the size of attention map is smaller that image size, use resize to make it similar to image.

In [None]:
last_attn_cls = np.stack([cv2.resize(a, dsize=(H, W)) for a in last_attn_cls])
last_attn_cls.shape

Plot image together with attention maps for each head.

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=6, figsize=(6*2, 2*2), tight_layout=True)
axes = axes.flatten()
for i, ax in enumerate(axes):
  ax.imshow(img)
  ax.imshow(last_attn_cls[i], alpha=0.6)
  ax.title.set_text(f"Attention head: {i}")
  ax.axis("off")

Attention maps are computed between the tokens of the same image. Thus, attention scores highlight which parts of images are important. This example illustrates the purpose of attention mechanism.