In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from get_device import get_device

device = get_device()
print(f"Using device: {device}")

Using device: mps


In [2]:
from pathlib import Path

text = Path('../../data/tiny-shakespeare.txt').read_text()

In [3]:
print(text[0:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [4]:
import torch

class CharTokenizer:
    """
    A simple character-level tokenizer for converting text to and from numerical IDs.

    This tokenizer builds a vocabulary from a given text and provides methods
    to encode strings into integer tensors and decode them back into strings.
    It is a basic but essential component for character-level language models.

    Attributes:
        token_id_for_char (dict): A mapping from each character in the vocabulary
            to its unique integer ID.
        char_for_token_id (dict): A reverse mapping from each integer ID back
            to its corresponding character.
    """
  
    # def __init__(self, vocabulary):
    #     """
    #     Initializes the CharTokenizer with a predefined vocabulary.

    #     Args:
    #         vocabulary (list or str): An ordered list or string of unique
    #             characters that will form the tokenizer's vocabulary.
    #     """
    #     self.token_id_for_char = {char: token_id for token_id, char in enumerate(vocabulary)}
    #     self.char_for_token_id = {token_id: char for token_id, char in enumerate(vocabulary)}

    def __init__(self, vocabulary):
        """
        Initializes the CharTokenizer with a predefined vocabulary.

        Args:
            vocabulary (list or str): An ordered list or string of unique
                characters that will form the tokenizer's vocabulary.
        """        
        
        unique_vocab = list(dict.fromkeys(vocabulary))  # Preserves order, removes duplicates
        self.token_id_for_char = {char: token_id for token_id, char in enumerate(unique_vocab)}
        self.char_for_token_id = {token_id: char for token_id, char in enumerate(unique_vocab)}

    @staticmethod
    def train_from_text(text):
        """
        Creates a new CharTokenizer instance by building a vocabulary from text.

        This static method scans the input text, finds all unique characters,
        sorts them to ensure a consistent vocabulary order, and then creates
        a new tokenizer instance based on this vocabulary.

        Args:
            text (str): The corpus of text from which to build the vocabulary.

        Returns:
            CharTokenizer: A new instance of the tokenizer trained on the text.
        """
        vocabulary = sorted(list(set(text)))
        return CharTokenizer(vocabulary)

    def encode(self, text):
        """
        Encodes a string of text into a tensor of token IDs.

        Each character in the input string is mapped to its corresponding integer
        ID from the vocabulary.

        Args:
            text (str): The string to encode.

        Returns:
            torch.Tensor: A 1D tensor of dtype torch.long containing the sequence
                of token IDs.
        """
        token_ids = []
        for char in text:
            token_ids.append(self.token_id_for_char[char])
        return torch.tensor(token_ids, dtype=torch.long)

    def decode(self, token_ids):
        """
        Decodes a tensor of token IDs back into a string of text.

        Each integer ID in the input tensor is mapped back to its corresponding
        character from the vocabulary.

        Args:
            token_ids (torch.Tensor): A 1D tensor of token IDs to decode.

        Returns:
            str: The decoded string.
        """
        chars = []
        # .tolist() converts the tensor to a standard Python list for iteration.
        for token_id in token_ids.tolist():
            chars.append(self.char_for_token_id[token_id])
        return ''.join(chars)

    def vocabulary_size(self):
        """
        Returns the total number of unique characters in the vocabulary.

        Returns:
            int: The size of the vocabulary.
        """
        return len(self.token_id_for_char)

In [5]:
tokenizer = CharTokenizer.train_from_text(text)

In [6]:
print(tokenizer.encode("Hello world"))
print(tokenizer.decode(tokenizer.encode("Hello world")))

tensor([20, 43, 50, 50, 53,  1, 61, 53, 56, 50, 42])
Hello world


In [7]:
print(f"Vocabulary size: {tokenizer.vocabulary_size()}")

Vocabulary size: 65



---

# Language Model Dataset: Token Sequence Generation

###### Purpose and Functionality

The `TokenIdsDataset` class creates training data for autoregressive language models by converting a sequence of tokens into input-target pairs. This dataset implements the standard "next token prediction" training paradigm where the model learns to predict the subsequent token given a context window.

###### Core Concept: Shifted Sequences

**Training Objective:** Given a sequence of tokens, predict the next token
**Implementation:** For each position, create pairs where the target is the input shifted by one position

###### Detailed Example Walkthrough

**Setup:**
```python
data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])  # Token sequence
block_size = 4  # Context window size
dataset = TokenIdsDataset(data, block_size)
```

**Dataset Length Calculation:**
```python
len(dataset) = len(data) - block_size = 9 - 4 = 5
# 5 possible starting positions for complete sequences
```

**Sample Generation:**
```python
# Position 0: x = [1, 2, 3, 4], y = [2, 3, 4, 5]
# Position 1: x = [2, 3, 4, 5], y = [3, 4, 5, 6]  
# Position 2: x = [3, 4, 5, 6], y = [4, 5, 6, 7]
# Position 3: x = [4, 5, 6, 7], y = [5, 6, 7, 8]
# Position 4: x = [5, 6, 7, 8], y = [6, 7, 8, 9]
```

###### Token-by-Token Prediction Logic

For each input-target pair, the model learns multiple next-token predictions simultaneously:

**Example with Position 0:**
```python
Input:  [1, 2, 3, 4]
Target: [2, 3, 4, 5]

# Training signals:
# Given context [1] → predict 2
# Given context [1, 2] → predict 3  
# Given context [1, 2, 3] → predict 4
# Given context [1, 2, 3, 4] → predict 5
```

###### Implementation Analysis

**Memory Efficiency:**
The dataset doesn't store all possible sequences but generates them on-demand using tensor slicing, making it memory-efficient for large corpora.

**Boundary Handling:**
```python
def __len__(self):
    return len(self.data) - self.block_size
```
This ensures every generated sequence has exactly `block_size` input tokens and `block_size` target tokens, preventing index overflow.

**Tensor Slicing:**
```python
x = self.data[pos:pos + block_size]        # Input: 4 tokens
y = self.data[pos + 1:pos + 1 + block_size] # Target: 4 tokens (shifted)
```

###### Integration with Training Loop

**DataLoader Usage:**
```python
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch_x, batch_y in dataloader:
    # batch_x shape: (32, block_size)  
    # batch_y shape: (32, block_size)
    predictions = model(batch_x)
    loss = criterion(predictions.view(-1, vocab_size), batch_y.view(-1))
```

The `-1` in `.view(-1, vocab_size)` tells PyTorch to automatically calculate that dimension based on the tensor's total size. Here's why this is necessary:

###### Tensor Shape Problem

**Before reshaping:**
```python
# predictions shape: (batch_size, block_size, vocab_size)
# batch_y shape: (batch_size, block_size)

# Example with batch_size=32, block_size=4, vocab_size=1000:
predictions.shape = torch.Size([32, 4, 1000])
batch_y.shape = torch.Size([32, 4])
```

**Loss function requirement:**
Most loss functions expect:
- Predictions: 2D tensor `(num_samples, num_classes)`
- Targets: 1D tensor `(num_samples,)`

###### What `.view(-1, vocab_size)` Does

**Automatic dimension calculation:**
```python
predictions.view(-1, vocab_size)
# PyTorch calculates: total_elements / vocab_size = first_dimension
# (32 * 4 * 1000) / 1000 = 128
# Result shape: (128, 1000)

batch_y.view(-1)  
# Flattens to: (128,)
```

**Step-by-step breakdown:**
```python
# Original: (32, 4, 1000) - 32 batches, 4 tokens each, 1000 vocab probabilities
# Reshaped: (128, 1000) - 128 individual token predictions, 1000 vocab probabilities

# Original targets: (32, 4) - 32 batches, 4 target tokens each  
# Reshaped: (128,) - 128 individual target tokens
```

###### Why Use -1 Instead of Hard-coding?

**Flexibility:**
```python
# Hard-coded (brittle):
predictions.view(32 * 4, vocab_size)  # Breaks if batch_size changes

# Auto-calculated (robust):  
predictions.view(-1, vocab_size)  # Works with any batch_size
```

**Real example:**
```python
batch_size = 32
block_size = 4
vocab_size = 1000

# Before reshaping
print(predictions.shape)     # torch.Size([32, 4, 1000])
print(batch_y.shape)         # torch.Size([32, 4])

# After reshaping  
print(predictions.view(-1, vocab_size).shape)  # torch.Size([128, 1000])
print(batch_y.view(-1).shape)                  # torch.Size([128])
```

###### What This Accomplishes

The reshaping converts from "batch-of-sequences" format to "individual-predictions" format, where each token prediction is treated as a separate classification problem. This allows the loss function to compute the cross-entropy between each predicted token distribution and its corresponding target token.

The `-1` makes the code robust to different batch sizes without requiring manual calculation of the flattened dimension.


###### Practical Considerations

**Context Length Trade-offs:**
- Larger `block_size`: Better long-range dependencies, more memory usage
- Smaller `block_size`: Less memory, limited context understanding

**Data Utilization:**
From a sequence of length N with block size B, the dataset generates N-B training examples, maximizing data utilization through overlapping windows.

**Computational Efficiency:**
The sliding window approach creates multiple training examples from a single sequence, effectively augmenting the training data without additional storage requirements.

This dataset design is fundamental to training transformer-based language models, providing the structured input-target pairs necessary for learning autoregressive text generation through next-token prediction.

---

In [8]:
import torch
from torch.utils.data import Dataset

class TokenIdsDataset(Dataset):
  
  """
  A PyTorch Dataset for creating input-target pairs for language model training.

  This dataset takes a long sequence of token IDs and a specified block size
  (context length) to generate pairs of (input, target) tensors. The input `x`
  is a chunk of the data, and the target `y` is the same chunk shifted by one
  position to the right. This setup is standard for training a model to predict
  the next token in a sequence.

  For example, if the data is [1, 2, 3, 4, 5] and block_size is 3:
  - A possible `x` would be [1, 2, 3].
  - The corresponding `y` would be [2, 3, 4].
  """
  def __init__(self, data, block_size):
    """
    Initializes the dataset.

    Args:
        data (torch.Tensor): A 1D tensor containing the entire sequence of
            token IDs for the text corpus.
        block_size (int): The context length or the size of the input
            sequences to be generated.
    """
    self.data = data
    self.block_size = block_size

  def __len__(self):
    """
    Returns the total number of possible sequences that can be generated.

    The length is the total number of tokens minus the block size, as this
    represents the number of possible starting positions for a full sequence.

    Returns:
        int: The total number of samples in the dataset.
    """
    return len(self.data) - self.block_size

  def __getitem__(self, pos):
    """
    Retrieves a single input-target pair at a given position.

    Args:
        pos (int): The starting index in the data tensor from which to
            create the sequence.

    Returns:
        tuple[torch.Tensor, torch.Tensor]: A tuple containing the input
            tensor `x` and the target tensor `y`.
    """
    # Ensure the requested position is valid.
    assert pos < len(self.data) - self.block_size

    # The input sequence starts at `pos` and has length `block_size`.
    x = self.data[pos:pos + self.block_size]
    # The target sequence is shifted by one token to the right.
    y = self.data[pos + 1:pos + 1 + self.block_size]
    return x, y


---

The `//` operator performs **floor division** (integer division), which is crucial for ensuring the head size is always an integer.

###### Floor Division vs Regular Division

**Regular division (`/`):**
```python
embedding_dim = 768
heads_num = 12

head_size = embedding_dim / heads_num
# Result: 64.0 (float)
```

**Floor division (`//`):**
```python
head_size = embedding_dim // heads_num  
# Result: 64 (integer)
```

###### Why This Matters for Neural Networks

**Integer requirement:**
Neural network dimensions must be integers. You cannot have 64.5 neurons or create a tensor with fractional dimensions.

**Example where it makes a difference:**
```python
embedding_dim = 770  # Not perfectly divisible
heads_num = 12

regular_division = embedding_dim / heads_num  # 64.16666... (float)
floor_division = embedding_dim // heads_num   # 64 (integer)
```

###### Multi-Head Attention Context

In transformer architecture, the embedding dimension must be evenly divided among attention heads:
```python
# Each attention head gets head_size dimensions
# Total: heads_num × head_size = embedding_dim

12 heads × 64 dimensions = 768 total embedding dimensions
```

**Configuration validation:**
```python
assert embedding_dim % heads_num == 0, "embedding_dim must be divisible by heads_num"
head_size = embedding_dim // heads_num
```

The `//` operator ensures you get a clean integer division result, which is essential for creating properly sized tensor operations in the attention mechanism. Using regular division would produce floats that would cause errors when used as tensor dimension specifications.

---

In [9]:
config = {
  "vocabulary_size": tokenizer.vocabulary_size(),
  "context_size": 256,
  "embedding_dim": 768,
  "heads_num": 12,
  "layers_num": 10,
  "dropout_rate": 0.1,
  "use_bias": False,
}

config["head_size"] = config["embedding_dim"] // config["heads_num"]


---

# Single-Head Attention Mechanism Implementation

###### Core Purpose

The `AttentionHead` class implements a single attention head from the multi-head attention mechanism used in transformer architectures. It performs the fundamental attention operation: allowing each token to attend to (focus on) relevant tokens in the sequence while respecting causal constraints for language modeling.

###### Architecture Components

**Linear Projection Layers:**
```python
self.Q_weights = nn.Linear(config["embedding_dim"], config["head_size"], config["use_bias"])
self.K_weights = nn.Linear(config["embedding_dim"], config["head_size"], config["use_bias"])  
self.V_weights = nn.Linear(config["embedding_dim"], config["head_size"], config["use_bias"])
```

These layers project the input embeddings into three different subspaces:
- **Query (Q)**: What the current token is looking for
- **Key (K)**: What each token represents/offers  
- **Value (V)**: The actual content each token contributes

###### Causal Attention Mask

```python
casual_attention_mask = torch.tril(torch.ones(config["context_size"], config["context_size"]))
self.register_buffer('casual_attention_mask', casual_attention_mask)
```

**Purpose:** Ensures tokens can only attend to previous tokens (including themselves), not future tokens. This is crucial for autoregressive language modeling.

**Structure:** Lower triangular matrix where:
- 1 = allowed attention (current and previous positions)
- 0 = blocked attention (future positions)

**Example for context_size=4:**
```
[[1, 0, 0, 0],
 [1, 1, 0, 0], 
 [1, 1, 1, 0],
 [1, 1, 1, 1]]
```

###### Forward Pass Breakdown

**Step 1: Linear Projections**
```python
Q = self.Q_weights(input) # (B, C, head_size)
K = self.K_weights(input) # (B, C, head_size)  
V = self.V_weights(input) # (B, C, head_size)
```
Input shape: `(batch_size, sequence_length, embedding_dim)`
Output shapes: `(batch_size, sequence_length, head_size)`

**Step 2: Attention Score Computation**
```python
attention_scores = Q @ K.transpose(1, 2)  # (B, C, C)
```
Computes similarity between queries and keys using dot product. Result is a `(batch_size, sequence_length, sequence_length)` matrix where `attention_scores[i,j]` represents how much token `i` should attend to token `j`.

**Step 3: Causal Masking**
```python
attention_scores = attention_scores.masked_fill(
    self.casual_attention_mask[:tokens_num,:tokens_num] == 0,
    -torch.inf
)
```
Sets attention scores for future positions to negative infinity, ensuring they become zero after softmax.

**Step 4: Scaled Dot-Product Attention**
```python
attention_scores = attention_scores / (K.shape[-1] ** 0.5)
```
Scales by `√(head_size)` to prevent softmax saturation with large dot products.

**Step 5: Attention Probabilities**
```python
attention_scores = torch.softmax(attention_scores, dim=-1)
attention_scores = self.dropout(attention_scores)
```
Converts scores to probabilities that sum to 1 for each query position, then applies dropout for regularization.

**Step 6: Weighted Value Aggregation**
```python
return attention_scores @ V # (B, C, head_size)
```
Multiplies attention probabilities with values to get the final attended representation.

###### Mathematical Formulation

The complete attention operation can be expressed as:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T + M}{\sqrt{d_k}}\right)V$$

