<a href="https://colab.research.google.com/github/addamit/LMExperiments/blob/main/prefill_phase.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F

Imagine we have the following input sequence tokens
[10, 20, 30, 40, 50, 60].

First, we convert these tokens into embeddings.

For simplicity, let's say our embedding table gives us:

Token 10 → [0.1, 0.2, 0.3, 0.4]  
Token 20 → [0.2, 0.3, 0.4, 0.5]  
Token 30 → [0.3, 0.4, 0.5, 0.6]  
Token 40 → [0.4, 0.5, 0.6, 0.7]  
Token 50 → [0.5, 0.6, 0.7, 0.8]  
Token 60 → [0.6, 0.7, 0.8, 0.9]  



In [27]:
#Our embedding matrix E is now:
E = [
    [0.1, 0.2, 0.3, 0.4],  # token 10
    [0.2, 0.3, 0.4, 0.5],  # token 20
    [0.3, 0.4, 0.5, 0.6],  # token 30
    [0.4, 0.5, 0.6, 0.7],  # token 40
    [0.5, 0.6, 0.7, 0.8],  # token 50
    [0.6, 0.7, 0.8, 0.9]   # token 60
]

In [31]:
# converting E into torch tensor
import torch
E = torch.tensor(E)
E.shape

torch.Size([6, 4])

In [34]:
# Now, for a single attention head, let's define some example weight matrices:
W_q = torch.tensor([  # Query weights
    [1.0, 0.5],
    [0.6, 0.3],
    [0.4, 0.2],
    [0.8, 0.7]
])

W_k = torch.tensor([  # Key weights
    [0.9, 0.2],
    [0.5, 0.4],
    [0.3, 0.8],
    [0.1, 0.6]
])

W_v = torch.tensor([  # Value weights
    [0.8, 0.3],
    [0.2, 0.5],
    [0.7, 0.4],
    [0.1, 0.9]
])


In [35]:
W_q.shape, W_k.shape, W_v.shape

(torch.Size([4, 2]), torch.Size([4, 2]), torch.Size([4, 2]))

In [36]:
# Step 1: Calculate Q, K, V matrices
# For each token, compute Q, K, V vectors
Q = E @ W_q  # Shape: [6, 2]
K = E @ W_k  # Shape: [6, 2]
V = E @ W_v  # Shape: [6, 2]

In [37]:
Q.shape, K.shape, V.shape

(torch.Size([6, 2]), torch.Size([6, 2]), torch.Size([6, 2]))

In [40]:
# attention scores
attention_scores = Q @ K.transpose(-1, -2)
attention_scores = attention_scores / math.sqrt(2) # our d_k = 2


In [41]:
attention_scores.shape

torch.Size([6, 6])

In [42]:
attention_scores

tensor([[0.3339, 0.4815, 0.6292, 0.7768, 0.9245, 1.0721],
        [0.4670, 0.6743, 0.8816, 1.0889, 1.2963, 1.5036],
        [0.6001, 0.8671, 1.1341, 1.4011, 1.6681, 1.9351],
        [0.7331, 1.0598, 1.3865, 1.7132, 2.0399, 2.3665],
        [0.8662, 1.2526, 1.6389, 2.0253, 2.4117, 2.7980],
        [0.9993, 1.4453, 1.8914, 2.3374, 2.7835, 3.2295]])

In [43]:
# Create causal mask (lower triangular matrix)
mask = torch.tril(torch.ones(6, 6))
mask

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [44]:
# Apply mask (set upper triangular part to -infinity)
masked_scores = attention_scores * mask + -1e9 * (1 - mask)
masked_scores

tensor([[ 3.3390e-01, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,
         -1.0000e+09],
        [ 4.6697e-01,  6.7430e-01, -1.0000e+09, -1.0000e+09, -1.0000e+09,
         -1.0000e+09],
        [ 6.0005e-01,  8.6705e-01,  1.1341e+00, -1.0000e+09, -1.0000e+09,
         -1.0000e+09],
        [ 7.3313e-01,  1.0598e+00,  1.3865e+00,  1.7132e+00, -1.0000e+09,
         -1.0000e+09],
        [ 8.6621e-01,  1.2526e+00,  1.6389e+00,  2.0253e+00,  2.4117e+00,
         -1.0000e+09],
        [ 9.9928e-01,  1.4453e+00,  1.8914e+00,  2.3374e+00,  2.7835e+00,
          3.2295e+00]])

In [46]:
# Apply softmax row-wise
attention_weights = F.softmax(masked_scores, dim=1)  # Shape: [6, 6]
attention_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4484, 0.5516, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2493, 0.3256, 0.4252, 0.0000, 0.0000, 0.0000],
        [0.1434, 0.1988, 0.2756, 0.3821, 0.0000, 0.0000],
        [0.0799, 0.1176, 0.1731, 0.2547, 0.3748, 0.0000],
        [0.0415, 0.0649, 0.1014, 0.1584, 0.2474, 0.3864]])

In [47]:
# Context vectors = attention_weights @ V
context_vectors = attention_weights @ V  # Shape: [6, 2]

In [48]:
context_vectors.shape

torch.Size([6, 2])

In [49]:
kv_cache = {
    'K': K,  # Shape: [6, 2]
    'V': V   # Shape: [6, 2]
}

In [50]:
kv_cache

{'K': tensor([[0.3200, 0.5800],
         [0.5000, 0.7800],
         [0.6800, 0.9800],
         [0.8600, 1.1800],
         [1.0400, 1.3800],
         [1.2200, 1.5800]]),
 'V': tensor([[0.3700, 0.6100],
         [0.5500, 0.8200],
         [0.7300, 1.0300],
         [0.9100, 1.2400],
         [1.0900, 1.4500],
         [1.2700, 1.6600]])}

In [1]:
import torch
import torch.nn.functional as F
import math
from transformers import GPT2Model, GPT2Config, GPT2Tokenizer

In [2]:
# Load pre-trained GPT2 model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
model.eval()  # S

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [3]:
# Sample input
text = "Hello, this is a test of the prefill phase."
input_ids = tokenizer.encode(text, return_tensors='pt')


In [4]:
input_ids.shape

torch.Size([1, 12])

In [5]:
# Extract single layer params (let's use the first layer)
layer_idx = 0
layer = model.h[layer_idx]

In [6]:

layer

GPT2Block(
  (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Conv1D(nf=2304, nx=768)
    (c_proj): Conv1D(nf=768, nx=768)
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Conv1D(nf=3072, nx=768)
    (c_proj): Conv1D(nf=768, nx=3072)
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [9]:
# Configuration
config = model.config


In [10]:
# Get token embeddings and position embeddings
with torch.no_grad():
    # Get initial embeddings (token embeddings + position embeddings)
    token_embeddings = model.wte(input_ids)
    position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long).unsqueeze(0)
    position_embeddings = model.wpe(position_ids)
    hidden_states = token_embeddings + position_embeddings

In [11]:

hidden_states.shape

torch.Size([1, 12, 768])

In [12]:
# GPT2 uses causal attention mask
batch_size, seq_length = hidden_states.size()[:2]
causal_mask = torch.tril(torch.ones((seq_length, seq_length))).view(
    1, 1, seq_length, seq_length
)

# batch_size, seq_length, causal_mask

(1, 12)

In [14]:
# Layer norm before attention
ln_1_out = layer.ln_1(hidden_states)
ln_1_out.shape

torch.Size([1, 12, 768])

In [16]:
# Multi-head attention
query = layer.attn.c_attn(ln_1_out)
query.shape

torch.Size([1, 12, 2304])

In [17]:
# Split into query, key, value
query, key, value = query.split(config.hidden_size, dim=2)
query.shape, key.shape, value.shape

(torch.Size([1, 12, 768]), torch.Size([1, 12, 768]), torch.Size([1, 12, 768]))

In [18]:
# Reshape for multi-head attention
query = query.view(batch_size, seq_length, config.n_head, config.hidden_size // config.n_head).transpose(1, 2)
key = key.view(batch_size, seq_length, config.n_head, config.hidden_size // config.n_head).transpose(1, 2)
value = value.view(batch_size, seq_length, config.n_head, config.hidden_size // config.n_head).transpose(1, 2)
query.shape, key.shape, value.shape

(torch.Size([1, 12, 12, 64]),
 torch.Size([1, 12, 12, 64]),
 torch.Size([1, 12, 12, 64]))

In [19]:
# Compute attention scores
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = attn_weights / math.sqrt(config.hidden_size // config.n_head)
attn_weights.shape

torch.Size([1, 12, 12, 12])

In [20]:
# Apply causal mask
attn_weights = attn_weights.masked_fill(causal_mask == 0, float('-inf'))

# Softmax and dropout
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = F.dropout(attn_weights, p=config.attn_pdrop, training=False)
attn_weights.shape

torch.Size([1, 12, 12, 12])

In [21]:
# Apply attention to values
attn_output = torch.matmul(attn_weights, value)

# Reshape and project
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, config.hidden_size)
attn_output = layer.attn.c_proj(attn_output)
attn_output.shape

torch.Size([1, 12, 768])

In [22]:
# Add residual connection
hidden_states = hidden_states + attn_output

# Layer norm before MLP
ln_2_out = layer.ln_2(hidden_states)

# MLP
mlp_output = layer.mlp.c_fc(ln_2_out)
mlp_output = F.gelu(mlp_output)
mlp_output = layer.mlp.c_proj(mlp_output)

# Add residual connection
hidden_states = hidden_states + mlp_output

In [23]:

hidden_states.shape

torch.Size([1, 12, 768])

In [24]:
# Compare with original output
with torch.no_grad():
    # Extract the output from the same layer in the original model
    original_layer_output = model(input_ids, output_hidden_states=True).hidden_states[layer_idx + 1]

In [25]:
original_layer_output.shape

torch.Size([1, 12, 768])

In [26]:
# Check if outputs match
diff = (hidden_states - original_layer_output).abs().max().item()
print(f"Maximum difference between custom and original: {diff}")

if diff < 1e-5:
    print("✓ Implementation matches the original model!")
else:
    print("✗ Implementation does not match the original model.")

Maximum difference between custom and original: 0.0105743408203125
✗ Implementation does not match the original model.
