# Chapter 8: Attention Mechanism

This notebook covers **Chapter 8** of the Deep Learning in Hebrew book, focusing on the **Attention Mechanism** - one of the most revolutionary concepts in deep learning that has transformed natural language processing and many other domains.

## Overview

The **Attention Mechanism** is a fundamental breakthrough that allows models to focus on relevant parts of the input when producing each part of the output. Unlike traditional sequence-to-sequence models that compress all information into a single fixed-size vector, attention enables dynamic alignment between input and output sequences.

The attention mechanism was first introduced in the context of sequence-to-sequence models to address the bottleneck problem of encoding long sequences. However, its true power was revealed in 2017 with the paper **"Attention is All You Need"**, which introduced the **Transformer** architecture - a model that uses attention mechanisms exclusively, without any recurrent or convolutional layers.

This chapter will take you through:
1. **Sequence-to-Sequence Learning with Attention** - How attention solves the bottleneck problem
2. **Bahdanau and Luong Attention** - Early attention mechanisms
3. **Transformer Architecture** - The revolutionary attention-only model
4. **Self-Attention** - Attention within a single sequence
5. **Multi-Head Attention** - Parallel attention mechanisms
6. **Positional Encoding** - Representing sequence order without recurrence
7. **Transformer Applications** - BERT, GPT, and more

---

## Table of Contents

### 8.1 Sequence to Sequence Learning and Attention
- 8.1.1 [Attention in Seq2Seq Models](#811-attention-in-seq2seq-models)
- 8.1.2 [Bahdanau Attention and Luong Attention](#812-bahdanau-attention-and-luong-attention)

### 8.2 Transformer
- 8.2.1 [Positional Encoding](#821-positional-encoding)
- 8.2.2 [Self-Attention Layer](#822-self-attention-layer)
- 8.2.3 [Multi-Head Attention](#823-multi-head-attention)
- 8.2.4 [Transformer End to End](#824-transformer-end-to-end)
- 8.2.5 [Transformer Applications](#825-transformer-applications)

## Setup and Imports

Let's start by importing the necessary libraries for this chapter.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import math
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Set matplotlib style
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['font.size'] = 11

# 8.1 Sequence to Sequence Learning and Attention

## 8.1.1 Attention in Seq2Seq Models

Traditional sequence-to-sequence (Seq2Seq) models, as we saw in Chapter 6, use RNNs with encoder-decoder architecture. The input sequence passes through an **encoder** that creates a vector of known size representing the original sequence, including the order of elements and relationships between them. Then this vector passes through a **decoder** to decode the information in the vector and generate output in another domain.

For example, a Seq2Seq model can encode a sentence in one language into a certain vector and then decode the vector into a sentence in a second language.

### The Bottleneck Problem

The conventional way to create the vector and decode it was using different architectures of RNNs, such as deep networks containing memory components like LSTM and GRU. These models encountered a problem with long sequences, because the vector is limited in its ability to contain relationships between many elements.

The problem is that **all information from the input sequence must be compressed into a single fixed-size vector**. This creates a bottleneck:
- For short sequences, the vector may be sufficient
- For long sequences, information gets lost or diluted
- The decoder must generate the entire output from this single vector

### The Attention Solution

To deal with this problem, we can use a different approach - instead of creating a vector at the output of the encoder, we can use the hidden states of the encoder in combination with the hidden states of the decoder, and thus find dependencies between elements of the input sequence and elements of the output sequence (**general attention**) and relationships between elements of the input sequence itself (**self-attention**).

### How Attention Works

For example, in translation of the sentence "How was your day" to another language - in this case, the attention mechanism generates a new vector for each word in the output sequence, where each vector quantifies how much the current word in the output is related to each of the words in the original sentence.

In this way, each element of the output sequence can access each element of the input sequence. This mechanism is called **attention**.

### Key Benefits of Attention

1. **Dynamic Alignment**: Each output element can focus on different parts of the input
2. **No Information Bottleneck**: Direct access to all encoder hidden states
3. **Interpretability**: Attention weights show which input parts are important
4. **Handles Long Sequences**: No need to compress everything into one vector

### Attention Mechanism Overview

The attention mechanism computes a **context vector** for each output time step by:
1. Computing **alignment scores** between the current decoder state and all encoder states
2. Converting scores to **attention weights** using softmax
3. Computing a **weighted sum** of encoder states (context vector)
4. Using the context vector along with the decoder state to generate output

In [None]:
# Visualize Attention Mechanism Concept
fig, axes = plt.subplots(1, 2, figsize=(18, 7))

# Traditional Seq2Seq (without attention)
ax1 = axes[0]
ax1.set_xlim(-0.5, 6)
ax1.set_ylim(-0.5, 4)
ax1.set_aspect('equal')
ax1.axis('off')
ax1.set_title('Traditional Seq2Seq (No Attention)', fontsize=14, weight='bold', pad=20)

# Input sequence
input_words = ['How', 'was', 'your', 'day']
for i, word in enumerate(input_words):
    ax1.text(0, 3-i*0.8, word, fontsize=11, ha='center',
            bbox=dict(boxstyle='round', facecolor='lightblue', edgecolor='black', linewidth=2))
    if i < len(input_words) - 1:
        ax1.arrow(0.3, 3-i*0.8, 0.2, -0.4, head_width=0.08, head_length=0.06, fc='black', ec='black')

# Encoder
ax1.text(2, 2, 'Encoder\n(LSTM/GRU)', fontsize=10, ha='center', va='center',
        bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='black', linewidth=2))
ax1.arrow(0.8, 2, 0.8, 0, head_width=0.15, head_length=0.1, fc='black', ec='black')

# Bottleneck vector
ax1.text(3.5, 2, 'Context\nVector', fontsize=10, weight='bold', ha='center', va='center',
        bbox=dict(boxstyle='round', facecolor='orange', edgecolor='black', linewidth=2))
ax1.text(3.5, 1.2, '(Fixed Size)', fontsize=8, ha='center', style='italic', color='red')

# Decoder
ax1.text(5, 2, 'Decoder\n(LSTM/GRU)', fontsize=10, ha='center', va='center',
        bbox=dict(boxstyle='round', facecolor='lightcoral', edgecolor='black', linewidth=2))
ax1.arrow(4.2, 2, 0.5, 0, head_width=0.15, head_length=0.1, fc='black', ec='black')

# Output sequence
output_words = ['Comment', 'était', 'ta', 'journée']
for i, word in enumerate(output_words):
    ax1.text(5, 0.5-i*0.3, word, fontsize=9, ha='center',
            bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='black'))
    if i < len(output_words) - 1:
        ax1.arrow(5, 0.5-i*0.3-0.15, 0, -0.15, head_width=0.1, head_length=0.05, fc='black', ec='black')

# Problem annotation
ax1.annotate('Information\nBottleneck!', xy=(3.5, 1.2), xytext=(3.5, 0.3),
            arrowprops=dict(arrowstyle='->', color='red', lw=2),
            fontsize=10, weight='bold', color='red', ha='center')

# Seq2Seq with Attention
ax2 = axes[1]
ax2.set_xlim(-0.5, 7)
ax2.set_ylim(-0.5, 4.5)
ax2.set_aspect('equal')
ax2.axis('off')
ax2.set_title('Seq2Seq with Attention', fontsize=14, weight='bold', pad=20)

# Input sequence
for i, word in enumerate(input_words):
    ax2.text(0, 3.5-i*0.7, word, fontsize=10, ha='center',
            bbox=dict(boxstyle='round', facecolor='lightblue', edgecolor='black', linewidth=2))

# Encoder with hidden states
ax2.text(2, 3.5, 'Encoder', fontsize=10, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='black', linewidth=2))
for i in range(len(input_words)):
    ax2.text(2, 2.8-i*0.7, f'h_{i+1}', fontsize=9, ha='center',
            bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='black', alpha=0.7))

# Attention mechanism
ax2.text(4, 3.5, 'Attention', fontsize=10, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='orange', edgecolor='black', linewidth=2))
ax2.text(4, 2.8, 'Context', fontsize=9, ha='center',
        bbox=dict(boxstyle='round', facecolor='wheat', edgecolor='black'))
ax2.text(4, 2.1, 'Vector', fontsize=9, ha='center',
        bbox=dict(boxstyle='round', facecolor='wheat', edgecolor='black'))

# Attention connections (example for first output word)
for i in range(len(input_words)):
    # Connection from encoder hidden state to attention
    ax2.plot([2.3, 3.5], [2.8-i*0.7, 2.45], 'b-', linewidth=1.5, alpha=0.4)
    # Connection from attention to decoder
    ax2.plot([4.5, 5.2], [2.45, 2.1], 'b-', linewidth=1.5, alpha=0.4)

# Decoder
ax2.text(6, 2.1, 'Decoder', fontsize=10, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='lightcoral', edgecolor='black', linewidth=2))

# Output sequence
for i, word in enumerate(output_words):
    ax2.text(6, 0.5-i*0.3, word, fontsize=9, ha='center',
            bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='black'))