Where:
- $M$ is the causal mask (0 for allowed, $-\infty$ for blocked)
- $d_k$ is the head size (key dimension)

###### Key Design Decisions

**Head Size Calculation:** `head_size = embedding_dim // heads_num` ensures the total dimension is preserved when multiple heads are concatenated.

**Buffer Registration:** Using `register_buffer` for the mask ensures it moves with the model to GPU/CPU without being treated as a trainable parameter.

**Dropout Placement:** Applied to attention weights rather than the final output, providing regularization on the attention patterns themselves.

This implementation forms the building block for multi-head attention, where multiple such heads operate in parallel and their outputs are concatenated to capture different types of relationships in the data.

---


---

```python
attention_scores = Q @ K.transpose(1, 2)  # Shape: (B, T, T)
```

The `@` symbol is the **matrix multiplication** operator in Python.

In the context of PyTorch and NumPy, it's used to perform matrix multiplication on tensors or arrays. It's a more readable, infix alternative to calling a function like `torch.matmul()`.

##### In Your Code: `attention_scores = Q @ K.transpose(1, 2)`

Let's break down this specific line:

1.  **`Q`**: This is the "Query" tensor, with a shape of `(Batch, Tokens, Head_size)`.
2.  **`K`**: This is the "Key" tensor, also with a shape of `(Batch, Tokens, Head_size)`.
3.  **`K.transpose(1, 2)`**: This transposes the Key tensor, swapping its last two dimensions. Its new shape becomes `(Batch, Head_size, Tokens)`. This is done to make the dimensions compatible for matrix multiplication.
4.  **`Q @ ...`**: This performs the matrix multiplication:
    *   `Q` shape: `(B, T, H)`
    *   `K.transpose` shape: `(B, H, T)`
    *   Resulting `attention_scores` shape: `(B, T, T)`

