### GPT-2 Block Architecture

This section aims to recreate the `transformers.models.gpt2.modeling_gpt2.GPT2Block` architecture, ensuring that $\text{GPT2Block}(H_i) = H_{i+1}$.

The GPT-2 Block comprises three key components:
1. **Layer Normalization**
2. **Attention Block**
3. **Multi-Layer Perceptron**

We will reconstruct each component and explain how they contribute to the final output.


The self-attention mechanism allows a model to weigh the importance of each word in a sentence when making predictions. For the sentence "The students opened their books," self-attention helps the model understand that "their" refers to "students" and not "books." It does this by computing attention scores between each word and every other word. These scores determine how much focus to place on each word when generating a representation of the sentence, enabling the model to capture relationships and dependencies across the sequence.

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoTokenizer, AutoModel
from itertools import islice

query_start, query_end = 0, 64
key_start, key_end = 768, 832
value_start, value_end = 1536, 1600

def rel_error(x, y):
    """ Returns relative error using PyTorch """
    # Calculate the absolute difference
    abs_diff = torch.abs(x - y)
    
    # Calculate the sum of absolute values
    abs_sum = torch.abs(x) + torch.abs(y)
    
    # Compute the relative error, ensuring no division by zero
    return torch.max(abs_diff / torch.clamp(abs_sum, min=1e-8))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = AutoConfig.from_pretrained(model_id, output_hidden_states=True)
model = AutoModel.from_pretrained(model_id, config=config)

batch = ["The students opened their books", "The cat chased the mouse", "The chef prepared a meal"]
inputs = tokenizer(batch, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)

hidden_states = outputs.hidden_states
gpt2_blocks = model.h
gpt2_block = gpt2_blocks[0]



In [3]:
model

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2SdpaAttention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

First, `tokenizer` maps the input text to indices, similar to how a dictionary or hashmap works.
If you're curious about how this vocabulary is created, you can learn more about it by checking out SentencePiece [here](https://github.com/google/sentencepiece).

In [4]:
for i in range(3):
    print(f"{batch[i]} = {inputs.input_ids[i].tolist()}")

The students opened their books = [464, 2444, 4721, 511, 3835]
The cat chased the mouse = [464, 3797, 26172, 262, 10211]
The chef prepared a meal = [464, 21221, 5597, 257, 9799]


# Embeddings

- Position and word embeddings serve as a lookup table for words in the vocabulary.
- Embeddings are represented as a 2-dimensional matrix.

1) **Word Embeddings**:
    - Embeddings are represented as a matrix $\in \mathbb{R}^{V \times D}$.
    - $V$ is the length of the vocabulary; $D$ is the length/dimension of the embedding.
    - Each row in the matrix represents a specific token index.
    - Entries in the row provide information about the corresponding token.
    
$$
\begin{array}{ccccc}
\text{The} & \text{students} & \text{opened} & \text{their} & \text{books} \\
\downarrow & \downarrow & \downarrow & \downarrow & \downarrow \\
464 & 2444 & 4721 & 511 & 3835 \\
\downarrow & \downarrow & \downarrow & \downarrow & \downarrow \\

\begin{bmatrix}
t_{1,1} \\
t_{1,2} \\
\vdots \\
t_{1,D}
\end{bmatrix}^\top &
\begin{bmatrix}
t_{2,1} \\
t_{2,2} \\
\vdots \\
t_{2,D}
\end{bmatrix}^\top &
\begin{bmatrix}
t_{3,1} \\
t_{3,2} \\
\vdots \\
t_{3,D}
\end{bmatrix}^\top &
\begin{bmatrix}
t_{4,1} \\
t_{4,2} \\
\vdots \\
t_{4,D}
\end{bmatrix}^\top &
\begin{bmatrix}
t_{5,1} \\
t_{5,2} \\
\vdots \\
t_{5,D}
\end{bmatrix}^\top \\
\end{array}

\quad\implies

\begin{array}{c}
\text{\text{The students opened their books}} \\
\downarrow \\
\begin{bmatrix}
t_{1,1} & t_{1,2} & \cdots & t_{1,D} \\
t_{2,1} & t_{2,2} & \cdots & t_{2,D} \\
\vdots & \vdots & \ddots & \vdots \\
t_{T,1} & t_{T,2} & \cdots & t_{T,D}
\end{bmatrix}
\end{array}
$$