# Annotation
ax2.text(3.5, 0.2, 'Dynamic context for each output!', fontsize=9, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='green', linewidth=2))

plt.tight_layout()
plt.show()

print("Key Differences:")
print("=" * 60)
print("Traditional Seq2Seq:")
print("  - Single fixed-size context vector")
print("  - Information bottleneck for long sequences")
print("  - All input information compressed into one vector")
print("\nSeq2Seq with Attention:")
print("  - Dynamic context vector for each output step")
print("  - Direct access to all encoder hidden states")
print("  - Attention weights show which inputs are important")
print("  - No information bottleneck")

In [None]:
# Bahdanau Attention Implementation
class BahdanauAttention(nn.Module):
    """
    Bahdanau (Additive) Attention mechanism.
    """
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.hidden_size = hidden_size
        
        # Attention weights
        self.W_encoder = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W_decoder = nn.Linear(hidden_size, hidden_size, bias=False)
        self.v = nn.Linear(hidden_size, 1, bias=False)
    
    def forward(self, decoder_hidden, encoder_outputs):
        """
        Compute attention weights and context vector.
        
        Parameters:
        -----------
        decoder_hidden : tensor, shape (batch, hidden_size)
            Previous decoder hidden state
        encoder_outputs : tensor, shape (seq_len, batch, hidden_size)
            All encoder hidden states
        
        Returns:
        --------
        context : tensor, shape (batch, hidden_size)
            Context vector
        attention_weights : tensor, shape (batch, seq_len)
            Attention weights for visualization
        """
        seq_len, batch_size, _ = encoder_outputs.shape
        
        # Reshape for computation
        decoder_hidden = decoder_hidden.unsqueeze(0)  # (1, batch, hidden_size)
        
        # Compute alignment scores
        # score = v^T * tanh(W_e * h_e + W_d * h_d)
        encoder_proj = self.W_encoder(encoder_outputs)  # (seq_len, batch, hidden_size)
        decoder_proj = self.W_decoder(decoder_hidden)  # (1, batch, hidden_size)
        
        # Add and apply tanh
        combined = torch.tanh(encoder_proj + decoder_proj)  # (seq_len, batch, hidden_size)
        
        # Compute scores
        scores = self.v(combined).squeeze(-1)  # (seq_len, batch)
        scores = scores.transpose(0, 1)  # (batch, seq_len)
        
        # Compute attention weights
        attention_weights = F.softmax(scores, dim=1)  # (batch, seq_len)
        
        # Compute context vector as weighted sum
        encoder_outputs = encoder_outputs.transpose(0, 1)  # (batch, seq_len, hidden_size)
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)  # (batch, 1, hidden_size)
        context = context.squeeze(1)  # (batch, hidden_size)
        
        return context, attention_weights

# Luong Attention Implementation
class LuongAttention(nn.Module):
    """
    Luong (Multiplicative) Attention mechanism.
    """
    def __init__(self, hidden_size, method='general'):
        super(LuongAttention, self).__init__()
        self.hidden_size = hidden_size
        self.method = method
        
        if method == 'general':
            self.W = nn.Linear(hidden_size, hidden_size, bias=False)
        elif method == 'concat':
            self.W = nn.Linear(hidden_size * 2, hidden_size, bias=False)
            self.v = nn.Linear(hidden_size, 1, bias=False)
    
    def forward(self, decoder_hidden, encoder_outputs):
        """
        Compute attention weights and context vector.
        
        Parameters:
        -----------
        decoder_hidden : tensor, shape (batch, hidden_size)
            Current decoder hidden state
        encoder_outputs : tensor, shape (seq_len, batch, hidden_size)
            All encoder hidden states
        
        Returns:
        --------
        context : tensor, shape (batch, hidden_size)
            Context vector
        attention_weights : tensor, shape (batch, seq_len)
            Attention weights
        """
        seq_len, batch_size, _ = encoder_outputs.shape
        
        # Transpose encoder_outputs once at the beginning: (seq_len, batch, hidden_size) -> (batch, seq_len, hidden_size)
        encoder_outputs_batch = encoder_outputs.transpose(0, 1)  # (batch, seq_len, hidden_size)
        
        if self.method == 'dot':
            # Dot product: score = h_t^T * h_s
            decoder_hidden = decoder_hidden.unsqueeze(1)  # (batch, 1, hidden_size)
            scores = torch.bmm(decoder_hidden, encoder_outputs_batch.transpose(1, 2))  # (batch, 1, seq_len)
            scores = scores.squeeze(1)  # (batch, seq_len)
        
        elif self.method == 'general':
            # General: score = h_t^T * W * h_s
            decoder_hidden = decoder_hidden.unsqueeze(1)  # (batch, 1, hidden_size)
            encoder_proj = self.W(encoder_outputs_batch)  # (batch, seq_len, hidden_size)
            scores = torch.bmm(decoder_hidden, encoder_proj.transpose(1, 2))  # (batch, 1, seq_len)
            scores = scores.squeeze(1)  # (batch, seq_len)
        
        elif self.method == 'concat':
            # Concat: score = v^T * tanh(W * [h_t; h_s])
            decoder_hidden = decoder_hidden.unsqueeze(1).expand(-1, seq_len, -1)  # (batch, seq_len, hidden_size)
            combined = torch.cat([decoder_hidden, encoder_outputs_batch], dim=2)  # (batch, seq_len, 2*hidden_size)
            scores = self.v(torch.tanh(self.W(combined)))  # (batch, seq_len, 1)
            scores = scores.squeeze(2)  # (batch, seq_len)
        
        # Compute attention weights
        attention_weights = F.softmax(scores, dim=1)  # (batch, seq_len)
        
        # Compute context vector (encoder_outputs_batch is already in correct shape: batch, seq_len, hidden_size)
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs_batch)  # (batch, 1, hidden_size)
        context = context.squeeze(1)  # (batch, hidden_size)
        
        return context, attention_weights

# Test both attention mechanisms
hidden_size = 64
seq_len = 10
batch_size = 4

# Create dummy encoder outputs and decoder hidden state
encoder_outputs = torch.randn(seq_len, batch_size, hidden_size)
decoder_hidden_bahdanau = torch.randn(batch_size, hidden_size)
decoder_hidden_luong = torch.randn(batch_size, hidden_size)

# Test Bahdanau Attention
bahdanau_attn = BahdanauAttention(hidden_size)
context_bahdanau, weights_bahdanau = bahdanau_attn(decoder_hidden_bahdanau, encoder_outputs)

# Test Luong Attention (general method)
luong_attn = LuongAttention(hidden_size, method='general')
context_luong, weights_luong = luong_attn(decoder_hidden_luong, encoder_outputs)