The resulting `(B, T, T)` tensor holds the attention scores, where each token in the sequence has a score indicating how much it should "attend to" every other token.

---

In [None]:
import torch
import torch.nn as nn

class AttentionHead(nn.Module):
    """
    A single head of self-attention for a transformer model.

    This module implements the scaled dot-product attention mechanism. It takes
    a sequence of token embeddings and computes a new representation for each
    token by attending to all other tokens in the sequence. It learns three
    linear projections (Query, Key, Value) to transform the input embeddings.

    The key components are:
    - Q, K, V linear layers to project the input.
    - A causal mask to prevent tokens from attending to future tokens.
    - Scaled dot-product attention calculation.
    - Dropout for regularization.
    """
    def __init__(self, config):
        """
        Initializes the AttentionHead module.

        Args:
            config (dict): A configuration dictionary containing the following keys:
                - "embedding_dim" (int): The dimensionality of the input token embeddings.
                - "head_size" (int): The dimensionality of the Query, Key, and Value projections.
                - "use_bias" (bool): Whether to use a bias term in the linear layers.
                - "dropout_rate" (float): The dropout rate to apply to the attention scores.
                - "context_size" (int): The maximum sequence length (block size).
        """
        super().__init__()

        # Linear layers to project input embeddings into Query, Key, and Value spaces.
        self.Q_weights = nn.Linear(config["embedding_dim"], config["head_size"], bias=config["use_bias"])
        self.K_weights = nn.Linear(config["embedding_dim"], config["head_size"], bias=config["use_bias"])
        self.V_weights = nn.Linear(config["embedding_dim"], config["head_size"], bias=config["use_bias"])

        # Dropout layer to regularize attention scores.
        self.dropout = nn.Dropout(config["dropout_rate"])

        # Create a lower triangular matrix for the causal attention mask.
        # This prevents tokens from attending to future tokens in the sequence.
        casual_attention_mask = torch.tril(torch.ones(config["context_size"], config["context_size"]))
        
        # `register_buffer` makes the mask a part of the module's state, but not
        # a parameter to be trained. This ensures it's moved to the correct
        # device (e.g., GPU) along with the model.
        self.register_buffer('casual_attention_mask', casual_attention_mask)


    def forward(self, input_embeddings):
        """
        Performs the forward pass of the attention mechanism.

        Args:
            input_embeddings (torch.Tensor): A tensor of shape (B, T, E) where
                B is the batch size, T is the sequence length (tokens_num), and
                E is the embedding dimension.

        Returns:
            torch.Tensor: The output tensor of shape (B, T, H) where H is the
                head size. This is the weighted aggregation of the Value vectors.
        """
        batch_size, tokens_num, embedding_dim = input_embeddings.shape
        
        # 1. Project input into Query, Key, and Value tensors.
        Q = self.Q_weights(input_embeddings) # Shape: (B, T, H)
        K = self.K_weights(input_embeddings) # Shape: (B, T, H)
        V = self.V_weights(input_embeddings) # Shape: (B, T, H)

        # 2. Calculate attention scores by taking the dot product of Q and K.
        # K is transposed to align dimensions for matrix multiplication.
        attention_scores = Q @ K.transpose(1, 2)  # Shape: (B, T, T)

        # 3. Apply the causal mask to prevent future-peeking.
        # We set the scores for future positions to negative infinity so that
        # they become zero after the softmax operation.
        attention_scores = attention_scores.masked_fill(
            self.casual_attention_mask[:tokens_num, :tokens_num] == 0,
            -torch.inf
        )
        
        # 4. Scale the attention scores to stabilize gradients.
        # This is divided by the square root of the Key dimension (head_size).
        attention_scores = attention_scores / (K.shape[-1] ** 0.5)
        
        # 5. Apply softmax to convert scores into probability distributions (weights).
        attention_scores = torch.softmax(attention_scores, dim=-1)
        
        # 6. Apply dropout for regularization.
        attention_scores = self.dropout(attention_scores)

        # 7. Compute the final output by taking a weighted sum of the Value vectors.
        return attention_scores @ V # Shape: (B, T, H)