2) **Position Embeddings**:
    - Embeddings are represented as a matrix $\in \mathbb{R}^{L \times D}$.
    - $L$ is the maximum context length; $D$ is the length/dimension of the embedding.
    - Each row in the matrix corresponds to the positions represented by the output of `range(L)`.
    - Entries in the row provide information about the positoin of each token in the sequence.

$$
\begin{array}{ccccc}
0 & 1 & 2 & 3 & 4 \\
\downarrow & \downarrow & \downarrow & \downarrow & \downarrow \\
\begin{bmatrix}
p_{1,1} \\
p_{1,2} \\
\vdots \\
p_{1,D}
\end{bmatrix}^\top &
\begin{bmatrix}
p_{2,1} \\
p_{2,2} \\
\vdots \\
p_{2,D}
\end{bmatrix}^\top &
\begin{bmatrix}
p_{3,1} \\
p_{3,2} \\
\vdots \\
p_{3,D}
\end{bmatrix}^\top &
\begin{bmatrix}
p_{4,1} \\
p_{4,2} \\
\vdots \\
p_{4,D}
\end{bmatrix}^\top &
\begin{bmatrix}
p_{5,1} \\
p_{5,2} \\
\vdots \\
p_{5,D}
\end{bmatrix}^\top \\
\end{array}
\quad\implies
\begin{array}{c}
\text{[0, 1, 2, 3, 4]} \\
\downarrow \\
\begin{bmatrix}
p_{1,1} & p_{1,2} & \cdots & p_{1,D} \\
p_{2,1} & p_{2,2} & \cdots & p_{2,D} \\
\vdots & \vdots & \ddots & \vdots \\
p_{T,1} & p_{T,2} & \cdots & p_{T,D}
\end{bmatrix}
\end{array}
$$

In [5]:
print(f"Word Embeddings    : {model.wte} -> [50256 tokens,   768 dimensional]")
print(f"Position Embeddings: {model.wpe}  -> [1024 positions, 768 dimensional]")

Word Embeddings    : Embedding(50257, 768) -> [50256 tokens,   768 dimensional]
Position Embeddings: Embedding(1024, 768)  -> [1024 positions, 768 dimensional]


In [6]:
class Embeddings(nn.Module):
    def __init__(self, num_entries, entry_dimension):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(num_entries, entry_dimension))

    def forward(self, indices):
        return self.weight[indices]
    
    @classmethod
    def from_GPT2(cls, source):
        num_entries, entry_dimension = source.weight.shape
        embeddings = cls(num_entries, entry_dimension)
        embeddings.load_state_dict(source.state_dict())
        return embeddings

token_ids = inputs.input_ids
position_ids = torch.arange(5).expand(3, -1)

token_embeddings = Embeddings.from_GPT2(model.wte)
position_embeddings = Embeddings.from_GPT2(model.wpe)

with torch.no_grad():
    predicted_token_embeddings = token_embeddings(token_ids)
    expected_token_embeddings = model.wte(token_ids)
    
    predicted_position_embeddings = position_embeddings(position_ids)
    expected_position_embeddings = model.wpe(position_ids)
   

    print(f"Token Embedding Error:    {rel_error(predicted_token_embeddings, expected_token_embeddings)}")
    print(f"Position Embedding Error: {rel_error(predicted_position_embeddings, expected_position_embeddings)}")

Token Embedding Error:    0.0
Position Embedding Error: 0.0


3) Hidden State
    - The first hidden state of GPT-2 is created by adding the token embeddings with the position embeddings.
    - It allows the model to effectively process the input while considering both the content and order of the tokens.