print("Attention Mechanisms Comparison:")
print("=" * 60)
print(f"Hidden size: {hidden_size}")
print(f"Sequence length: {seq_len}")
print(f"Batch size: {batch_size}")
print(f"\nBahdanau Attention:")
print(f"  Context shape: {context_bahdanau.shape}")
print(f"  Attention weights shape: {weights_bahdanau.shape}")
print(f"  Attention weights sum (should be 1.0): {weights_bahdanau.sum(dim=1)}")
print(f"\nLuong Attention (general):")
print(f"  Context shape: {context_luong.shape}")
print(f"  Attention weights shape: {weights_luong.shape}")
print(f"  Attention weights sum (should be 1.0): {weights_luong.sum(dim=1)}")

# Visualize attention weights
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Bahdanau attention weights
im1 = axes[0].imshow(weights_bahdanau.detach().numpy(), cmap='Blues', aspect='auto')
axes[0].set_title('Bahdanau Attention Weights', fontsize=12, weight='bold')
axes[0].set_xlabel('Encoder Position', fontsize=11)
axes[0].set_ylabel('Batch Item', fontsize=11)
plt.colorbar(im1, ax=axes[0])

# Luong attention weights
im2 = axes[1].imshow(weights_luong.detach().numpy(), cmap='Reds', aspect='auto')
axes[1].set_title('Luong Attention Weights', fontsize=12, weight='bold')
axes[1].set_xlabel('Encoder Position', fontsize=11)
axes[1].set_ylabel('Batch Item', fontsize=11)
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

print("\nKey Differences:")
print("=" * 60)
print("Bahdanau (Additive):")
print("  - Uses previous decoder hidden state h_{t-1}")
print("  - Score: v^T * tanh(W_e*h_e + W_d*h_{t-1})")
print("  - More parameters, more expressive")
print("\nLuong (Multiplicative):")
print("  - Uses current decoder hidden state h_t")
print("  - Score: h_t^T * W * h_s (general) or h_t^T * h_s (dot)")
print("  - Fewer parameters, more efficient")

# 8.2 Transformer

After the attention mechanism began to gain momentum, an architecture based on attention was invented that does not use memory components at all. This architecture is called **Transformer**.

The Transformer offers two new elements in order to find relationships between elements in a certain sequence - **positional encoding** and **self-attention**.

The Transformer architecture, introduced in the paper "Attention is All You Need" (Vaswani et al., 2017), revolutionized deep learning by showing that attention mechanisms alone, without any recurrent or convolutional layers, can achieve state-of-the-art results on many tasks.

## 8.2.1 Positional Encoding

RNN-based architectures inherently use the order of elements in the sequence. Another approach to representing the order between sequence elements is called **positional encoding**, in which we add to each element of the input a piece of information about its position in the sequence, and thus we can use architectures without RNN.

### Why Positional Encoding?

Since the Transformer uses only attention mechanisms (which are permutation-invariant), we need a way to inject information about the order of elements in the sequence. Positional encoding provides this information.

### Mathematical Formulation

Formally, for an input sequence $x \in \mathbb{R}^d$, we compute a vector of dimension $d \times 1$ as follows:

For position $t$ and dimension $i$:

$$p_t(i) = \begin{cases}
\sin(\omega_i t), & \text{if } i \text{ is even} \\
\cos(\omega_i t), & \text{if } i \text{ is odd}
\end{cases}$$

where $\omega_i = \frac{1}{10000^{2i/d}}$

The full positional encoding vector for position $t$ is:

$$p_t = \begin{bmatrix}
\sin(\omega_1 t) \\
\cos(\omega_1 t) \\
\sin(\omega_2 t) \\
\cos(\omega_2 t) \\
\vdots \\
\sin(\omega_{d/2} t) \\
\cos(\omega_{d/2} t)
\end{bmatrix}_{d \times 1}$$

### Understanding Positional Encoding

To understand how this vector contains meaning of order between things, we can think of it as a continuous version of binary representation. If we want to take a sequence of numbers and represent them in binary form, we can see that the higher the bit weight, the less frequently it changes, and in fact the frequency of bit change is an indication of its position.

The MSB (Most Significant Bit) changes at the lowest frequency, while the LSB (Least Significant Bit) changes at the highest frequency.

Similarly, we can use trigonometric functions with frequencies that decrease and increase - this is essentially the vector $p$ - it contains many trigonometric functions with decreasing frequencies, and according to the frequency added to each element in the sequence, we can encode the position of the element.

### Key Properties

1. **Relative Position Information**: For any pair of functions with the same frequency $[\sin(\omega_i t), \cos(\omega_i t)]$, we can linearly transform them to represent a different position:

$$M \cdot \begin{bmatrix} \sin(\omega_i t) \\ \cos(\omega_i t) \end{bmatrix} = \begin{bmatrix} \sin(\omega_i t + \phi) \\ \cos(\omega_i t + \phi) \end{bmatrix}$$

where $M = \begin{bmatrix} \cos(\omega_i \phi) & \sin(\omega_i \phi) \\ -\sin(\omega_i \phi) & \cos(\omega_i \phi) \end{bmatrix}$

This allows the model to easily learn relative positions between any two positions.

2. **Extrapolation**: The sinusoidal nature allows the model to extrapolate to sequence lengths longer than those seen during training.

3. **Deterministic**: Positional encodings are fixed and not learned (though learned positional embeddings are also used in some models).