In [None]:
input = torch.rand(8, config["context_size"], config["embedding_dim"]) 

In [13]:
ah = AttentionHead(config)

In [14]:
output = ah(input)

In [15]:
output.shape

torch.Size([8, 256, 64])


---

# Multi-Head Attention Implementation

###### Core Purpose

The `MultiHeadAttention` class implements the complete multi-head attention mechanism by combining multiple parallel attention heads and processing their combined output. This allows the model to attend to different types of relationships and patterns simultaneously across multiple representation subspaces.

###### Architecture Overview

**Parallel Head Structure:**
```python
heads_list = [AttentionHead(config) for _ in range(config["heads_num"])]
self.heads = nn.ModuleList(heads_list)
```

Creates multiple independent attention heads (typically 12) that operate in parallel, each focusing on different aspects of the input relationships.

**Output Processing:**
```python
self.linear = nn.Linear(config["embedding_dim"], config["embedding_dim"])
self.dropout = nn.Dropout(config["dropout_rate"])
```

A linear projection layer that processes the concatenated head outputs, followed by dropout for regularization.

###### Forward Pass Breakdown

**Step 1: Parallel Head Computation**
```python
heads_outputs = [head(input) for head in self.heads]
```

Each attention head processes the input independently:
- Input shape: `(batch_size, sequence_length, embedding_dim)`
- Each head output shape: `(batch_size, sequence_length, head_size)`
- Number of outputs: `heads_num` (e.g., 12)