$$
\begin{array}{c}
\text{Token Embeddings} \\
\downarrow \\
\begin{bmatrix}
t_{1,1} & t_{1,2} & \cdots & t_{1,D} \\
t_{2,1} & t_{2,2} & \cdots & t_{2,D} \\
\vdots & \vdots & \ddots & \vdots \\
t_{T,1} & t_{T,2} & \cdots & t_{T,D}
\end{bmatrix}
\end{array}
+
\begin{array}{c}
\text{Position Embeddings} \\
\downarrow \\
\begin{bmatrix}
p_{1,1} & p_{1,2} & \cdots & p_{1,D} \\
p_{2,1} & p_{2,2} & \cdots & p_{2,D} \\
\vdots & \vdots & \ddots & \vdots \\
p_{T,1} & p_{T,2} & \cdots & p_{T,D}
\end{bmatrix}
\end{array}
=
\begin{array}{c}
\text{Hidden States} \\
\downarrow \\
\begin{bmatrix}
h_{1,1} & h_{1,2} & \cdots & h_{1,D} \\
h_{2,1} & h_{2,2} & \cdots & h_{2,D} \\
\vdots & \vdots & \ddots & \vdots \\
h_{T,1} & h_{T,2} & \cdots & h_{T,D}
\end{bmatrix}
\end{array}
$$

In [7]:
with torch.no_grad():
    predicted_hidden_state0 = predicted_token_embeddings + predicted_position_embeddings
    expected_hidden_state0 = hidden_states[0]
    
    print(f"Hidden State Error: {rel_error(predicted_hidden_state0, expected_hidden_state0)}")

Hidden State Error: 0.0


# Attention Block
- The attention mechanism starts with layer normalization for training stability.
- The hidden state is projected into queries (q), keys (k), and values (v).
- Attention scores are calculated using queries and keys to determine token importance.
- Outputs from the attention mechanism are projected back into the required dimensions.

1) **Layer Normalization**

$$
\text{LayerNorm}(x) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta
\quad;\quad
\mu = \frac{1}{D} \sum_{d=1}^{D} x_d
\quad;\quad
\sigma^2 = \frac{1}{D} \sum_{d=1}^{D} (x_d)^2 - \mu^2
$$

Where:
- $\mu$ and $\sigma^2$ is the mean and variance of the input.
- $\epsilon$ is a small constant for numerical stability.
- $\gamma$ and $\beta$ are learned parameters for scaling and shifting.

Notes:
- $D$ is the dimension/axis we calculate the mean
- We use the raw moment equation for calculating $\sigma^2$ 

In [8]:
class LayerNorm(nn.Module):
    def __init__(self, num_layers, eps=1e-5) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_layers))
        self.bias = nn.Parameter(torch.zeros(num_layers))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        mean_x2 = torch.square(x).mean(dim=-1, keepdim=True)
        var = mean_x2 - torch.square(mean)

        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        x_norm = self.weight * x_norm + self.bias

        return x_norm
    
    @classmethod
    def from_GPT2(cls, source):
        num_layers = source.weight.shape
        layer_norm = cls(num_layers)
        layer_norm.load_state_dict(source.state_dict())
        
        return layer_norm

layer_norm = LayerNorm.from_GPT2(gpt2_block.ln_1)
with torch.no_grad():
    expected_layer_norm = gpt2_block.ln_1(expected_hidden_state0)
    predicted_layer_norm = layer_norm(expected_hidden_state0)
    
    print(f"Layer Normalization Error: {rel_error(predicted_layer_norm, expected_layer_norm)}")

Layer Normalization Error: 2.8775712053175084e-05


2) **Conv1d**

$$
\text{Input Matrix} = 
\begin{bmatrix}
    \begin{bmatrix}
    h_{11} & h_{12} & \cdots & h_{1D}
    \end{bmatrix}_{h_1} \\
    \begin{bmatrix}
    h_{21} & h_{22} & \cdots & h_{2D}
    \end{bmatrix}_{h_2} \\
    \vdots \\
    \begin{bmatrix}
    h_{L1} & h_{L2} & \cdots & h_{LD}
    \end{bmatrix}_{h_L}
\end{bmatrix}_{L \times D}
\quad
\text{Filter Weight} = 
\begin{bmatrix}
    \begin{bmatrix}
    w_{11} \\
    w_{21} \\
    \vdots \\
    w_{D1}
    \end{bmatrix}_{w_1}
    \begin{bmatrix}
    w_{12} \\
    w_{22} \\
    \vdots \\
    w_{D2}
    \end{bmatrix}_{w_2}
    \cdots \quad
    \begin{bmatrix}
    w_{1E} \\
    w_{2E} \\
    \vdots \\
    w_{DE}
    \end{bmatrix}_{w_E}