In [None]:
# Positional Encoding Implementation
class PositionalEncoding(nn.Module):
    """
    Positional encoding for Transformer.
    """
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # Compute div_term: 1 / (10000^(2i/d_model))
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                            (-math.log(10000.0) / d_model))
        
        # Apply sin to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cos to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension: (1, max_len, d_model)
        pe = pe.unsqueeze(0)
        
        # Register as buffer (not a parameter, but part of model state)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        Add positional encoding to input.
        
        Parameters:
        -----------
        x : tensor, shape (batch, seq_len, d_model)
            Input embeddings
        
        Returns:
        --------
        x + positional_encoding : tensor, shape (batch, seq_len, d_model)
        """
        # x is (batch, seq_len, d_model)
        # pe is (1, max_len, d_model)
        # We need to slice pe to match seq_len
        x = x + self.pe[:, :x.size(1), :]
        return x

# Test Positional Encoding
d_model = 128
max_len = 100
seq_len = 20
batch_size = 2

pos_encoding = PositionalEncoding(d_model, max_len)

# Create dummy input embeddings
x = torch.randn(batch_size, seq_len, d_model)

# Add positional encoding
x_with_pos = pos_encoding(x)

print("Positional Encoding:")
print("=" * 60)
print(f"Model dimension (d_model): {d_model}")
print(f"Sequence length: {seq_len}")
print(f"Input shape: {x.shape}")
print(f"Output shape: {x_with_pos.shape}")
print(f"Positional encoding shape: {pos_encoding.pe.shape}")

# Visualize positional encoding
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: Positional encoding for different positions
ax1 = axes[0, 0]
positions = [0, 10, 20, 50]
colors = ['blue', 'green', 'red', 'orange']
for pos, color in zip(positions, colors):
    ax1.plot(pos_encoding.pe[0, pos, :64].numpy(), label=f'Position {pos}', color=color, linewidth=2)
ax1.set_title('Positional Encoding Values (First 64 dims)', fontsize=12, weight='bold')
ax1.set_xlabel('Dimension', fontsize=11)
ax1.set_ylabel('Encoding Value', fontsize=11)
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Heatmap of positional encoding
ax2 = axes[0, 1]
pe_matrix = pos_encoding.pe[0, :50, :].numpy()
im = ax2.imshow(pe_matrix.T, cmap='coolwarm', aspect='auto', interpolation='nearest')
ax2.set_title('Positional Encoding Heatmap (50 positions, all dims)', fontsize=12, weight='bold')
ax2.set_xlabel('Position', fontsize=11)
ax2.set_ylabel('Dimension', fontsize=11)
plt.colorbar(im, ax=ax2)

# Plot 3: Frequency analysis - how different frequencies change
ax3 = axes[1, 0]
positions = torch.arange(0, 50)
# Plot different frequency components
for i in [0, 2, 4, 8, 16]:
    freq_component = pos_encoding.pe[0, :50, i].numpy()
    ax3.plot(positions, freq_component, label=f'Dim {i} (freq {i//2})', linewidth=2)
ax3.set_title('Different Frequency Components', fontsize=12, weight='bold')
ax3.set_xlabel('Position', fontsize=11)
ax3.set_ylabel('Encoding Value', fontsize=11)
ax3.legend()
ax3.grid(True, alpha=0.3)

# Plot 4: Relative position encoding (demonstration)
ax4 = axes[1, 1]
# Show how position 10 can be transformed to position 20
pos_10 = pos_encoding.pe[0, 10, :8].numpy()
pos_20 = pos_encoding.pe[0, 20, :8].numpy()
x_pos = np.arange(8)
width = 0.35
ax4.bar(x_pos - width/2, pos_10, width, label='Position 10', alpha=0.7)
ax4.bar(x_pos + width/2, pos_20, width, label='Position 20', alpha=0.7)
ax4.set_title('Positional Encoding Comparison (First 8 dims)', fontsize=12, weight='bold')
ax4.set_xlabel('Dimension', fontsize=11)
ax4.set_ylabel('Encoding Value', fontsize=11)
ax4.legend()
ax4.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\nKey Properties of Positional Encoding:")
print("=" * 60)
print("1. Sinusoidal functions with different frequencies")
print("2. Lower dimensions have higher frequencies (change more rapidly)")
print("3. Higher dimensions have lower frequencies (change more slowly)")
print("4. Allows model to learn relative positions")
print("5. Can extrapolate to longer sequences than seen in training")

## 8.2.2 Self-Attention Layer

In addition to positional encoding, the idea arose to perform attention not only between input elements and output elements, but also between input elements themselves. This is called **self-attention**.

### Motivation for Self-Attention

For each element in the sequence, we want to create a new representation that will represent an element in the original sequence plus information about its relationship to the other elements. The idea is to take each element in the sequence, and calculate its similarity to all other elements in the sequence.

Similar elements (close) in the sequence will receive high similarity values, while different elements (distant) in the sequence will give low values. In NLP, this could be words that are likely to appear together, and in images, this could be similar pixels.

The relationship between elements that have a connection between them is calculated using an inner product between representation vectors of the elements. Each inner product between two elements gives a coefficient that is a real number, and thus we can sum the product of all coefficients with the original elements, and get a new representation of the original element that also contains a relationship between the current element and similar elements in the sequence.

In other words, we can look at the vector containing the relationships of an element in the sequence as its new representation that reflects its relationships with the rest of the sequence elements.

### Query, Key, and Value

Formally, to calculate self-attention, we create three matrices of coefficients for the input sequence. These matrices are called **Value**, **Key**, and **Query**, where each row in each matrix corresponds to an element of the input.

Using these matrices, we calculate the attention score:

$$\text{Attention}(Q, K, V) = \text{SoftMax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \cdot V$$

where:
- $Q$ (Query): What am I looking for?
- $K$ (Key): What do I offer?
- $V$ (Value): What information do I contain?

### Step-by-Step Computation

1. **Create Q, K, V matrices**: For input sequence $x$, we compute:
   - $Q = x W_Q$ (Query matrix)
   - $K = x W_K$ (Key matrix)
   - $V = x W_V$ (Value matrix)
   
   where $W_Q, W_K, W_V$ are learned weight matrices.

2. **Compute attention scores**: For each query $q_i$ and key $k_j$:
   $$score_{ij} = \frac{q_i^T k_j}{\sqrt{d_k}}$$
   
   The scaling factor $\sqrt{d_k}$ prevents the dot products from growing too large.

3. **Apply softmax**: Convert scores to probabilities:
   $$w_{ij} = \text{SoftMax}(score_{ij}) = \frac{\exp(score_{ij})}{\sum_{k=1}^{n} \exp(score_{ik})}$$

4. **Weighted sum**: Compute output for each position:
   $$z_i = \sum_{j=1}^{n} w_{ij} v_j$$

### Understanding the Computation

For an input sequence $x$, we get three matrices, where each element in the original sequence $x_i$ creates a row in each of the matrices. When we take the row $q_i = Q \cdot x_i$ and multiply it by each of the rows in matrix $K$, we get a new vector, where each element $j$ says how much there is a relationship between elements $i, j$ in the original sequence.

Performing this multiplication for the entire input sequence creates a new matrix where each row represents the relationship between a certain element and the rest of the sequence elements. This multiplication is essentially $Q \cdot K^T$, where each $q_i^T k_j$ represents the relationship between element $i$ and element $j$.

We scale the result by the dimension of the embedding to prevent large gradients, and normalize using SoftMax. In this way, we get a matrix of numbers in the range [0,1], representing as mentioned the relationship between every two elements in the original sequence.

### The Output

For each element in the sequence, we compute a new representation by weighting the vectors in $V$:

$$z_i = \sum_{j=1}^{n} w_{ij} v_j = \sum_{j=1}^{n} \frac{\exp(q_i^T k_j)}{\sum_{k=1}^{n} \exp(q_i^T k_k)} v_j$$

The resulting sequence $z$ is a new representation of the sequence, where each $z_i$ represents element $i$ in the original sequence together with information about the relationships between it and the rest of the sequence elements. The resulting sequence can be passed through a decoder or several additional layers, and thus perform various tasks.

In [None]:
# Self-Attention Implementation from Scratch
def self_attention_naive(Q, K, V):
    """
    Naive implementation of self-attention for understanding.
    
    Parameters:
    -----------
    Q : tensor, shape (batch, seq_len, d_k)
        Query matrix
    K : tensor, shape (batch, seq_len, d_k)
        Key matrix
    V : tensor, shape (batch, seq_len, d_v)
        Value matrix
    
    Returns:
    --------
    output : tensor, shape (batch, seq_len, d_v)
        Self-attention output
    attention_weights : tensor, shape (batch, seq_len, seq_len)
        Attention weights for visualization
    """
    d_k = Q.size(-1)
    
    # Step 1: Compute attention scores
    # Q @ K^T: (batch, seq_len, d_k) @ (batch, d_k, seq_len) -> (batch, seq_len, seq_len)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Step 2: Apply softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # Step 3: Weighted sum of values
    # attention_weights @ V: (batch, seq_len, seq_len) @ (batch, seq_len, d_v) -> (batch, seq_len, d_v)
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

# PyTorch Self-Attention Module
class SelfAttention(nn.Module):
    """
    Self-Attention layer implementation.
    """
    def __init__(self, d_model, d_k=None, d_v=None):
        super(SelfAttention, self).__init__()
        self.d_model = d_model
        self.d_k = d_k if d_k is not None else d_model
        self.d_v = d_v if d_v is not None else d_model
        
        # Linear projections for Q, K, V
        self.W_Q = nn.Linear(d_model, self.d_k)
        self.W_K = nn.Linear(d_model, self.d_k)
        self.W_V = nn.Linear(d_model, self.d_v)
    
    def forward(self, x):
        """
        Forward pass through self-attention.
        
        Parameters:
        -----------
        x : tensor, shape (batch, seq_len, d_model)
            Input sequence
        
        Returns:
        --------
        output : tensor, shape (batch, seq_len, d_v)
            Self-attention output
        attention_weights : tensor, shape (batch, seq_len, seq_len)
            Attention weights
        """
        batch_size, seq_len, _ = x.shape
        
        # Compute Q, K, V
        Q = self.W_Q(x)  # (batch, seq_len, d_k)
        K = self.W_K(x)  # (batch, seq_len, d_k)
        V = self.W_V(x)  # (batch, seq_len, d_v)
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply softmax
        attention_weights = F.softmax(scores, dim=-1)
        
        # Weighted sum
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights

# Test Self-Attention
d_model = 64
seq_len = 10
batch_size = 2

# Create input sequence
x = torch.randn(batch_size, seq_len, d_model)

# Test naive implementation
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

output_naive, weights_naive = self_attention_naive(Q, K, V)

# Test PyTorch module
self_attn = SelfAttention(d_model)
output_module, weights_module = self_attn(x)

print("Self-Attention Implementation:")
print("=" * 60)
print(f"Input shape: {x.shape}")
print(f"Model dimension (d_model): {d_model}")
print(f"Sequence length: {seq_len}")
print(f"\nNaive Implementation:")
print(f"  Output shape: {output_naive.shape}")
print(f"  Attention weights shape: {weights_naive.shape}")
print(f"  Attention weights sum (should be 1.0): {weights_naive.sum(dim=-1)[0, 0]}")
print(f"\nPyTorch Module:")
print(f"  Output shape: {output_module.shape}")
print(f"  Attention weights shape: {weights_module.shape}")
print(f"  Attention weights sum (should be 1.0): {weights_module.sum(dim=-1)[0, 0]}")

# Visualize attention weights
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Attention weights for first batch item
im1 = axes[0].imshow(weights_naive[0].detach().numpy(), cmap='Blues', aspect='auto', vmin=0, vmax=1)
axes[0].set_title('Self-Attention Weights (Naive, First Batch)', fontsize=12, weight='bold')
axes[0].set_xlabel('Key Position (j)', fontsize=11)
axes[0].set_ylabel('Query Position (i)', fontsize=11)
axes[0].set_xticks(range(seq_len))
axes[0].set_yticks(range(seq_len))
plt.colorbar(im1, ax=axes[0])

im2 = axes[1].imshow(weights_module[0].detach().numpy(), cmap='Reds', aspect='auto', vmin=0, vmax=1)
axes[1].set_title('Self-Attention Weights (Module, First Batch)', fontsize=12, weight='bold')
axes[1].set_xlabel('Key Position (j)', fontsize=11)
axes[1].set_ylabel('Query Position (i)', fontsize=11)
axes[1].set_xticks(range(seq_len))
axes[1].set_yticks(range(seq_len))
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

print("\nSelf-Attention Key Points:")
print("=" * 60)
print("1. Each position attends to all positions (including itself)")
print("2. Attention weights sum to 1.0 for each query position")
print("3. Higher weights indicate stronger relationships")
print("4. Self-attention is permutation-invariant (needs positional encoding)")
print("5. Complexity: O(n²) where n is sequence length")

### Detailed Example: Self-Attention Computation

Let's trace through a concrete example to understand how self-attention works:

**Input**: Sequence of 3 words: ["The", "cat", "sat"]

1. **Embedding**: Each word is converted to a vector (e.g., 4-dimensional for simplicity)
   - "The" → [0.2, 0.1, 0.3, 0.1]
   - "cat" → [0.5, 0.3, 0.2, 0.4]
   - "sat" → [0.3, 0.2, 0.4, 0.3]

2. **Create Q, K, V**: Apply learned transformations
   - $Q = X W_Q$, $K = X W_K$, $V = X W_V$

3. **Compute Scores**: For each word, compute similarity with all words
   - "The" attends to ["The", "cat", "sat"]
   - "cat" attends to ["The", "cat", "sat"]
   - "sat" attends to ["The", "cat", "sat"]

4. **Softmax**: Convert scores to probabilities
   - Each row sums to 1.0

5. **Weighted Sum**: Combine value vectors according to attention weights
   - Output for "cat" might heavily weight "sat" (they often appear together)
   - Output for "The" might weight all words more evenly (it's a common word)

### Why Self-Attention Works

1. **Parallel Computation**: Unlike RNNs, all positions can be computed in parallel
2. **Long-Range Dependencies**: Direct connections between any two positions
3. **Interpretability**: Attention weights show which words are related
4. **Flexible Relationships**: Can learn different types of relationships (syntactic, semantic, etc.)

In [None]:
# Detailed Self-Attention Example
def demonstrate_self_attention():
    """
    Step-by-step demonstration of self-attention computation.
    """
    # Simple example: 3 words, 4-dimensional embeddings
    words = ["The", "cat", "sat"]
    seq_len = len(words)
    d_model = 4
    
    # Create simple embeddings (normally these would be learned)
    embeddings = torch.tensor([
        [0.2, 0.1, 0.3, 0.1],  # "The"
        [0.5, 0.3, 0.2, 0.4],  # "cat"
        [0.3, 0.2, 0.4, 0.3]   # "sat"
    ], dtype=torch.float32)
    
    # Simple weight matrices (normally learned)
    W_Q = torch.eye(d_model) * 0.5
    W_K = torch.eye(d_model) * 0.5
    W_V = torch.eye(d_model) * 0.5
    
    # Compute Q, K, V
    Q = torch.matmul(embeddings, W_Q)
    K = torch.matmul(embeddings, W_K)
    V = torch.matmul(embeddings, W_V)
    
    # Compute attention scores
    scores = torch.matmul(Q, K.T) / math.sqrt(d_model)
    
    # Apply softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # Compute output
    output = torch.matmul(attention_weights, V)
    
    print("Self-Attention Step-by-Step Example:")
    print("=" * 60)
    print(f"Words: {words}")
    print(f"\n1. Input Embeddings (shape {embeddings.shape}):")
    for i, word in enumerate(words):
        print(f"   {word:5s}: {embeddings[i].numpy()}")
    
    print(f"\n2. Attention Scores (before softmax, shape {scores.shape}):")
    print(scores.numpy())
    
    print(f"\n3. Attention Weights (after softmax, shape {attention_weights.shape}):")
    print(attention_weights.numpy())
    print("\n   Each row sums to 1.0 (probabilities)")
    
    print(f"\n4. Output (shape {output.shape}):")
    for i, word in enumerate(words):
        print(f"   {word:5s}: {output[i].numpy()}")
    
    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Attention weights heatmap
    im1 = axes[0].imshow(attention_weights.numpy(), cmap='Blues', aspect='auto', vmin=0, vmax=1)
    axes[0].set_title('Attention Weights Matrix', fontsize=12, weight='bold')
    axes[0].set_xlabel('Key Position (attended to)', fontsize=11)
    axes[0].set_ylabel('Query Position (attending from)', fontsize=11)
    axes[0].set_xticks(range(seq_len))
    axes[0].set_yticks(range(seq_len))
    axes[0].set_xticklabels(words)
    axes[0].set_yticklabels(words)
    
    # Add text annotations
    for i in range(seq_len):
        for j in range(seq_len):
            text = axes[0].text(j, i, f'{attention_weights[i, j].item():.2f}',
                              ha="center", va="center", color="black", fontweight='bold')
    
    plt.colorbar(im1, ax=axes[0])
    
    # Bar chart showing attention for "cat"
    axes[1].bar(words, attention_weights[1].numpy(), color=['blue', 'green', 'red'], alpha=0.7)
    axes[1].set_title('Attention Weights for "cat"', fontsize=12, weight='bold')
    axes[1].set_ylabel('Attention Weight', fontsize=11)
    axes[1].set_ylim([0, 1])
    axes[1].grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()
    
    return attention_weights, output

# Run demonstration
attn_weights, output = demonstrate_self_attention()

## 8.2.3 Multi-Head Attention

We can use the self-attention mechanism multiple times in parallel. Each time we get three matrices $(Q_h, K_h, V_h)$, and with them we calculate the new representations of the sequence elements (attention score). Each such mechanism is called an **attention head**, and the combination of multiple attention heads is called **Multi-Head Attention**.

In this way, for each input element $x_i$, there are several different representations $z_i^h$, which can be multiplied by a weight matrix $W_O$ and get the weighted representation of the element using multiple attention heads.

### Why Multi-Head Attention?

Different attention heads can learn to focus on different aspects of the relationships:
- **Head 1**: Might focus on syntactic relationships (subject-verb, etc.)
- **Head 2**: Might focus on semantic relationships (synonyms, related concepts)
- **Head 3**: Might focus on long-range dependencies
- **Head 4**: Might focus on positional relationships

By combining multiple heads, the model can capture richer, more nuanced relationships.

### Mathematical Formulation

For $h$ attention heads:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W_O$$

where each head is:

$$\text{head}_i = \text{Attention}(Q W_Q^i, K W_K^i, V W_V^i)$$

and:
- $W_Q^i, W_K^i, W_V^i$ are learned projections for head $i$
- $W_O$ is the output projection matrix
- Typically $d_k = d_v = d_{model} / h$ to keep total parameters similar

### Implementation Details

1. **Split dimensions**: Divide $d_{model}$ into $h$ heads
2. **Parallel computation**: Compute attention for all heads simultaneously
3. **Concatenate**: Combine all head outputs
4. **Project**: Apply final linear transformation

### Benefits

- **Diverse representations**: Each head can specialize
- **Parallel computation**: All heads computed simultaneously
- **Richer modeling**: Captures multiple relationship types
- **Scalability**: Can increase model capacity by adding heads

In [None]:
# Multi-Head Attention Implementation
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention implementation.
    """
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V (for all heads)
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        
        # Output projection
        self.W_O = nn.Linear(d_model, d_model)
    
    def forward(self, Q, K, V, mask=None):
        """
        Forward pass through multi-head attention.
        
        Parameters:
        -----------
        Q, K, V : tensor, shape (batch, seq_len, d_model)
            Query, Key, Value matrices (can have different seq_len for encoder-decoder attention)
        mask : tensor, optional, shape (batch, seq_len_q, seq_len_k)
            Mask to prevent attention to certain positions
        
        Returns:
        --------
        output : tensor, shape (batch, seq_len_q, d_model)
            Multi-head attention output
        attention_weights : tensor, shape (batch, num_heads, seq_len_q, seq_len_k)
            Attention weights for all heads
        """
        batch_size, seq_len_q, _ = Q.shape
        _, seq_len_k, _ = K.shape
        _, seq_len_v, _ = V.shape
        
        # Project and reshape for multi-head
        # Q: (batch, seq_len_q, d_model) -> (batch, seq_len_q, num_heads, d_k) -> (batch, num_heads, seq_len_q, d_k)
        Q = self.W_Q(Q).view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
        # K: (batch, seq_len_k, d_model) -> (batch, seq_len_k, num_heads, d_k) -> (batch, num_heads, seq_len_k, d_k)
        K = self.W_K(K).view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
        # V: (batch, seq_len_v, d_model) -> (batch, seq_len_v, num_heads, d_k) -> (batch, num_heads, seq_len_v, d_k)
        V = self.W_V(V).view(batch_size, seq_len_v, self.num_heads, self.d_k).transpose(1, 2)
        
        # Compute attention for all heads in parallel
        # Q: (batch, num_heads, seq_len_q, d_k)
        # K: (batch, num_heads, seq_len_k, d_k)
        # scores: (batch, num_heads, seq_len_q, seq_len_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask if provided
        # mask should be (batch, 1, seq_len_q, seq_len_k) or broadcastable
        if mask is not None:
            # Expand mask to match scores dimensions if needed
            if mask.dim() == 4 and mask.size(1) == 1:
                # mask is (batch, 1, seq_len_q, seq_len_k), expand to (batch, num_heads, seq_len_q, seq_len_k)
                mask = mask.expand(batch_size, self.num_heads, seq_len_q, seq_len_k)
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Softmax
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, d_k)
        # -> (batch, num_heads, seq_len, d_k)
        output = torch.matmul(attention_weights, V)
        
        # Concatenate heads
        # (batch, num_heads, seq_len_q, d_k) -> (batch, seq_len_q, num_heads, d_k)
        output = output.transpose(1, 2).contiguous()
        # -> (batch, seq_len_q, d_model)
        output = output.view(batch_size, seq_len_q, self.d_model)
        
        # Final projection
        output = self.W_O(output)
        
        return output, attention_weights

# Test Multi-Head Attention
d_model = 64
num_heads = 8
seq_len = 10
batch_size = 2

# Create input
x = torch.randn(batch_size, seq_len, d_model)

# Test multi-head attention
multi_head_attn = MultiHeadAttention(d_model, num_heads)
output, attention_weights = multi_head_attn(x, x, x)

print("Multi-Head Attention:")
print("=" * 60)
print(f"Input shape: {x.shape}")
print(f"Model dimension (d_model): {d_model}")
print(f"Number of heads: {num_heads}")
print(f"Dimension per head (d_k): {d_model // num_heads}")
print(f"\nOutput shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")
print(f"  - Batch: {attention_weights.shape[0]}")
print(f"  - Heads: {attention_weights.shape[1]}")
print(f"  - Query positions: {attention_weights.shape[2]}")
print(f"  - Key positions: {attention_weights.shape[3]}")

# Visualize attention from different heads
fig, axes = plt.subplots(2, 4, figsize=(18, 8))
fig.suptitle('Attention Weights from Different Heads (First Batch Item)', 
              fontsize=14, weight='bold', y=1.02)

for head in range(num_heads):
    row = head // 4
    col = head % 4
    ax = axes[row, col]
    
    im = ax.imshow(attention_weights[0, head].detach().numpy(), 
                   cmap='viridis', aspect='auto', vmin=0, vmax=1)
    ax.set_title(f'Head {head+1}', fontsize=10, weight='bold')
    ax.set_xlabel('Key Position', fontsize=9)
    ax.set_ylabel('Query Position', fontsize=9)
    plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.show()

print("\nMulti-Head Attention Benefits:")
print("=" * 60)
print("1. Each head can learn different types of relationships")
print("2. Parallel computation of all heads")
print("3. Richer representation by combining multiple perspectives")
print("4. Total parameters similar to single-head (d_k = d_model / num_heads)")
print("5. Attention weights show what each head focuses on")

In [None]:
# Complete Transformer Implementation
class FeedForward(nn.Module):
    """Feed-forward network in Transformer."""
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

class EncoderLayer(nn.Module):
    """Single encoder layer."""
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention with residual
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output))
        
        # Feed-forward with residual
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output))
        
        return x

class DecoderLayer(nn.Module):
    """Single decoder layer."""
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.enc_dec_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
    
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Masked self-attention
        attn_output, _ = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout1(attn_output))
        
        # Encoder-decoder attention
        attn_output, _ = self.enc_dec_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout2(attn_output))
        
        # Feed-forward
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout3(ff_output))
        
        return x

class Transformer(nn.Module):
    """Complete Transformer model."""
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8, 
                 num_encoder_layers=6, num_decoder_layers=6, d_ff=2048, max_len=5000, dropout=0.1):
        super(Transformer, self).__init__()
        
        # Embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        
        # Positional encoding
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        
        # Encoder
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_encoder_layers)
        ])
        
        # Decoder
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_decoder_layers)
        ])
        
        # Output projection
        self.output_proj = nn.Linear(d_model, tgt_vocab_size)
        
        self.d_model = d_model
    
    def generate_mask(self, src, tgt):
        """Generate masks for encoder and decoder."""
        batch_size = src.size(0)
        src_len = src.size(1)
        tgt_len = tgt.size(1)
        
        # Source mask for encoder self-attention: (batch, 1, 1, src_len)
        # This prevents attending to padding tokens in the source
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)  # (batch, 1, 1, src_len)
        
        # Target mask for decoder self-attention: (batch, 1, tgt_len, tgt_len)
        # This prevents attending to padding tokens and future tokens
        tgt_padding_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)  # (batch, 1, tgt_len, 1)
        look_ahead_mask = torch.triu(torch.ones(tgt_len, tgt_len, device=tgt.device), diagonal=1).bool()  # (tgt_len, tgt_len)
        look_ahead_mask = look_ahead_mask.unsqueeze(0).unsqueeze(0)  # (1, 1, tgt_len, tgt_len)
        tgt_mask = tgt_padding_mask & ~look_ahead_mask  # (batch, 1, tgt_len, tgt_len)
        
        # Source mask for encoder-decoder attention: (batch, 1, tgt_len, src_len)
        # This prevents attending to padding tokens in the source when querying from target
        src_mask_enc_dec = (src != 0).unsqueeze(1).unsqueeze(2)  # (batch, 1, 1, src_len)
        src_mask_enc_dec = src_mask_enc_dec.expand(batch_size, 1, tgt_len, src_len)  # (batch, 1, tgt_len, src_len)
        
        return src_mask, tgt_mask, src_mask_enc_dec
    
    def forward(self, src, tgt):
        # Embeddings + positional encoding
        src_emb = self.pos_encoding(self.src_embedding(src) * math.sqrt(self.d_model))
        tgt_emb = self.pos_encoding(self.tgt_embedding(tgt) * math.sqrt(self.d_model))
        
        # Generate masks
        src_mask, tgt_mask, src_mask_enc_dec = self.generate_mask(src, tgt)
        
        # Encoder
        enc_output = src_emb
        for layer in self.encoder_layers:
            enc_output = layer(enc_output, src_mask)
        
        # Decoder
        dec_output = tgt_emb
        for layer in self.decoder_layers:
            dec_output = layer(dec_output, enc_output, src_mask_enc_dec, tgt_mask)
        
        # Output projection
        output = self.output_proj(dec_output)
        
        return output

# Test Transformer
src_vocab_size = 1000
tgt_vocab_size = 1000
d_model = 128
num_heads = 8
num_layers = 3

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, 
                         num_layers, num_layers, d_ff=512, max_len=100)

# Create dummy sequences
src_seq_len = 10
tgt_seq_len = 12
batch_size = 2

src = torch.randint(1, src_vocab_size, (batch_size, src_seq_len))
tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_seq_len))

# Forward pass
output = transformer(src, tgt)

print("Complete Transformer Model:")
print("=" * 60)
print(f"Source vocabulary size: {src_vocab_size}")
print(f"Target vocabulary size: {tgt_vocab_size}")
print(f"Model dimension: {d_model}")
print(f"Number of heads: {num_heads}")
print(f"Number of encoder/decoder layers: {num_layers}")
print(f"\nInput shapes:")
print(f"  Source: {src.shape}")
print(f"  Target: {tgt.shape}")
print(f"Output shape: {output.shape}")

# Count parameters
total_params = sum(p.numel() for p in transformer.parameters())
trainable_params = sum(p.numel() for p in transformer.parameters() if p.requires_grad)

print(f"\nParameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")

# Visualize Transformer architecture
fig, ax = plt.subplots(figsize=(14, 10))
ax.set_xlim(-0.5, 8)
ax.set_ylim(-0.5, 10)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('Transformer Architecture', fontsize=16, weight='bold', pad=20)

# Input embeddings
ax.text(0, 9, 'Input\nEmbeddings', fontsize=10, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightblue', edgecolor='black', linewidth=2))
ax.text(0, 7.5, 'Positional\nEncoding', fontsize=10, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='black', linewidth=2))

# Encoder stack
ax.text(2, 8.5, 'Encoder', fontsize=12, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='orange', edgecolor='black', linewidth=2))
for i in range(3):
    y_pos = 7 - i * 1.5
    # Encoder layer
    ax.text(2, y_pos, f'Encoder Layer {i+1}', fontsize=9, ha='center',
            bbox=dict(boxstyle='round', facecolor='wheat', edgecolor='black'))
    ax.text(2, y_pos-0.5, 'Self-Attn → FFN', fontsize=8, ha='center', style='italic')
    if i < 2:
        ax.arrow(2, y_pos-0.7, 0, -0.5, head_width=0.15, head_length=0.1, fc='black', ec='black')

# Encoder output
ax.text(2, 1.5, 'Encoder\nOutput', fontsize=10, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='black', linewidth=2))

# Target embeddings
ax.text(5, 9, 'Target\nEmbeddings', fontsize=10, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightblue', edgecolor='black', linewidth=2))
ax.text(5, 7.5, 'Positional\nEncoding', fontsize=10, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='black', linewidth=2))

# Decoder stack
ax.text(7, 8.5, 'Decoder', fontsize=12, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='lightcoral', edgecolor='black', linewidth=2))
for i in range(3):
    y_pos = 7 - i * 1.5
    # Decoder layer
    ax.text(7, y_pos, f'Decoder Layer {i+1}', fontsize=9, ha='center',
            bbox=dict(boxstyle='round', facecolor='pink', edgecolor='black'))
    ax.text(7, y_pos-0.5, 'Masked Self-Attn', fontsize=8, ha='center', style='italic')
    ax.text(7, y_pos-0.7, '→ Enc-Dec Attn → FFN', fontsize=8, ha='center', style='italic')
    if i < 2:
        ax.arrow(7, y_pos-0.9, 0, -0.5, head_width=0.15, head_length=0.1, fc='black', ec='black')

# Decoder output
ax.text(7, 1.5, 'Output\nProjection', fontsize=10, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='black', linewidth=2))

# Arrows
ax.arrow(0.5, 8.25, 1, 0, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(0.5, 7.75, 1, 0, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(2.5, 8.25, 0, -0.5, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(5.5, 8.25, 1, 0, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(5.5, 7.75, 1, 0, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(7.5, 8.25, 0, -0.5, head_width=0.15, head_length=0.1, fc='black', ec='black')

# Encoder-decoder attention connection
ax.plot([2.5, 2.5, 6.5, 6.5], [4, 3, 3, 4], 'g--', linewidth=2, alpha=0.7)
ax.arrow(6.5, 3.5, 0, 0.3, head_width=0.15, head_length=0.1, fc='green', ec='green')
ax.text(4.5, 3, 'Encoder-Decoder\nAttention', fontsize=9, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='green', linewidth=2))

plt.tight_layout()
plt.show()

## 8.2.5 Transformer Applications

The Transformer presented state-of-the-art performance on many tasks using attention alone. Besides the high level of performance, the training process of Transformer is much faster than convolutional networks or recurrent networks. Like other models, with Transformer we can also perform **transfer learning**, i.e., take a Transformer trained on one task and adapt it to another task.

In practice, not all applications use the entire Transformer - some applications use only the encoder or only the decoder, depending on the task.

### Machine Translation

Translation of sentences between different languages is a trivial application of the Transformer. The task is to take a sentence and output a sentence in another language, and this is done by representing the original sentence in a new way using self-attention and then converting it using Encoder-Decoder Attention to another sentence.

### Bidirectional Encoder Representations from Transformers (BERT)

A language model based on encoder only. A language model is a function that receives text as input and returns the distribution over words for the next word. The most familiar and intuitive language model is automatic completion, which suggests the most likely word or words given what the user has typed so far.

Since self-attention is used together with the contexts between different words, the encoder in Transformer can function as a language model in an appropriate way. The developers of BERT used only the encoder - they trained the model on sentences where each time randomly they do masking of words in the sentence, and the task is to predict the masked words.

**BERT Key Features:**
- **Bidirectional**: Uses both left and right context
- **Masked Language Modeling**: Predicts masked words in sentences
- **Next Sentence Prediction**: Learns relationships between sentences
- **Pre-training + Fine-tuning**: Trained on large corpus, then fine-tuned for specific tasks

### Generative Pre-training (GPT)

A model for predicting the next word in a sentence. We can take a sentence that is cut in the middle, and examine what word the decoder suggests. We input a cut sentence to the decoder and then go through many words and check which word is most likely to be the next word in the sentence.

In fact, the Key is the sentence that was input, and the Query that enters is each time a different word in the dictionary, and thus using attention we examine which Query will fit best with the Key.

**GPT Key Features:**
- **Decoder-only**: Uses only the decoder part
- **Autoregressive**: Generates text one token at a time
- **Causal masking**: Prevents looking at future tokens
- **Pre-training**: Trained on large text corpus
- **Fine-tuning**: Adapted for specific tasks

### Other Applications

1. **DETR (Detection Transformer)**: Object detection in images using Transformer
2. **Vision Transformer (ViT)**: Image classification using Transformer
3. **Speech Recognition**: Transformer for speech-to-text
4. **Text Summarization**: Using encoder-decoder architecture
5. **Question Answering**: BERT-based models
6. **Named Entity Recognition**: Sequence labeling with Transformer

In [None]:
# Example: BERT-style Encoder-only Model
class BERTEncoder(nn.Module):
    """BERT-style encoder-only model for language modeling."""
    def __init__(self, vocab_size, d_model=512, num_heads=8, num_layers=6, d_ff=2048, max_len=512):
        super(BERTEncoder, self).__init__()
        
        # Token and positional embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        
        # Encoder layers
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff) 
            for _ in range(num_layers)
        ])
        
        self.d_model = d_model
    
    def forward(self, x, mask=None):
        # Embeddings + positional encoding
        x = self.pos_encoding(self.token_embedding(x) * math.sqrt(self.d_model))
        
        # Pass through encoder layers
        for layer in self.encoder_layers:
            x = layer(x, mask)
        
        return x

# Example: GPT-style Decoder-only Model
class GPTDecoder(nn.Module):
    """GPT-style decoder-only model for text generation."""
    def __init__(self, vocab_size, d_model=512, num_heads=8, num_layers=6, d_ff=2048, max_len=512):
        super(GPTDecoder, self).__init__()
        
        # Token and positional embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        
        # Decoder layers (without encoder-decoder attention)
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff) 
            for _ in range(num_layers)
        ])
        
        # Output projection
        self.output_proj = nn.Linear(d_model, vocab_size)
        
        self.d_model = d_model
    
    def forward(self, x, mask=None):
        # Embeddings + positional encoding
        x = self.pos_encoding(self.token_embedding(x) * math.sqrt(self.d_model))
        
        # Pass through decoder layers (masked self-attention only)
        # For GPT, we don't use encoder-decoder attention
        for layer in self.decoder_layers:
            # Use only masked self-attention, skip encoder-decoder attention
            attn_output, _ = layer.self_attn(x, x, x, mask)
            x = layer.norm1(x + layer.dropout1(attn_output))
            ff_output = layer.feed_forward(x)
            x = layer.norm3(x + layer.dropout3(ff_output))
        
        # Output projection
        output = self.output_proj(x)
        
        return output

