# Building a Modern Large Language Model from Scratch

**Objective:** We will build a Generative Pre-trained Transformer (GPT) from scratch. This is the same technology behind ChatGPT, Claude, and Llama.

**No prior knowledge required.** We will explain every concept as if you are new to AI.

## Table of Contents

| Stage | Topic | What we are building |
|-------|-------|----------------------|
| 0 | Setup | Preparing our tools |
| 1 | Bigram Model | A model that guesses the next word based on just one word |
| 2 | Tokenization | Converting words into numbers the computer understands |
| 3 | Attention | Giving the model 'memory' to look at previous words |
| 4 | Modern Components | Upgrading the engine with 2024 technology (Llama 3) |
| 5 | Full GPT Model | Assembling the robot |
| 6 | Training | Teaching the robot to read and write |
| 7 | Inference & Chat | Talking to our creation |
| 8 | RLHF Alignment | Teaching the robot manners |

---


## Stage 0: Setup

Before we build, we need tools. We will use **PyTorch**, which is a library that helps us do the heavy math required for AI.

We also need a **GPU** (Graphics Processing Unit). Think of a CPU (your computer's main brain) as a math professor—very smart but does one problem at a time. A GPU is like a thousand elementary school students—they can solve thousands of simple problems at the exact same time. AI needs this parallel speed.


In [None]:
# Setup local checkpoint directory
# Since we are not using Google Drive, we will save files locally.
# WARNING: These files will be DELETED if the Colab runtime disconnects.
import os
PROJECT_DIR = "checkpoints"
os.makedirs(PROJECT_DIR, exist_ok=True)
print(f"Local checkpoint directory created at: {PROJECT_DIR}")
print("Remember to DOWNLOAD your checkpoint file ('ckpt.pt') to your computer to save your progress!")


In [None]:
# Install the necessary libraries.
# 'torch': The framework for building neural networks.
# 'tiktoken': A tool from OpenAI to turn text into numbers.
# 'datasets': A library to download books and text for training.
!pip install -q torch tiktoken datasets matplotlib

import torch
import torch.nn as nn
from torch.nn import functional as F
import math, time

# Check if we have a GPU available.
# 'cuda' means we have an NVIDIA GPU. 'cpu' means we are using the slow processor.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


## Stage 1: The Simplest Language Model (Bigram)

### What is a Language Model?
A language model is just a machine that plays a guessing game. You give it a sequence of words, and it guesses what comes next.

**Example:** "The cat sat on the _____."
You probably guessed "mat" or "floor". That's because you have a language model in your brain.

### The Bigram Model
We will start with the simplest possible version: a **Bigram Model**.
This model is very forgetful. It looks **only at the very last word** to guess the next one. It ignores everything else.

- Input: "The cat sat on the"
- Bigram sees only: "the"
- Bigram guesses: "end" (because "the end" is common)

It has no memory of the "cat" or the "sitting". Let's build it to see how bad it is.


In [None]:
# Download a small dataset (Shakespeare) to practice on.
!wget -q https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r') as f:
    text = f.read()
print(f"Dataset length: {len(text):,} characters")
print("First 200 characters:")
print(text[:200])


In [None]:
# Computers cannot understand letters like 'a' or 'b'. They only understand numbers.
# We need a 'Tokenizer' to convert characters into integers.

# 1. Find all unique characters in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)

# 2. Create a map from character to integer (stoi) and integer to character (itos)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

# 3. Define helper functions to convert back and forth
encode = lambda s: [stoi[c] for c in s] # Text -> Numbers
decode = lambda l: ''.join([itos[i] for i in l]) # Numbers -> Text

print(f"Vocabulary size: {vocab_size} unique characters")
print(f"Example: 'hello' becomes {encode('hello')}")


In [None]:
# Convert the entire text into a long list of numbers (tensor)
data = torch.tensor(encode(text), dtype=torch.long)

# Split into training (90%) and validation (10%) sets.
# We train on the 90% and test on the 10% to make sure the model isn't just memorizing.
n = int(0.9 * len(data))
train_data, val_data = data[:n], data[n:]

block_size = 8   # The maximum number of characters we look at (context)
batch_size = 32  # How many independent sequences we process at once (parallelism)

def get_batch(split):
    # Pick a random chunk of data
    d = train_data if split == 'train' else val_data
    ix = torch.randint(len(d) - block_size, (batch_size,))
    
    # x is the input, y is the target (the very next character)
    x = torch.stack([d[i:i+block_size] for i in ix])
    y = torch.stack([d[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)


### Understanding Inputs (x) and Targets (y)

We train the model by showing it a sequence and asking it to predict the next character.

If the text is "Hello":
- When input is `[H]`, target is `e`
- When input is `[H, e]`, target is `l`
- When input is `[H, e, l]`, target is `l`

We do this for every position simultaneously.


In [None]:
class BigramModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # The Embedding layer is a lookup table.
        # For every input number, it looks up a list of probabilities for the next number.
        self.embed = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets=None):
        # idx is the input numbers (Batch, Time)
        logits = self.embed(idx)
        loss = None
        
        if targets is not None:
            # If we have targets, calculate how wrong we are (loss)
            B, T, C = logits.shape
            # Reshape to match what PyTorch expects
            loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T))
        return logits, loss
    
    def generate(self, idx, max_new):
        # This function generates new text
        for _ in range(max_new):
            # Get predictions
            logits, _ = self(idx)
            # Focus only on the last time step (Bigram only cares about the last token)
            logits = logits[:, -1, :]
            # Convert scores to probabilities
            probs = F.softmax(logits, dim=-1)
            # Sample the next character from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # Append to the sequence
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