\end{bmatrix}_{D \times E}
\quad
\text{Filter Bias} =
\begin{bmatrix}
b_{1} \\
b_{2} \\
\vdots \\
b_{E}
\end{bmatrix}_{E \times 1}
$$

$$
\text{out}[L_i, E_j] = \text{bias}[E_j] + \sum_{d=0}^{D - 1} \text{weight}[E_j, d] \ast \text{input}[L_i, d]
$$

In [39]:
class Conv1d(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(input_dim, output_dim))
        self.bias = nn.Parameter(torch.randn(output_dim))

    def forward(self, input):
        output = input @ self.weight + self.bias
        return output

    @classmethod
    def from_GPT2(cls, source, start, end):
        input_dim = source.weight.shape[0]
        output_dim = end - start
        conv1d = cls(input_dim, output_dim) 
        with torch.no_grad():
            conv1d.weight.copy_(source.weight[:, start:end])
            conv1d.bias.copy_(source.bias[start:end])

        return conv1d

query_conv1d = Conv1d.from_GPT2(gpt2_block.attn.c_attn, query_start, query_end)
key_conv1d = Conv1d.from_GPT2(gpt2_block.attn.c_attn, key_start, key_end)
value_conv1d = Conv1d.from_GPT2(gpt2_block.attn.c_attn, value_start, value_end)   

with torch.no_grad():
    expected_conv1d = gpt2_block.attn.c_attn(predicted_layer_norm)

    expected_query = expected_conv1d[:, :, query_start: query_end]
    expected_key = expected_conv1d[:, :, key_start: key_end]
    expected_value = expected_conv1d[:, :, value_start: value_end]

    predicted_query = query_conv1d(predicted_layer_norm)
    predicted_key = key_conv1d(predicted_layer_norm)
    predicted_value = value_conv1d(predicted_layer_norm)

    print(f"Predicted Query: {rel_error(predicted_query, expected_query)}")
    print(f"Predicted Key  : {rel_error(predicted_key, expected_key)}")
    print(f"Predicted Value: {rel_error(predicted_value, expected_value)}")

Predicted Query: 1.3251878044684418e-05
Predicted Key  : 7.21017386240419e-06
Predicted Value: 3.2268785616906825e-06


In [42]:
class Conv1d(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(input_dim, output_dim))
        self.bias = nn.Parameter(torch.randn(output_dim))

    def forward(self, input):
        output = input @ self.weight + self.bias
        return output
    
    def split_qkv(self, input, num_heads, head_dim):
        query, key, value = input.split(num_heads * head_dim, dim=2)
        return query, key, value

    @classmethod
    def from_GPT2(cls, source):
        input_dim, output_dim = source.weight.shape
        conv1d = cls(input_dim, output_dim)
        conv1d.load_state_dict(source.state_dict())

        return conv1d

conv1d = Conv1d.from_GPT2(gpt2_block.attn.c_attn)

with torch.no_grad():
    expected_qkv = gpt2_block.attn.c_attn(predicted_layer_norm)
    expected_query, expected_key, expected_value = expected_qkv.split(768, dim=2)

    predicted_qkv = conv1d(predicted_layer_norm)
    predicted_query, predicted_key, predicted_value = predicted_qkv.split(768, dim=2)

    print(f"QKV Error        : {rel_error(predicted_qkv, expected_qkv)}")

    print(f"Batch Query Error: {rel_error(predicted_query, expected_query)}")
    print(f"Batch Key Error  : {rel_error(predicted_key, expected_key)}")
    print(f"Batch Value Error: {rel_error(predicted_value, expected_value)}")

QKV Error        : 8.590293145971373e-05
Batch Query Error: 5.80195446673315e-05
Batch Key Error  : 8.590293145971373e-05
Batch Value Error: 3.689480581670068e-05


3. **Attention Mechanism**

$$
\text{Attention}(Q, K, V) = \underbrace{\text{softmax} \left( \frac{Q K^T}{\sqrt{\text{dim}_k}} \right)}_{\text{Attention Score}} \underbrace{\phantom{\frac{V}{V}} V \phantom{\frac{V}{V}}}_{\text{Hidden State}}
$$

$\quad$ 3.1. **Attention Score**:

The sentence tokens are represented as embeddings in the query (Q) and key (K) matrices. Each token, such as "books," "The," "students," etc., is associated with query and key vectors:

$$
\text{softmax}\left(
\begin{array}{cccccc}
\text{books/} q_0 & & \text{The/} k_0 & \text{students/} k_1 & \text{opened/} k_2 & \text{their/} k_3 & \text{books/} k_4 \\
\downarrow & & \downarrow & \downarrow & \downarrow & \downarrow & \downarrow \\
\begin{bmatrix}
q_0 & q_1 & \cdots & q_{E-1}
\end{bmatrix} &
\cdot &
\begin{bmatrix}
k_{0,0} \\
k_{1,0} \\
\vdots \\
k_{E-1,0}
\end{bmatrix} &
\begin{bmatrix}
k_{0,1} \\
k_{1,1} \\
\vdots \\
k_{E-1,1}
\end{bmatrix} &
\begin{bmatrix}
k_{0,2} \\
k_{1,2} \\
\vdots \\
k_{E-1,2}
\end{bmatrix} &
\begin{bmatrix}
k_{0,3} \\
k_{1,3} \\
\vdots \\
k_{E-1,3}
\end{bmatrix} &
\begin{bmatrix}
k_{0,4} \\
k_{1,4} \\
\vdots \\
k_{E-1,4}
\end{bmatrix} \\
\end{array}
\right)
=
\begin{array}{cccccc}
\text{The} & \text{students} & \text{opened} & \text{their} & \text{books} \\
\downarrow & \downarrow & \downarrow & \downarrow & \downarrow \\
0.051 & 0.511 & 0.069 & 0.057 & 0.310
\end{array}
$$

- The dot product $QK^T$ computes how similar each query is to the keys.
- $\sqrt{\text{dim}_k}$ ensures numeric stability
- Applying the softmax function ensures that the **attention scores for each query sum** to 1, producing a weighted distribution over all the values (V).

$\quad$ 3.2. **Attention Outputs**:

Each value vector $v$ is multiplied by its corresponding attention weight. These weighted values are then combined to form the final output of the attention mechanism, which is a weighted sum of the value vectors.
$$
\begin{array}{cccccc}
    & \text{The} & \text{students} & \text{opened} & \text{their} & \text{books} \\
    & \downarrow & \downarrow & \downarrow & \downarrow & \downarrow \\
\text{book} \rightarrow & [\quad 0.051 \cdot v_0 & 0.511 \cdot v_1 & 0.069 \cdot v_2 & 0.057 \cdot v_3 & 0.310 \cdot v_4\quad]
\end{array}
$$

- The right-hand part, $V$ hold the actual information the attention mechanism will pass forward.
- The attention scores act as weights applied to the value vectors, determining how much attention each query should give to each corresponding value vector.

In [23]:
def attention_score(query, key, values):
    key_dim = key.shape[-1]
    attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(key_dim)
    attn_weights = F.softmax(attn_weights, dim=-1)

    attn_output = torch.matmul(attn_weights, values)
    return attn_output

with torch.no_grad():
    expected_attn_output = F.scaled_dot_product_attention(expected_query, expected_key, expected_value, is_causal=False)
    predicted_attn_output = attention_score(expected_query, expected_key, expected_value)

    print(rel_error(expected_attn_output, predicted_attn_output))

tensor(0.0002)


$\quad$ 3.3. **Causal Mask**:

In [20]:
def attention_score(query, key, values, is_causal=True):
    query_length = query.shape[-2]
    key_length, key_dim = key.shape[-2], key.shape[-1]
    
    attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(key_dim)
    
    if is_causal:
        causal_mask = torch.ones(query_length, key_length, dtype=torch.bool).tril(diagonal=0)
        attn_weights = torch.where(causal_mask, attn_weights, float("-inf"))

    attn_weights = F.softmax(attn_weights, dim=-1)
    attn_output = torch.matmul(attn_weights, values)
    return attn_output

with torch.no_grad():
    expected_attn_output = gpt2_block.attn._attn(expected_query, expected_key, expected_value)[0]
    predicted_attn_output = attention_score(expected_query, expected_key, expected_value)

    print(rel_error(expected_attn_output, predicted_attn_output))

tensor(0.)


$\quad$ 3.4. **Batched Multi-Attention**: