In [1]:
#
# SETUP: Imports and Model Loading for the entire notebook
#
import torch
import torch.nn as nn
import time
from transformers import GPT2LMHeadModel, GPT2Tokenizer

print("Setting up models...")
# Load pre-trained model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
print("Setup complete.")

#
# LISTING 2.1: Visualizing autoregressive generation with GPT-2
#
prompt = "The next day is bright"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids

print(f"Prompt: '{prompt}'", end="")

# Generate 20 tokens
for _ in range(20):
    # Pass the entire sequence to the model
    outputs = model(input_ids)
    logits = outputs.logits

    # Get the logits for the very last token
    next_token_logits = logits[:, -1, :]

    # Get the ID of the most likely next token (greedy decoding)
    next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

    # Append the new token ID to the input sequence
    input_ids = torch.cat([input_ids, next_token_id], dim=-1)

    # Decode and print the new token
    new_token = tokenizer.decode(next_token_id[0])
    print(new_token, end="", flush=True)

print("\n")

Setting up models...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Setup complete.
Prompt: 'The next day is bright' and sunny, and the sun is shining. The sun is shining, and the moon is shining.



In [2]:
#
# LISTING 2.2: Demonstrating the Speedup of KV Caching
#
prompt = "The next day is bright"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask

# --- Timing without KV cache ---
print("Generating without KV Cache...")
start_time_without_cache = time.time()
output_without_cache = model.generate(
    input_ids,
    max_new_tokens=100,
    use_cache=False, # Explicitly disable the cache
    attention_mask=attention_mask
)
end_time_without_cache = time.time()
duration_without_cache = end_time_without_cache - start_time_without_cache
print(f"Time without KV Cache: {duration_without_cache:.4f} seconds\n")


# --- Timing with KV cache ---
print("Generating with KV Cache...")
start_time_with_cache = time.time()
output_with_cache = model.generate(
    input_ids,
    max_new_tokens=100,
    use_cache=True, # Explicitly enable the cache
    attention_mask=attention_mask
)
end_time_with_cache = time.time()
duration_with_cache = end_time_with_cache - start_time_with_cache
print(f"Time with KV Cache: {duration_with_cache:.4f} seconds\n")


# --- Calculate and print the speedup ---
speedup = duration_without_cache / duration_with_cache
print(f"KV Cache Speedup: {speedup:.2f}x")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Generating without KV Cache...


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Time without KV Cache: 57.6215 seconds

Generating with KV Cache...
Time with KV Cache: 6.3934 seconds

KV Cache Speedup: 9.01x


In [3]:
#
# LISTING 2.3: Implementing an MQA layer from scratch
#
class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.0):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, self.d_head) # Single projection for K
        self.W_v = nn.Linear(d_model, self.d_head) # Single projection for V
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        # Using a fixed size mask for demonstration. A dynamic one is better in practice.
        self.register_buffer('mask', torch.triu(torch.ones(1, 1, 1024, 1024), diagonal=1))

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        # Query: (B, num_heads, seq_len, d_head)
        q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)

        # Key & Value: (B, 1, seq_len, d_head)
        k = self.W_k(x).view(batch_size, seq_len, 1, self.d_head).transpose(1, 2)
        v = self.W_v(x).view(batch_size, seq_len, 1, self.d_head).transpose(1, 2)

        # Repeat K and V for each query head
        k = k.repeat(1, self.num_heads, 1, 1) # (B, num_heads, seq_len, d_head)
        v = v.repeat(1, self.num_heads, 1, 1) # (B, num_heads, seq_len, d_head)

        attn_scores = (q @ k.transpose(-2, -1)) / (self.d_head ** 0.5)

        # Apply causal mask
        attn_scores = attn_scores.masked_fill(self.mask[:,:,:seq_len,:seq_len] == 1, float('-inf'))

        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vector = (attn_weights @ v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

        output = self.W_o(context_vector)
        return output

# --- Usage Example ---
d_model = 512
num_heads = 8
batch_size = 4
seq_len = 64

mqa_layer = MultiQueryAttention(d_model, num_heads)
dummy_input = torch.randn(batch_size, seq_len, d_model)
output = mqa_layer(dummy_input)

print("MQA Layer successful!")
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")

MQA Layer successful!
Input shape: torch.Size([4, 64, 512])
Output shape: torch.Size([4, 64, 512])


In [4]:
#
# LISTING 2.4: Implementing a GQA layer from scratch
#
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_groups, dropout=0.0, max_seq_len: int = 1024):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"

        self.d_model = d_model
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.d_head = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, self.num_groups * self.d_head) # Grouped projection for K
        self.W_v = nn.Linear(d_model, self.num_groups * self.d_head) # Grouped projection for V
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self._register_mask_buffer(max_seq_len)

    def _register_mask_buffer(self, max_seq_len):
        if max_seq_len > 0:
            mask = torch.triu(torch.ones(1, 1, max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1)
            self.register_buffer("causal_mask", mask, persistent=False)
        else:
            self.causal_mask = None

    def _get_causal_mask(self, seq_len, device):
        if self.causal_mask is not None and self.causal_mask.size(-1) >= seq_len:
            return self.causal_mask[:, :, :seq_len, :seq_len]
        # Dynamically create mask if needed
        return torch.triu(torch.ones(1, 1, seq_len, seq_len, dtype=torch.bool, device=device), diagonal=1)

    def forward(self, x):
        B, T, _ = x.shape

        # Query: (B, num_heads, T, d_head)
        q = self.W_q(x).view(B, T, self.num_heads, self.d_head).transpose(1, 2)

        # Key & Value: (B, num_groups, T, d_head)
        k = self.W_k(x).view(B, T, self.num_groups, self.d_head).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.num_groups, self.d_head).transpose(1, 2)

        heads_per_group = self.num_heads // self.num_groups

        # Repeat K and V to match query heads
        k = k.repeat_interleave(heads_per_group, dim=1) # (B, num_heads, T, d_head)
        v = v.repeat_interleave(heads_per_group, dim=1) # (B, num_heads, T, d_head)

        attn_scores = (q @ k.transpose(-2, -1)) * (self.d_head**-0.5)

        causal_mask = self._get_causal_mask(T, x.device)
        attn_scores = attn_scores.masked_fill(causal_mask, float("-inf"))

        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context = (attn_weights @ v).transpose(1, 2).contiguous().view(B, T, self.d_model)

        return self.W_o(context)

# --- Usage Example ---
d_model = 512
num_heads = 32
num_groups = 4 # e.g., Llama 2 7B uses 4 groups for 32 heads
batch_size = 4
seq_len = 64

gqa_layer = GroupedQueryAttention(d_model, num_heads, num_groups)
dummy_input = torch.randn(batch_size, seq_len, d_model)
output = gqa_layer(dummy_input)

print("GQA Layer successful!")
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")

GQA Layer successful!
Input shape: torch.Size([4, 64, 512])
Output shape: torch.Size([4, 64, 512])