**Step 2: Concatenation**
```python
scores_change = torch.cat(heads_outputs, dim=-1)
```

Concatenates all head outputs along the feature dimension:
- Individual head: `(B, C, head_size)` where `head_size = embedding_dim // heads_num`
- After concatenation: `(B, C, heads_num × head_size) = (B, C, embedding_dim)`

**Numerical Example:**
```python
# Config: embedding_dim=768, heads_num=12, head_size=64
# Input: (32, 256, 768)  # batch_size=32, sequence_length=256

# Each head output: (32, 256, 64)
# After concatenation: (32, 256, 768)  # 12 × 64 = 768
```

**Step 3: Linear Projection**
```python
scores_change = self.linear(scores_change)
```

Applies a learned linear transformation to the concatenated outputs:
- Weight matrix: `(embedding_dim, embedding_dim)` = `(768, 768)`
- Allows heads to interact and combine their representations
- Maintains the original embedding dimension

**Step 4: Regularization**
```python
return self.dropout(scores_change)
```

Applies dropout to prevent overfitting on the attention patterns.

###### Why Multiple Heads?

**Representation Diversity:**
Each head can specialize in different types of relationships:
- Head 1: Syntactic dependencies (subject-verb relationships)
- Head 2: Semantic similarity (related concepts)
- Head 3: Positional patterns (sequential ordering)
- Head 4: Long-range dependencies (paragraph-level connections)