bigram = BigramModel(vocab_size).to(device)
print(f"Model parameters: {sum(p.numel() for p in bigram.parameters()):,}")


In [None]:
# Let's see what the untrained model generates.
# Since it hasn't learned anything, it should just be random gibberish.
print("Output BEFORE training:")
context = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(bigram.generate(context, 100)[0].tolist()))


In [None]:
# Create an optimizer.
# The optimizer is the algorithm that updates the model's numbers to reduce the error (loss).
optimizer = torch.optim.AdamW(bigram.parameters(), lr=1e-3)

print("Starting training loop...")
for step in range(1000):
    # 1. Get a batch of data
    xb, yb = get_batch('train')
    
    # 2. Forward pass: Make predictions and calculate loss
    _, loss = bigram(xb, yb)
    
    # 3. Backward pass: Calculate gradients (how to change weights)
    optimizer.zero_grad()
    loss.backward()
    
    # 4. Update weights
    optimizer.step()
    
    if step % 200 == 0: print(f"Step {step}: loss={loss.item():.4f}")

print(f"Final loss: {loss.item():.4f}")


In [None]:
# Now let's generate text again.
# It should look slightly better (maybe like real words), but still nonsensical
# because it only looks at one character at a time.
print("Output AFTER training:")
context = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(bigram.generate(context, 200)[0].tolist()))


### Section Summary: Stage 1

**What we learned:**
1.  **Training Loop:** The cycle of Guess -> Measure Error -> Update Weights.
2.  **Limitation:** The Bigram model is too simple. It has no "memory". It sees "t" and guesses "h", but it doesn't know if the previous word was "ca" (cat) or "ba" (bat).

**Next Step:** We need a model that can look back at the entire history of words. That is what **Attention** does.


## Stage 2: Tokenization

In Stage 1, we used characters (a, b, c). But modern models like GPT-4 don't read letters; they read **tokens**.

### Why not characters?
If we use characters, the model has to learn that 't', 'h', 'e' always go together. That's a waste of time.

### Why not whole words?
There are too many words in English. The vocabulary would be huge (millions).

### The Solution: Byte Pair Encoding (BPE)
BPE is a middle ground. It groups common pairs of letters together.
- "ing" becomes one token.
- "the" becomes one token.
- Rare words like "Zylophone" might be split into "Zy", "lo", "phone".

We will use `tiktoken`, the exact tokenizer used by OpenAI.


In [None]:
import tiktoken

# Load the tokenizer used by GPT-2 (and GPT-3)
enc = tiktoken.get_encoding("gpt2")

text = "Hello, I am learning AI!"
tokens = enc.encode(text)

print(f"Original text: {text}")
print(f"Token IDs: {tokens}")
print(f"Compression: {len(text)} characters -> {len(tokens)} tokens")

# Let's see what each token represents
for t in tokens:
    print(f"  Token {t} -> '{enc.decode([t])}'")


In [None]:
# We wrap the tokenizer in a class to make it easy to use later.
class Tokenizer:
    def __init__(self):
        self.enc = tiktoken.get_encoding("gpt2")
        self.vocab_size = 50304  # Standard GPT-2 vocabulary size
    
    def encode(self, text):
        # Convert string to list of integers (tensor)
        return torch.tensor(self.enc.encode(text), dtype=torch.long)
    
    def decode(self, tokens):
        # Convert list of integers back to string
        if isinstance(tokens, torch.Tensor): tokens = tokens.tolist()
        return self.enc.decode(tokens)

tokenizer = Tokenizer()
print(f"Vocabulary size: {tokenizer.vocab_size:,} tokens")


### Section Summary: Stage 2

**What we learned:**
- **Tokens** are the atoms of language for AI.
- **BPE** allows us to represent text efficiently.

Now that we have proper data, let's build the brain.


## Stage 3: Attention (The Heart of Transformers)

This is the most important concept in modern AI.

**The Problem:** In a sentence like *"The bank was closed because it was flooded"*, how does the machine know that "it" refers to "bank" and not "closed"?

**The Solution:** **Self-Attention**. Every word looks at every other word and decides how relevant it is.
- "it" asks: "What am I?"
- "bank" answers: "I am a noun, and I can be flooded."
- "it" decides: "Okay, I am referring to the bank."

We will build this mechanism in 4 steps.


### Version 1: The Naive Approach (Averaging)

The simplest way to get context is to just **average** the features of all previous words.
If I am at word 5, I take the average of words 1, 2, 3, 4, and 5.

This is weak (it loses order information), but it's a start.


In [None]:
torch.manual_seed(42)
B, T, C = 4, 8, 2  # Batch=4, Time=8, Channels=2
x = torch.randn(B, T, C)

# We want x[b,t] to be the average of x[b, 0...t]
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        # Slice all previous tokens and average them
        xbow[b, t] = x[b, :t+1].mean(dim=0)

print("Example row 0 (original):", x[0,0])
print("Example row 1 (average of 0 and 1):", xbow[0,1])


### Version 2: Matrix Multiplication (Efficiency)

Loops in Python are slow. We can do the exact same averaging using **Matrix Multiplication**.
We create a triangular matrix of 1s and normalize it. When we multiply our data by this matrix, it automatically computes the running averages.


In [None]:
# Create a lower triangular matrix of 1s
wei = torch.tril(torch.ones(T, T))
# Normalize so rows sum to 1
wei = wei / wei.sum(1, keepdim=True)

print("Weight matrix (notice the lower triangle):")
print(wei)

# Perform matrix multiplication
xbow2 = wei @ x
print(f"\nDoes it match the loop version? {torch.allclose(xbow, xbow2)}")


### Version 3: Softmax (Making it smart)

Averaging is boring. We want the model to **choose** which past words are important.

We use **Softmax**. Softmax takes a list of numbers and turns them into probabilities (they sum to 1).
We also use a **Mask** (setting future positions to negative infinity) so the model can't cheat and look at the future.


In [None]:
# Start with zeros (no preference yet)
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))

# Masking: Set future positions to -infinity
# When we take softmax, e^-inf becomes 0. This prevents looking ahead.
wei = wei.masked_fill(tril == 0, float('-inf'))

# Softmax normalizes the weights
wei = F.softmax(wei, dim=-1)

print("Weights after softmax:")
print(wei)

xbow3 = wei @ x
print(f"Matches? {torch.allclose(xbow, xbow3)}")


### Version 4: Self-Attention (The Real Deal)

Now for the magic. We don't want uniform averages. We want data-dependent attention.

Every token emits three vectors:
1.  **Query (Q):** What am I looking for?
2.  **Key (K):** What do I contain?
3.  **Value (V):** If you find me interesting, here is my information.

**The Analogy:**
- Imagine a library.
- **Query:** You search for "dinosaurs".
- **Key:** The book titles on the spine.
- **Value:** The content inside the book.

The attention score is the match between your Query and the book's Key.


In [None]:
torch.manual_seed(42)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)
head_size = 16

# Three linear layers to generate Q, K, V
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# Generate Q, K, V from the input x
k, q, v = key(x), query(x), value(x)

# Calculate scores: Query dot Key
# (B, T, 16) @ (B, 16, T) -> (B, T, T)
wei = q @ k.transpose(-2, -1) * (head_size ** -0.5)  # Scale by sqrt(d) for stability

# Masking (don't look at future)
wei = wei.masked_fill(torch.tril(torch.ones(T,T)) == 0, float('-inf'))

# Softmax to get probabilities
wei = F.softmax(wei, dim=-1)

# Aggregate values based on attention weights
out = wei @ v

print("Learned attention weights (first batch):")
print(wei[0])


### Visualizing Attention

Let's look at the "Heatmap".
- The **x-axis** is the token being looked at.
- The **y-axis** is the token doing the looking.
- **Color intensity** shows how much attention is being paid.

Notice the triangle shape? That's the causal mask. You can't look at the future!


In [None]:
import matplotlib.pyplot as plt

# Visualize attention weights for batch 0
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Attention heatmap
ax1 = axes[0]
attn_weights = wei[0].detach().numpy()  # Shape: (T, T)
im1 = ax1.imshow(attn_weights, cmap='Blues', aspect='auto')
ax1.set_xlabel('Key Position (tokens being attended to)', fontsize=10)
ax1.set_ylabel('Query Position (current token)', fontsize=10)
ax1.set_title('Attention Heatmap\n(Causal Mask: can only see past + current)', fontsize=11)
ax1.set_xticks(range(T))
ax1.set_yticks(range(T))
plt.colorbar(im1, ax=ax1, label='Attention Weight')

# Add visual markers for masked positions
for i in range(T):
    for j in range(T):
        if j > i:  # Masked (future) positions
            ax1.text(j, i, 'X', ha='center', va='center', color='red', fontsize=10, fontweight='bold')

# Plot 2: Line plot showing attention distribution per position
ax2 = axes[1]
colors = plt.cm.viridis([0.2, 0.4, 0.6, 0.8])
positions_to_show = [0, 2, 5, 7]
for idx, pos in enumerate(positions_to_show):
    ax2.plot(range(T), attn_weights[pos, :], 'o-', color=colors[idx], 
             label=f'Position {pos}', linewidth=2, markersize=8)
ax2.set_xlabel('Key Position (tokens being attended to)', fontsize=10)
ax2.set_ylabel('Attention Weight', fontsize=10)
ax2.set_title('Attention Distribution per Query Position', fontsize=11)
ax2.legend(loc='upper right')
ax2.grid(True, alpha=0.3)
ax2.set_xticks(range(T))
ax2.set_ylim(-0.05, 1.05)

plt.tight_layout()
plt.show()

print("Key observations:")
print("1. Lower-triangular pattern: Can only attend to past + current (causal mask)")
print("2. Position 0 attends ONLY to itself (weight = 1.0)")
print("3. Later positions can distribute attention across more tokens")
print("4. Each row sums to 1.0 (softmax normalization)")
print(f"\nRow sums: {[f'{s:.2f}' for s in attn_weights.sum(axis=1)]}")

### Why do we divide by sqrt(head_size)?

This is a technical detail called **Scaling**.
If we don't divide, the numbers in the dot product get huge. When numbers are huge, Softmax becomes extremely "peaky" (one value is 1, the rest are 0).
This kills the gradient (the signal for learning). Scaling keeps the numbers in a nice range so the model can learn smoothly.


In [None]:
# Demonstrate the scaling problem visually
q_demo, k_demo = torch.randn(8, 64), torch.randn(8, 64)
raw = q_demo @ k_demo.T
scaled = raw * (64 ** -0.5)

print(f"Raw variance: {raw.var():.1f} (grows with dimension!)")
print(f"Scaled variance: {scaled.var():.1f} (stable around 1.0)")
print()

# Visualize the effect on softmax
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Without scaling
raw_probs = F.softmax(raw, dim=-1).detach().numpy()
im1 = axes[0].imshow(raw_probs, cmap='hot', aspect='auto')
axes[0].set_title('WITHOUT Scaling\n(Peaky - almost one-hot!)', fontsize=11)
axes[0].set_xlabel('Key Position')
axes[0].set_ylabel('Query Position')
plt.colorbar(im1, ax=axes[0])

# With scaling
scaled_probs = F.softmax(scaled, dim=-1).detach().numpy()
im2 = axes[1].imshow(scaled_probs, cmap='hot', aspect='auto')
axes[1].set_title('WITH Scaling (divide by sqrt(d))\n(Diffuse - can learn!)', fontsize=11)
axes[1].set_xlabel('Key Position')
axes[1].set_ylabel('Query Position')
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

print("Without scaling: Network is 'overconfident' before learning!")
print("With scaling: Attention is diffuse, allowing gradients to flow and learn.")

### Section Summary: Stage 3

**What we learned:**
- **Self-Attention** allows tokens to talk to each other.
- **Q, K, V** is the mechanism: Query matches Key to get Value.
- **Masking** ensures we generate text one word at a time without cheating.

We now have the brain. Let's upgrade it.


## Stage 4: Modern Components (Llama 3)

The original Transformer was invented in 2017. AI moves fast. We will use the modern improvements found in **Llama 3 (2024)**.

1.  **RMSNorm:** A simpler, faster way to normalize numbers.
2.  **RoPE:** A better way to handle position (angles instead of addition).
3.  **SwiGLU:** A more expressive neuron activation function.


In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        # Learnable scaling parameter
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        # Calculate the Root Mean Square (RMS)
        rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        # Normalize and scale
        return x * rms * self.weight


### 4.2 RoPE (Rotary Positional Embeddings)

**Old way:** Add a number to the embedding to say "I am at position 5".
**New way (RoPE):** Rotate the vector by an angle.

Think of a clock. Position 1 is 1:00, Position 2 is 2:00. The angle tells you the position. This works much better for long sequences.


In [None]:
def precompute_freqs_cis(dim, max_len, theta=10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_len)
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)

def apply_rotary_emb(xq, xk, freqs_cis):
    # xq, xk: (B, T, nh, head_dim)
    # freqs_cis: (T, head_dim//2) complex
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    # Broadcast freqs_cis to (1, T, 1, head_dim//2) for multi-head
    freqs_cis = freqs_cis[:xq.shape[1]].unsqueeze(0).unsqueeze(2)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

### 4.3 SwiGLU

This is the "neuron" part of the brain.
Standard neurons (ReLU) are simple switches: On or Off.
**SwiGLU** is a gated mechanism. It's like a door that can be partially open, controlled by another neuron. It allows for much more complex decision making.


In [None]:
class SwiGLU(nn.Module):
    def __init__(self, dim, hidden=None, bias=False):
        super().__init__()
        hidden = hidden or int(dim * 4 * 2/3)
        self.w1 = nn.Linear(dim, hidden, bias=bias)
        self.w2 = nn.Linear(hidden, dim, bias=bias)
        self.w3 = nn.Linear(dim, hidden, bias=bias)
    
    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

### Section Summary: Stage 4

We have gathered the state-of-the-art parts:
- **RMSNorm** for stability.
- **RoPE** for position.
- **SwiGLU** for thinking.

Now, let's put them all together into a full GPT model.


## Stage 5: Assembling the Full GPT Model

We will now build the full architecture.

**The Blueprint:**
1.  **Embedding:** Convert Token IDs to Vectors.
2.  **Transformer Blocks:** A stack of layers (Communication + Computation).
    - **Communication:** Attention (Tokens talk to each other).
    - **Computation:** FeedForward (Tokens think about what they heard).
3.  **Output Head:** Convert vectors back to probabilities for the next token.

### Configuration Presets

| Mode | Layers | Embed | Params | Memory | Best For |
|------|--------|-------|--------|--------|----------|
| **COLAB_MODE=True** | 6 | 384 | ~30M | ~2GB | Free Colab T4 |
| COLAB_MODE=False | 12 | 768 | ~124M | ~6GB | Better GPU |

In [None]:
from dataclasses import dataclass

# Configuration flags
COLAB_MODE = True  # Use this for free Google Colab T4 GPU

@dataclass
class GPTConfig:
    # block_size: How far back the model can look (context window)
    block_size: int = 512 if COLAB_MODE else 1024
    # vocab_size: Number of unique tokens
    vocab_size: int = 50304
    # n_layer: Number of Transformer blocks (depth)
    n_layer: int = 6 if COLAB_MODE else 12
    # n_head: Number of attention heads (parallel focus)
    n_head: int = 6 if COLAB_MODE else 12
    # n_embd: Size of the vector representing each token (width)
    n_embd: int = 384 if COLAB_MODE else 768
    dropout: float = 0.1 if COLAB_MODE else 0.0
    bias: bool = False

config = GPTConfig()
print(f"Model Configuration:")
print(f"  - Layers: {config.n_layer}")
print(f"  - Embedding Size: {config.n_embd}")
print(f"  - Context Window: {config.block_size}")
print(f"  - Parameter Count: {'~30M' if COLAB_MODE else '~124M'}")


In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head
        
        # Key, Query, Value projections combined into one layer for efficiency
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # Output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
        # Check if we can use Flash Attention (a super fast optimization)
        self.flash = hasattr(F, 'scaled_dot_product_attention')
    
    def forward(self, x, freqs_cis=None):
        B, T, C = x.size()
        # Calculate Q, K, V
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        
        # Reshape for multi-head attention
        # Split the embedding into 'n_head' smaller vectors
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        
        # Apply RoPE (Rotary Positional Embeddings)
        if freqs_cis is not None:
            q = q.transpose(1, 2)
            k = k.transpose(1, 2)
            q, k = apply_rotary_emb(q, k, freqs_cis)
            q = q.transpose(1, 2)
            k = k.transpose(1, 2)
        
        # Calculate attention scores
        if self.flash:
            # Use the optimized PyTorch implementation if available
            y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        else:
            # Manual implementation (for understanding)
            att = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
            att = att.masked_fill(torch.tril(torch.ones(T, T, device=x.device)) == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            y = att @ v
        
        # Reassemble the heads back into one big vector
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.dropout(self.c_proj(y))


In [None]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = RMSNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = RMSNorm(config.n_embd)
        self.mlp = SwiGLU(config.n_embd)
    
    def forward(self, x, freqs_cis):
        # Residual Connection: x = x + ...
        # This allows the signal to flow through the network easily (like a highway)
        x = x + self.attn(self.ln_1(x), freqs_cis)
        x = x + self.mlp(self.ln_2(x))
        return x


In [None]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # The container for all our layers
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd), # Token embeddings
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # The layers
            ln_f = RMSNorm(config.n_embd), # Final normalization
        ))
        # The output head that predicts the next token
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # Weight Tying: We use the same weights for input embedding and output prediction.
        # This makes sense because if 'cat' has a certain vector input, it should have the same vector output.
        self.transformer.wte.weight = self.lm_head.weight
        
        # Precompute RoPE frequencies
        self.freqs_cis = precompute_freqs_cis(
            config.n_embd // config.n_head, config.block_size * 2
        ).to(device)
        
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        # Initialize weights with small random numbers
        if isinstance(m, nn.Linear):
            torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None: torch.nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
    
    def forward(self, idx, targets=None):
        B, T = idx.size()
        
        # 1. Convert token IDs to vectors
        x = self.transformer.wte(idx)
        
        # 2. Get position frequencies
        freqs_cis = self.freqs_cis[:T]
        
        # 3. Run through all Transformer blocks
        for block in self.transformer.h:
            x = block(x, freqs_cis)
        
        # 4. Final normalization
        x = self.transformer.ln_f(x)
        
        # 5. Calculate output (logits) and loss
        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        else:
            # Inference-time optimization: only compute for the last token
            logits = self.lm_head(x[:, [-1], :])
            loss = None
        return logits, loss
    
    @torch.no_grad()
    def generate(self, idx, max_new, temperature=1.0, top_k=None):
        # Generation loop
        for _ in range(max_new):
            # Crop context if it gets too long
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            # Get prediction
            logits, _ = self(idx_cond)
            # Scale by temperature (higher = more random, lower = more focused)
            logits = logits[:, -1, :] / temperature
            # Top-K sampling (optional): only keep the top K most likely options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')
            # Sample
            probs = F.softmax(logits, dim=-1)
            idx = torch.cat((idx, torch.multinomial(probs, 1)), dim=1)
        return idx

model = GPT(config).to(device)
print(f"Model instantiated with {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters.")

# Compile the model (PyTorch 2.0 optimization) for faster training
if hasattr(torch, 'compile'):
    print("Compiling model...")
    model = torch.compile(model)
    print("Model compiled.")


### Section Summary: Stage 5

We have built the robot! It has:
- **Eyes:** Tokenizer + Embedding
- **Brain:** Transformer Blocks (Attention + SwiGLU)
- **Mouth:** LM Head

But it hasn't gone to school yet. It knows nothing.


## Stage 6: Training

We will train the model on **FineWeb-Edu**, a dataset of high-quality educational content.

### Checkpointing Strategy (Important!)
If you are running on a free cloud machine, it might disconnect at any time.
To save our work, we will:
1.  **Save** the model to a local folder named `checkpoints`.
2.  **Download** the `ckpt.pt` file to our own computer manually.
3.  **Upload** it back if we need to resume training later.

### Training Techniques
1.  **Mixed Precision:** We use `float16` (half precision) instead of `float32`. It uses half the memory and runs 2x faster with almost no loss in quality.
2.  **Gradient Accumulation:** We want to simulate a large batch size (e.g., 32) but our GPU can only fit 4. So we run 4 small batches, add up the gradients, and *then* update the weights.
3.  **Cosine Decay:** We start with a high learning rate to learn fast, then slowly lower it to fine-tune the details.


### What to Expect

- **Step 0:** Complete gibberish.
- **Step 500:** It starts to spell words correctly.
- **Step 2000:** It forms sentences.
- **Step 5000:** It writes coherent paragraphs.

**Note:** This is a small model trained for a short time. It won't be Einstein, but it will speak English.
With 6 layers and 512 context, this is a "baby" model.
Real ChatGPT has 100+ layers and 100K+ context.
The goal here is **understanding**, not production quality.

### Troubleshooting: Out of Memory

If your GPU runs out of memory (OOM), try:
1.  Restarting the runtime (Runtime -> Restart Runtime).
2.  Reducing `batch_size` in the code below.
3.  Reducing `block_size` (context window).
4. **Clear GPU cache:** Run this cell:

```python
import gc
gc.collect()
torch.cuda.empty_cache()
```


In [None]:
from datasets import load_dataset
import matplotlib.pyplot as plt
from IPython.display import clear_output

class DataLoader:
    def __init__(self, batch_size, block_size, split='train'):
        self.batch_size = batch_size
        self.block_size = block_size
        # Stream the dataset from Hugging Face so we don't download 500GB at once
        self.dataset = load_dataset("HuggingFaceFW/fineweb-edu", 
            name="sample-10BT", split=split, streaming=True)
        self.iterator = iter(self.dataset)
        self.buffer = []
    
    def __iter__(self): return self
    
    def __next__(self):
        # Fill buffer with tokens until we have enough for a batch
        needed = self.batch_size * self.block_size + 1
        while len(self.buffer) < needed:
            try:
                text = next(self.iterator)['text']
                self.buffer.extend(tokenizer.enc.encode(text))
            except StopIteration:
                self.iterator = iter(self.dataset)
        
        # Take a chunk
        chunk = self.buffer[:needed]
        self.buffer = self.buffer[needed:]
        
        # Create tensors
        data = torch.tensor(chunk, dtype=torch.long)
        x = data[:-1].view(self.batch_size, self.block_size)
        y = data[1:].view(self.batch_size, self.block_size)
        return x.to(device), y.to(device)


In [None]:
# Training Settings
max_iters = 5000              # Total training steps
warmup_iters = 200            # Warmup steps (start slow)
eval_interval = 250           # How often to check progress
save_interval = 500           # How often to save to Google Drive
batch_size = 4 if COLAB_MODE else 8
grad_accum_steps = 4          # Accumulate gradients to simulate larger batch
max_lr = 3e-4                 # Maximum learning rate
min_lr = 1e-5                 # Minimum learning rate

# Learning rate schedule
def get_lr(it):
    # 1. Linear Warmup
    if it < warmup_iters:
        return max_lr * (it + 1) / warmup_iters
    # 2. Cosine Decay
    decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (max_lr - min_lr)

print("Training Configuration Ready.")
print(f"  - Max iterations: {max_iters:,}")
print(f"  - Batch size: {batch_size} × {grad_accum_steps} = {batch_size * grad_accum_steps} effective")
print(f"  - Tokens per step: {batch_size * grad_accum_steps * config.block_size:,}")
print(f"  - LR: {min_lr} → {max_lr} → {min_lr} (warmup + cosine)")
print(f"  - Checkpointing every {save_interval} steps to Google Drive")

### How to Resume Training

If you have a `ckpt.pt` file from a previous session on your computer:
1.  Run the cell below.
2.  Click **"Choose Files"**.
3.  Select your `ckpt.pt` file.
4.  The code will automatically move it to the correct folder so training can resume.


In [None]:
# Run this ONLY if you want to resume from a saved checkpoint file.
from google.colab import files
import shutil

print("Upload your 'ckpt.pt' file here (if you have one)...")
uploaded = files.upload()

for filename in uploaded.keys():
    # Move the uploaded file to our checkpoint directory
    destination = os.path.join(PROJECT_DIR, "ckpt.pt")
    shutil.move(filename, destination)
    print(f"Restored checkpoint to {destination}")


In [None]:
train_loader = DataLoader(batch_size, config.block_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=(0.9, 0.95), weight_decay=0.1)
scaler = torch.amp.GradScaler('cuda')  # For mixed precision

# Checkpoint logic: Resume if file exists in our local folder
CKPT = os.path.join(PROJECT_DIR, "ckpt.pt")
start_iter = 0
if os.path.exists(CKPT):
    print(f"Found checkpoint at {CKPT}. Resuming...")
    ckpt = torch.load(CKPT, map_location=device, weights_only=False)
    state_dict = ckpt['model']
    # Fix for compiled model names
    new_state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict, strict=False)
    optimizer.load_state_dict(ckpt['optimizer'])
    start_iter = ckpt['iter'] + 1
else:
    print("No checkpoint found. Starting fresh.")

losses = []


In [None]:
print("Starting training loop...")
model.train()
t0 = time.time()
running_loss = 0.0

for it in range(start_iter, max_iters):
    # Set learning rate
    lr = get_lr(it)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    # Gradient Accumulation
    optimizer.zero_grad(set_to_none=True)
    for micro_step in range(grad_accum_steps):
        xb, yb = next(train_loader)
        with torch.amp.autocast('cuda', dtype=torch.float16):
            logits, loss = model(xb, yb)
            loss = loss / grad_accum_steps
        scaler.scale(loss).backward()
    
    # Clip gradients (prevent instability)
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    
    # Update weights
    scaler.step(optimizer)
    scaler.update()
    
    running_loss += loss.item() * grad_accum_steps
    
    # Logging
    if it % 10 == 0:
        avg_loss = running_loss / 10 if it > 0 else running_loss
        losses.append(avg_loss)
        running_loss = 0.0
        dt = time.time() - t0; t0 = time.time()
        print(f"Step {it:5d}/{max_iters} | loss={avg_loss:.4f} | lr={lr:.2e} | {dt*1000:.0f}ms")
    
    # Plotting
    if it % eval_interval == 0 and it > 0:
        clear_output(wait=True)
        plt.figure(figsize=(10, 4))
        plt.plot(losses)
        plt.xlabel('Step (x10)')
        plt.ylabel('Loss')
        plt.title(f'Training Progress - Step {it}/{max_iters}')
        plt.grid(True, alpha=0.3)
        plt.show()
    
    # Saving
    if it % save_interval == 0 and it > 0:
        raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
        torch.save({
            'model': raw_model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'iter': it,
            'config': config,
            'losses': losses
        }, CKPT)
        print(f"Checkpoint saved to {CKPT}")
        print("Remember to download this file if you want to keep it!")

print("Training complete.")


In [None]:
# Run this cell to DOWNLOAD your trained model to your computer.
from google.colab import files

if os.path.exists(CKPT):
    print(f"Downloading {CKPT}...")
    files.download(CKPT)
else:
    print("No checkpoint found to download!")


### Section Summary: Stage 6

We have successfully trained the model. It has seen millions of tokens and adjusted its weights to predict the next word.

Now comes the fun part: testing it.


## Stage 7: Inference & Chat

We will now switch the model to **Evaluation Mode** and interact with it.
We use a `temperature` parameter to control creativity:
- **Low Temperature (0.1):** Focused, deterministic, boring.
- **High Temperature (1.0):** Creative, random, sometimes crazy.


In [None]:
def chat(max_tokens=100, temp=0.8):
    model.eval()  # Turn off dropout for consistent results
    print("Chat Interface (Type 'exit' to stop)")
    print("-" * 40)
    while True:
        prompt = input("You: ")
        if prompt.lower() == 'exit': break
        
        # Encode prompt
        idx = tokenizer.encode(prompt).unsqueeze(0).to(device)
        print("AI: ", end="", flush=True)
        
        # Generate tokens one by one
        for _ in range(max_tokens):
            out = model.generate(idx, max_new=1, temperature=temp)
            new_tok = out[0, -1].item()
            print(tokenizer.decode([new_tok]), end="", flush=True)
            idx = out
        print("\n" + "-" * 40)

# Uncomment the line below to run the chat
# chat()


### Section Summary: Stage 7

You now have a working AI chatbot running on your own code.
However, you might notice it just completes sentences rather than answering questions helpfully. That's because it is a **Base Model**.

To make it an **Assistant**, we need Alignment.


## Stage 8: RLHF Alignment (Conceptual)

**Reinforcement Learning from Human Feedback (RLHF)** is how we turn a text completer into a helpful assistant.

**The Process:**
1.  **SFT (Supervised Fine-Tuning):** Train on Q&A datasets.
2.  **Reward Modeling:** Train a judge model to rate answers.
3.  **PPO (Proximal Policy Optimization):** Use the judge to train the AI to get high scores.

Below is a simple demonstration of the **Reward** concept.


In [None]:
def get_reward(text):
    """A dummy reward function. Real ones are complex neural networks."""
    # We reward the model for using positive words
    positive = ["happy", "good", "great", "excellent", "love", "wonderful", "amazing"]
    return sum(1 for w in positive if w in text.lower())

# RLHF Demonstration - CONCEPTUAL ONLY (no learning occurs!)
# Missing for true RLHF: log_probs collection, value function, PPO loss, gradient update
# This demo only shows: sampling/scoring loop (the reward signal concept)
print("Initializing RLHF demonstration...")
prompts = ["Today I feel", "The weather is", "My work is"]

for i in range(50):
    prompt = prompts[i % len(prompts)]
    idx = tokenizer.encode(prompt).unsqueeze(0).to(device)
    
    # 1. Generate multiple samples
    responses = []
    for _ in range(4):
        out = model.generate(idx, max_new=10, temperature=0.9)
        responses.append(tokenizer.decode(out[0].tolist()))
    
    # 2. Score them
    rewards = [get_reward(r) for r in responses]
    avg_reward = sum(rewards) / len(rewards)
    
    if i % 10 == 0:
        print(f"Step {i}: Average Reward = {avg_reward:.2f}")
        best_idx = rewards.index(max(rewards))
        print(f"  Best Response: {responses[best_idx]}")

print("RLHF demonstration complete.")


### Section Summary: Stage 8

In a real scenario, we would use these rewards to update the model's weights using PPO. This encourages the model to generate more "positive" (high reward) text in the future.


## Glossary

| Term | Definition |
|------|------------|
| **BPE** | Byte Pair Encoding. The method for chopping text into tokens. |
| **Logits** | The raw scores output by the model before Softmax. |
| **Softmax** | A math function that turns scores into probabilities (0 to 1). |
| **Temperature** | A setting that controls how random the model's creativity is. |
| **Epoch** | One full pass through the entire training dataset. |
| **Inference** | Using the trained model to generate text. |

---

## Conclusion

You have successfully implemented a complete GPT architecture with modern enhancements.

**Key Concepts Covered:**
- Byte Pair Encoding (BPE)
- Causal Self-Attention and Multi-Head Attention
- RMSNorm, RoPE, and SwiGLU
- Mixed-precision training and gradient accumulation

**Next steps:**
1. Train longer (100k+ steps)
2. Use larger/better datasets
3. Fine-tune on dialogue for assistant behavior