# Test BERT and GPT
vocab_size = 1000
d_model = 128
num_heads = 8
num_layers = 3

bert = BERTEncoder(vocab_size, d_model, num_heads, num_layers, d_ff=512)
gpt = GPTDecoder(vocab_size, d_model, num_heads, num_layers, d_ff=512)

# Create dummy input
seq_len = 15
batch_size = 2
input_ids = torch.randint(1, vocab_size, (batch_size, seq_len))

# Forward pass
bert_output = bert(input_ids)
gpt_output = gpt(input_ids)

print("Transformer Applications:")
print("=" * 60)
print(f"\nBERT (Encoder-only):")
print(f"  Input shape: {input_ids.shape}")
print(f"  Output shape: {bert_output.shape}")
print(f"  Use case: Language understanding, masked language modeling")
print(f"\nGPT (Decoder-only):")
print(f"  Input shape: {input_ids.shape}")
print(f"  Output shape: {gpt_output.shape}")
print(f"  Use case: Text generation, next token prediction")

# Visualize architecture differences
fig, axes = plt.subplots(1, 2, figsize=(18, 6))

# BERT architecture
ax1 = axes[0]
ax1.set_xlim(-0.5, 4)
ax1.set_ylim(-0.5, 6)
ax1.set_aspect('equal')
ax1.axis('off')
ax1.set_title('BERT Architecture (Encoder-only)', fontsize=14, weight='bold', pad=20)

ax1.text(0, 5, 'Input\nTokens', fontsize=10, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightblue', edgecolor='black', linewidth=2))
ax1.text(0, 3.5, 'Embeddings\n+ Position', fontsize=10, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='black', linewidth=2))

ax1.text(2, 4.5, 'Encoder\nStack', fontsize=11, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='orange', edgecolor='black', linewidth=2))
for i in range(3):
    y_pos = 3.5 - i * 0.8
    ax1.text(2, y_pos, f'Layer {i+1}', fontsize=9, ha='center',
            bbox=dict(boxstyle='round', facecolor='wheat', edgecolor='black'))
    if i < 2:
        ax1.arrow(2, y_pos-0.4, 0, -0.3, head_width=0.1, head_length=0.08, fc='black', ec='black')

ax1.text(2, 0.5, 'Contextualized\nRepresentations', fontsize=10, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='black', linewidth=2))

ax1.arrow(0.5, 4.25, 1, 0, head_width=0.1, head_length=0.08, fc='black', ec='black')
ax1.arrow(0.5, 3.75, 1, 0, head_width=0.1, head_length=0.08, fc='black', ec='black')
ax1.arrow(2.5, 4.5, 0, -0.5, head_width=0.1, head_length=0.08, fc='black', ec='black')
ax1.arrow(2.5, 0.5, 0, 0.3, head_width=0.1, head_length=0.08, fc='black', ec='black')

# GPT architecture
ax2 = axes[1]
ax2.set_xlim(-0.5, 4)
ax2.set_ylim(-0.5, 6)
ax2.set_aspect('equal')
ax2.axis('off')
ax2.set_title('GPT Architecture (Decoder-only)', fontsize=14, weight='bold', pad=20)

ax2.text(0, 5, 'Input\nTokens', fontsize=10, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightblue', edgecolor='black', linewidth=2))
ax2.text(0, 3.5, 'Embeddings\n+ Position', fontsize=10, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='black', linewidth=2))

ax2.text(2, 4.5, 'Decoder\nStack', fontsize=11, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='lightcoral', edgecolor='black', linewidth=2))
for i in range(3):
    y_pos = 3.5 - i * 0.8
    ax2.text(2, y_pos, f'Layer {i+1}', fontsize=9, ha='center',
            bbox=dict(boxstyle='round', facecolor='pink', edgecolor='black'))
    ax2.text(2, y_pos-0.25, 'Masked\nSelf-Attn', fontsize=8, ha='center', style='italic')
    if i < 2:
        ax2.arrow(2, y_pos-0.4, 0, -0.3, head_width=0.1, head_length=0.08, fc='black', ec='black')

ax2.text(2, 0.5, 'Next Token\nPrediction', fontsize=10, ha='center',
        bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='black', linewidth=2))

ax2.arrow(0.5, 4.25, 1, 0, head_width=0.1, head_length=0.08, fc='black', ec='black')
ax2.arrow(0.5, 3.75, 1, 0, head_width=0.1, head_length=0.08, fc='black', ec='black')
ax2.arrow(2.5, 4.5, 0, -0.5, head_width=0.1, head_length=0.08, fc='black', ec='black')
ax2.arrow(2.5, 0.5, 0, 0.3, head_width=0.1, head_length=0.08, fc='black', ec='black')

plt.tight_layout()
plt.show()

## Summary

In this chapter, we've covered the **Attention Mechanism** - one of the most important breakthroughs in deep learning:

### 8.1 Sequence to Sequence Learning and Attention
- **8.1.1 Attention in Seq2Seq Models**: Solved the information bottleneck problem, dynamic alignment, interpretability
- **8.1.2 Bahdanau and Luong Attention**: Early attention mechanisms, additive vs multiplicative attention

### 8.2 Transformer
- **8.2.1 Positional Encoding**: Sinusoidal encoding to represent sequence order without recurrence
- **8.2.2 Self-Attention Layer**: Query, Key, Value matrices, attention scores, weighted aggregation
- **8.2.3 Multi-Head Attention**: Parallel attention heads, diverse relationship modeling
- **8.2.4 Transformer End to End**: Complete encoder-decoder architecture, residual connections, layer normalization
- **8.2.5 Transformer Applications**: BERT, GPT, machine translation, and more

### Key Takeaways

- **Attention** allows models to focus on relevant parts of input when generating output
- **Self-attention** finds relationships within a single sequence
- **Multi-head attention** captures diverse types of relationships in parallel
- **Positional encoding** injects order information without recurrence
- **Transformer** uses only attention mechanisms, no RNNs or CNNs
- **BERT** uses encoder-only architecture for language understanding
- **GPT** uses decoder-only architecture for text generation

The Transformer architecture has revolutionized NLP and many other domains, enabling models like BERT, GPT, and modern LLMs that have transformed the field of AI.