**Parallel Processing:**
All heads compute simultaneously, making the operation efficient while capturing multiple perspectives on the same input.

###### Mathematical Formulation

The complete multi-head attention can be expressed as:
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

Where:
- $\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$
- $W^O$ is the output projection matrix (`self.linear`)
- $h$ is the number of heads

###### Dimension Preservation

**Key insight:** The total computational cost remains similar to single-head attention:
- Single head: `(B, C, embedding_dim)` → `(B, C, embedding_dim)`
- Multi-head: `heads_num × (B, C, head_size)` → `(B, C, embedding_dim)`
- Total parameters: Similar due to dimension splitting

**Head Size Relationship:**
```python
head_size = embedding_dim // heads_num  # 768 // 12 = 64
total_dim = heads_num × head_size       # 12 × 64 = 768
```

###### Integration Benefits

**Enhanced Representation:** Captures multiple types of attention patterns simultaneously rather than learning a single averaged attention pattern.

**Computational Efficiency:** Parallel computation across heads with dimension splitting maintains reasonable computational cost.

**Learning Flexibility:** Different heads can specialize during training without interfering with each other's learning process.

This multi-head structure is fundamental to transformer architectures, enabling the model to build rich, multi-faceted representations of sequence relationships that single-head attention cannot achieve.

---

In [18]:
import torch
import torch.nn as nn


class MultiHeadAttention(nn.Module):
    """
    Implements the Multi-Head Attention mechanism for a transformer model.

    This module runs multiple self-attention "heads" in parallel and then
    concatenates their outputs. This allows the model to jointly attend to
    information from different representation subspaces at different positions.
    A final linear layer is applied to the concatenated outputs to produce
    the final result.

    This architecture is a core component of the Transformer model, enabling it
    to capture a richer variety of relationships within the input sequence.
    """
    def __init__(self, config):
        """
        Initializes the MultiHeadAttention module.

        Args:
            config (dict): A configuration dictionary containing the following keys:
                - "heads_num" (int): The number of parallel attention heads to use.
                - "embedding_dim" (int): The dimensionality of the input and output.
                - "dropout_rate" (float): The dropout rate for the final output.
                - Other keys required by `AttentionHead` (head_size, use_bias, etc.).
        """
        super().__init__()

        # Create a list of `AttentionHead` modules, one for each head.
        # `nn.ModuleList` is used to properly register all the heads as sub-modules.
        heads_list = [AttentionHead(config) for _ in range(config["heads_num"])]
        self.heads = nn.ModuleList(heads_list)

        # A final linear layer to project the concatenated head outputs back
        # to the original embedding dimension.
        self.linear = nn.Linear(config["embedding_dim"], config["embedding_dim"])
        
        # A dropout layer for regularization on the final output.
        self.dropout = nn.Dropout(config["dropout_rate"])

    def forward(self, input_embeddings):
        """
        Performs the forward pass for Multi-Head Attention.

        Args:
            input_embeddings (torch.Tensor): A tensor of shape (B, T, E) where
                B is the batch size, T is the sequence length, and E is the
                embedding dimension.

        Returns:
            torch.Tensor: The final output tensor of shape (B, T, E).
        """
        # 1. Run each attention head in parallel on the same input.
        # This results in a list of output tensors, each of shape (B, T, H).
        heads_outputs = [head(input_embeddings) for head in self.heads]

        # 2. Concatenate the outputs of all heads along the last dimension.
        # If we have N heads, the shape becomes (B, T, N * H).
        # Note: For this to work, N * H must equal the embedding_dim.
        concatenated_heads = torch.cat(heads_outputs, dim=-1)

        # 3. Pass the concatenated output through a final linear layer.
        # This projects the combined attention information back to the original
        # embedding dimension, shape (B, T, E).
        projected_output = self.linear(concatenated_heads)
        
        # 4. Apply dropout for regularization.
        return self.dropout(projected_output)

In [19]:
mha = MultiHeadAttention(config)

In [20]:
output = mha(input)

In [21]:
output.shape

torch.Size([8, 256, 768])