# Building a Modern Large Language Model from Scratch

**Objective:** Construct a Generative Pre-trained Transformer (GPT) with modern architectural components (Llama 3 style) and train it on a high-quality dataset.

## Table of Contents

| Stage | Topic | Time |
|-------|-------|------|
| 0 | Setup | 5 min |
| 1 | Bigram Model (Simplest LM) | 15 min |
| 2 | Tokenization | 15 min |
| 3 | Attention (4 Versions) | 30 min |
| 4 | Modern Components | 20 min |
| 5 | Full GPT Model | 15 min |
| 6 | Training | 25 min |
| 7 | Inference & Chat | 10 min |
| 8 | RLHF Alignment | 15 min |

---

## Stage 0: Setup

In [None]:
# Mount Google Drive for checkpoints
from google.colab import drive
import os
drive.mount('/content/drive')
PROJECT_DIR = "/content/drive/MyDrive/nanochat_zero"
os.makedirs(PROJECT_DIR, exist_ok=True)
print(f"Project: {PROJECT_DIR}")

In [None]:
# Install dependencies
!pip install -q torch tiktoken datasets matplotlib

import torch
import torch.nn as nn
from torch.nn import functional as F
import math, time
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

## Stage 1: The Simplest Language Model (Bigram)

Before Transformers, let's understand the core idea with the SIMPLEST possible model.

### What is Language Modeling?
Predict the next token: "The cat sat on the ___" â†’ "mat"

### Bigram Model
Only looks at the LAST token to predict. No context at all!

In [None]:
# Download tiny dataset for fast iteration
!wget -q https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r') as f:
    text = f.read()
print(f"Dataset: {len(text):,} characters")
print(text[:200])

In [None]:
# Character-level tokenizer (simplest possible)
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
print(f"Vocab: {vocab_size} chars")
print(f"'hello' -> {encode('hello')} -> {decode(encode('hello'))}")

In [None]:
# Prepare data
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data, val_data = data[:n], data[n:]

block_size, batch_size = 8, 32
def get_batch(split):
    d = train_data if split == 'train' else val_data
    ix = torch.randint(len(d) - block_size, (batch_size,))
    x = torch.stack([d[i:i+block_size] for i in ix])
    y = torch.stack([d[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

### Input/Target Relationship
```
Input:  [H, e, l, l, o]
Target: [e, l, l, o, !]
```
Model learns: "After H, predict e. After e, predict l."

In [None]:
class BigramModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets=None):
        logits = self.embed(idx)
        loss = None
        if targets is not None:
            B, T, C = logits.shape
            loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T))
        return logits, loss
    
    def generate(self, idx, max_new):
        for _ in range(max_new):
            logits, _ = self(idx)
            probs = F.softmax(logits[:, -1, :], dim=-1)
            idx = torch.cat((idx, torch.multinomial(probs, 1)), dim=1)
        return idx

bigram = BigramModel(vocab_size).to(device)
print(f"Params: {sum(p.numel() for p in bigram.parameters()):,}")

In [None]:
# Before training - random garbage
print("BEFORE training:")
print(decode(bigram.generate(torch.zeros((1,1), dtype=torch.long, device=device), 100)[0].tolist()))

In [None]:
# Train bigram
opt = torch.optim.AdamW(bigram.parameters(), lr=1e-3)
for step in range(1000):
    xb, yb = get_batch('train')
    _, loss = bigram(xb, yb)
    opt.zero_grad(); loss.backward(); opt.step()
    if step % 200 == 0: print(f"Step {step}: loss={loss.item():.4f}")
print(f"Final: {loss.item():.4f}")

In [None]:
# After training - slightly better garbage
print("AFTER training:")
print(decode(bigram.generate(torch.zeros((1,1), dtype=torch.long, device=device), 200)[0].tolist()))

### Section Summary: Stage 1

**Learned:** Training loop (forward â†’ loss â†’ backward â†’ update)

**Problem:** Bigram only sees the LAST token. In "The cat sat", it only sees "t" from "sat".

**Solution:** ATTENTION - look at ALL previous tokens!

---

## Stage 2: Tokenization

### Character vs Subword
| Text | Chars | BPE (GPT) |
|------|-------|-----------|
| "Hello" | 5 | 1 |
| "The quick fox" | 13 | 4 |

BPE compresses common patterns into single tokens.

### BPE Intuition
Start with characters, merge frequent pairs:
```
['H','e','l','l','o'] â†’ ['He','l','l','o'] â†’ ['Hell','o'] â†’ ['Hello']
```

Building BPE from scratch = ~150 lines. We'll use tiktoken.

In [None]:
import tiktoken
enc = tiktoken.get_encoding("gpt2")

text = "Hello, I am learning AI!"
tokens = enc.encode(text)
print(f"Text: {text}")
print(f"Tokens: {tokens}")
print(f"{len(text)} chars -> {len(tokens)} tokens ({len(text)/len(tokens):.1f}x compression)")

# Decode each
for t in tokens:
    print(f"  {t} -> {repr(enc.decode([t]))}")

In [None]:
# Tokenizer wrapper
class Tokenizer:
    def __init__(self):
        self.enc = tiktoken.get_encoding("gpt2")
        self.vocab_size = 50304  # Padded for efficiency
    
    def encode(self, text):
        return torch.tensor(self.enc.encode(text), dtype=torch.long)
    
    def decode(self, tokens):
        if isinstance(tokens, torch.Tensor): tokens = tokens.tolist()
        return self.enc.decode(tokens)

tokenizer = Tokenizer()
print(f"Vocab size: {tokenizer.vocab_size:,}")

### Section Summary: Stage 2

**Learned:** BPE merges frequent patterns for compression.

---

## Stage 3: Attention (The Heart of Transformers)

Goal: For each position, aggregate info from ALL previous positions.

We'll build this in 4 versions, each adding one insight.

### Version 1: Naive Averaging (For Loop)
For position t, average all positions 0...t.

In [None]:
torch.manual_seed(42)
B, T, C = 4, 8, 2  # batch, time, channels
x = torch.randn(B, T, C)

# Slow but clear
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xbow[b, t] = x[b, :t+1].mean(dim=0)

print("Row 0 = just x[0,0]")
print("Row 1 = avg(x[0,0:2])")
print(xbow[0])

### Version 2: Matrix Multiplication Trick
Lower-triangular matrix does the same thing, but batched!

In [None]:
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
print("Weight matrix:")
print(wei)
xbow2 = wei @ x
print(f"\nSame result? {torch.allclose(xbow, xbow2)}")

### Version 3: Softmax for Learnable Weights
Use softmax to convert scores to probabilities.
Mask future with -inf (becomes 0 after softmax).

In [None]:
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
print(wei)
xbow3 = wei @ x
print(f"Same? {torch.allclose(xbow, xbow3)}")

### Version 4: Self-Attention
Instead of fixed weights, LEARN what to attend to!

- **Query (Q):** "What am I looking for?"
- **Key (K):** "What do I contain?"
- **Value (V):** "What information do I provide?"

Attention weight = how well Q matches K.

In [None]:
torch.manual_seed(42)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)
head_size = 16

key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k, q, v = key(x), query(x), value(x)

# Attention scores
wei = q @ k.transpose(-2, -1) * (head_size ** -0.5)  # Scale!
wei = wei.masked_fill(torch.tril(torch.ones(T,T)) == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ v

print("Attention weights (learned!):")
print(wei[0])

### Visualizing Attention Patterns

Let's **see** what the attention weights actually look like!

The heatmap shows: **"How much does position i attend to position j?"**
- **Rows** = query positions (current token asking "what should I look at?")
- **Columns** = key positions (tokens being looked at)
- **Brighter colors** = higher attention weight

In [None]:
import matplotlib.pyplot as plt

# Visualize attention weights for batch 0
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Attention heatmap
ax1 = axes[0]
attn_weights = wei[0].detach().numpy()  # Shape: (T, T)
im1 = ax1.imshow(attn_weights, cmap='Blues', aspect='auto')
ax1.set_xlabel('Key Position (tokens being attended to)', fontsize=10)
ax1.set_ylabel('Query Position (current token)', fontsize=10)
ax1.set_title('Attention Heatmap\n(Causal Mask: can only see past + current)', fontsize=11)
ax1.set_xticks(range(T))
ax1.set_yticks(range(T))
plt.colorbar(im1, ax=ax1, label='Attention Weight')

# Add visual markers for masked positions
for i in range(T):
    for j in range(T):
        if j > i:  # Masked (future) positions
            ax1.text(j, i, 'X', ha='center', va='center', color='red', fontsize=10, fontweight='bold')

# Plot 2: Line plot showing attention distribution per position
ax2 = axes[1]
colors = plt.cm.viridis([0.2, 0.4, 0.6, 0.8])
positions_to_show = [0, 2, 5, 7]
for idx, pos in enumerate(positions_to_show):
    ax2.plot(range(T), attn_weights[pos, :], 'o-', color=colors[idx], 
             label=f'Position {pos}', linewidth=2, markersize=8)
ax2.set_xlabel('Key Position (tokens being attended to)', fontsize=10)
ax2.set_ylabel('Attention Weight', fontsize=10)
ax2.set_title('Attention Distribution per Query Position', fontsize=11)
ax2.legend(loc='upper right')
ax2.grid(True, alpha=0.3)
ax2.set_xticks(range(T))
ax2.set_ylim(-0.05, 1.05)

plt.tight_layout()
plt.show()

print("Key observations:")
print("1. Lower-triangular pattern: Can only attend to past + current (causal mask)")
print("2. Position 0 attends ONLY to itself (weight = 1.0)")
print("3. Later positions can distribute attention across more tokens")
print("4. Each row sums to 1.0 (softmax normalization)")
print(f"\nRow sums: {[f'{s:.2f}' for s in attn_weights.sum(axis=1)]}")

### Why Scale by âˆšd?
Without scaling, large dot products â†’ peaky softmax â†’ vanishing gradients.

Let's visualize this problem:

In [None]:
# Demonstrate the scaling problem visually
q_demo, k_demo = torch.randn(8, 64), torch.randn(8, 64)
raw = q_demo @ k_demo.T
scaled = raw * (64 ** -0.5)

print(f"Raw variance: {raw.var():.1f} (grows with dimension!)")
print(f"Scaled variance: {scaled.var():.1f} (stable around 1.0)")
print()

# Visualize the effect on softmax
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Without scaling
raw_probs = F.softmax(raw, dim=-1).detach().numpy()
im1 = axes[0].imshow(raw_probs, cmap='hot', aspect='auto')
axes[0].set_title('WITHOUT Scaling\n(Peaky - almost one-hot!)', fontsize=11)
axes[0].set_xlabel('Key Position')
axes[0].set_ylabel('Query Position')
plt.colorbar(im1, ax=axes[0])

# With scaling
scaled_probs = F.softmax(scaled, dim=-1).detach().numpy()
im2 = axes[1].imshow(scaled_probs, cmap='hot', aspect='auto')
axes[1].set_title('WITH Scaling (divide by sqrt(d))\n(Diffuse - can learn!)', fontsize=11)
axes[1].set_xlabel('Key Position')
axes[1].set_ylabel('Query Position')
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

print("Without scaling: Network is 'overconfident' before learning!")
print("With scaling: Attention is diffuse, allowing gradients to flow and learn.")

### Section Summary: Stage 3

**4 Versions of Attention:**
1. For-loop averaging (slow but clear)
2. Matrix multiply (fast, batched)
3. Softmax + mask (differentiable, causal)
4. Q, K, V (learned, data-dependent)

**Key insight:** Self-attention = learned weighted average!

**Visualization showed:**
- Causal mask creates lower-triangular attention pattern
- Scaling prevents softmax from becoming one-hot

---

## Stage 4: Modern Components

### 4.1 RMSNorm
LayerNorm: subtract mean, divide by std (2 stats)
RMSNorm: divide by RMS only (1 stat, 10-15% faster)

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * rms * self.weight

### 4.2 RoPE (Rotary Position Embeddings)
Old: Add position vector (position info can get drowned)
New: Rotate embeddings (relative position = angle difference)

In [None]:
def precompute_freqs_cis(dim, max_len, theta=10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_len)
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)

def apply_rotary_emb(xq, xk, freqs_cis):
    # xq, xk: (B, T, nh, head_dim)
    # freqs_cis: (T, head_dim//2) complex
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    # Broadcast freqs_cis to (1, T, 1, head_dim//2) for multi-head
    freqs_cis = freqs_cis[:xq.shape[1]].unsqueeze(0).unsqueeze(2)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

### 4.3 SwiGLU
Old: Linear â†’ ReLU â†’ Linear
New: Linear â†’ SiLU Ã— Gate â†’ Linear (more expressive)

In [None]:
class SwiGLU(nn.Module):
    def __init__(self, dim, hidden=None, bias=False):
        super().__init__()
        hidden = hidden or int(dim * 4 * 2/3)
        self.w1 = nn.Linear(dim, hidden, bias=bias)
        self.w2 = nn.Linear(hidden, dim, bias=bias)
        self.w3 = nn.Linear(dim, hidden, bias=bias)
    
    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

### Section Summary: Stage 4
**Modern upgrades:** RMSNorm, RoPE, SwiGLU

---

## Stage 5: Full GPT Model

Now we assemble all components!

```
Tokens â†’ Embed â†’ [Block Ã— N] â†’ Norm â†’ Logits
Block = Attention + FFN (with residuals)
```

### Configuration Presets

| Mode | Layers | Embed | Params | Memory | Best For |
|------|--------|-------|--------|--------|----------|
| **COLAB_MODE=True** | 6 | 384 | ~30M | ~2GB | Free Colab T4 |
| COLAB_MODE=False | 12 | 768 | ~124M | ~6GB | Better GPU |

In [None]:
from dataclasses import dataclass

# Toggle for Colab Free Tier compatibility
COLAB_MODE = True  # Set False if you have a better GPU (A100, H100, RTX 4090)

@dataclass
class GPTConfig:
    # For Colab T4: smaller model that fits in 16GB VRAM
    # For better GPU: larger model for better quality
    block_size: int = 512 if COLAB_MODE else 1024
    vocab_size: int = 50304  # GPT-2 vocab, padded for efficiency
    n_layer: int = 6 if COLAB_MODE else 12
    n_head: int = 6 if COLAB_MODE else 12
    n_embd: int = 384 if COLAB_MODE else 768
    dropout: float = 0.1 if COLAB_MODE else 0.0  # Regularization helps small models
    bias: bool = False

config = GPTConfig()
print(f"Config: {config.n_layer}L, {config.n_embd}d, context={config.block_size}")
print(f"Mode: {'Colab-Friendly (30M params)' if COLAB_MODE else 'Full (124M params)'}")
print(f"Estimated memory: {'~2GB' if COLAB_MODE else '~6GB'}")
print()
print("To use full config, set COLAB_MODE = False above.")

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head
        
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
        self.flash = hasattr(F, 'scaled_dot_product_attention')
    
    def forward(self, x, freqs_cis=None):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        
        if freqs_cis is not None:
            q = q.transpose(1, 2)
            k = k.transpose(1, 2)
            q, k = apply_rotary_emb(q, k, freqs_cis)
            q = q.transpose(1, 2)
            k = k.transpose(1, 2)
        
        if self.flash:
            y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        else:
            att = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
            att = att.masked_fill(torch.tril(torch.ones(T, T, device=x.device)) == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            y = att @ v
        
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.dropout(self.c_proj(y))

In [None]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = RMSNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = RMSNorm(config.n_embd)
        self.mlp = SwiGLU(config.n_embd)
    
    def forward(self, x, freqs_cis):
        x = x + self.attn(self.ln_1(x), freqs_cis)
        x = x + self.mlp(self.ln_2(x))
        return x

In [None]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = RMSNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # Weight tying: share embedding weights with output layer
        # Why? The "meaning" of token embeddings and predictions should be consistent
        # Bonus: Saves ~19M parameters (vocab_size Ã— n_embd = 50304 Ã— 384 = 19.3M)
        self.transformer.wte.weight = self.lm_head.weight
        
        self.freqs_cis = precompute_freqs_cis(
            config.n_embd // config.n_head, config.block_size * 2
        ).to(device)
        
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None: torch.nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
    
    def forward(self, idx, targets=None):
        B, T = idx.size()
        x = self.transformer.wte(idx)
        freqs_cis = self.freqs_cis[:T]
        for block in self.transformer.h:
            x = block(x, freqs_cis)
        x = self.transformer.ln_f(x)
        
        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None
        return logits, loss
    
    @torch.no_grad()
    def generate(self, idx, max_new, temperature=1.0, top_k=None):
        for _ in range(max_new):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')
            probs = F.softmax(logits, dim=-1)
            idx = torch.cat((idx, torch.multinomial(probs, 1)), dim=1)
        return idx

model = GPT(config).to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")

# Apply torch.compile for ~1.5-2x speedup (PyTorch 2.0+)
if hasattr(torch, 'compile'):
    print("Compiling model with torch.compile...")
    model = torch.compile(model)
    print("Model compiled successfully!")

### Section Summary: Stage 5
**Built:** Full GPT with RMSNorm, RoPE, SwiGLU, Flash Attention, torch.compile

---

## Stage 6: Training

Using FineWeb-Edu (high-quality educational text).

### Training Features
- **Mixed Precision**: float16 for 2x memory savings
- **Gradient Accumulation**: Larger effective batch size
- **LR Warmup + Cosine Decay**: Standard best practice
- **Checkpointing**: Resume training across Colab sessions

### What to Expect During Training

| Steps | Loss | Text Quality |
|-------|------|--------------|
| 0 | ~10 | Random noise |
| 500 | ~6-7 | Some words recognizable |
| 2000 | ~5 | Sentences form (broken grammar) |
| 5000+ | ~4 | Coherent paragraphs |

> **Note:** With 6 layers and 512 context, this is a "baby" model.
> Real ChatGPT has 100+ layers and 100K+ context.
> The goal here is **understanding**, not production quality.

### If You Get "CUDA Out of Memory"

1. **Reduce batch_size:** Change `batch_size = 4` â†’ `batch_size = 2` or `1`
2. **Reduce layers:** Set `COLAB_MODE = True` above (uses 6 layers)
3. **Restart runtime:** Runtime â†’ Restart runtime
4. **Clear GPU cache:** Run this cell:

```python
import gc
gc.collect()
torch.cuda.empty_cache()
```

5. **Reduce context:** Change `block_size` from 512 to 256

In [None]:
from datasets import load_dataset
import matplotlib.pyplot as plt
from IPython.display import clear_output

class DataLoader:
    def __init__(self, batch_size, block_size, split='train'):
        self.batch_size = batch_size
        self.block_size = block_size
        self.dataset = load_dataset("HuggingFaceFW/fineweb-edu", 
            name="sample-10BT", split=split, streaming=True)
        self.iterator = iter(self.dataset)
        self.buffer = []
    
    def __iter__(self): return self
    
    def __next__(self):
        needed = self.batch_size * self.block_size + 1
        while len(self.buffer) < needed:
            try:
                text = next(self.iterator)['text']
                self.buffer.extend(tokenizer.enc.encode(text))
            except StopIteration:
                self.iterator = iter(self.dataset)
        
        chunk = self.buffer[:needed]
        self.buffer = self.buffer[needed:]
        data = torch.tensor(chunk, dtype=torch.long)
        x = data[:-1].view(self.batch_size, self.block_size)
        y = data[1:].view(self.batch_size, self.block_size)
        return x.to(device), y.to(device)

In [None]:
# Training hyperparameters
max_iters = 5000              # P1: Increased from 1000 for meaningful training
warmup_iters = 200            # P1: LR warmup for stability
eval_interval = 250           # Show progress every 250 steps
save_interval = 500           # Checkpoint every 500 steps (important for Colab!)
batch_size = 4 if COLAB_MODE else 8  # Smaller batch for Colab
grad_accum_steps = 4          # P2: Effective batch = batch_size Ã— 4 = 16 or 32
max_lr = 3e-4
min_lr = 1e-5

# Learning rate schedule: warmup + cosine decay
def get_lr(it):
    # Linear warmup
    if it < warmup_iters:
        return max_lr * (it + 1) / warmup_iters
    # Cosine decay after warmup
    decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (max_lr - min_lr)

print(f"Training config:")
print(f"  - Max iterations: {max_iters:,}")
print(f"  - Batch size: {batch_size} Ã— {grad_accum_steps} = {batch_size * grad_accum_steps} effective")
print(f"  - Tokens per step: {batch_size * grad_accum_steps * config.block_size:,}")
print(f"  - LR: {min_lr} â†’ {max_lr} â†’ {min_lr} (warmup + cosine)")
print(f"  - Checkpointing every {save_interval} steps to Google Drive")

In [None]:
# Initialize training components
train_loader = DataLoader(batch_size, config.block_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=(0.9, 0.95), weight_decay=0.1)
scaler = torch.amp.GradScaler('cuda')  # For mixed precision training

# Resume if checkpoint exists
CKPT = os.path.join(PROJECT_DIR, "ckpt.pt")
start_iter = 0
if os.path.exists(CKPT):
    ckpt = torch.load(CKPT, map_location=device, weights_only=False)  # Contains config dict
    # Handle compiled model state dict
    state_dict = ckpt['model']
    # Remove '_orig_mod.' prefix if present (from torch.compile)
    new_state_dict = {}
    for k, v in state_dict.items():
        new_key = k.replace('_orig_mod.', '')
        new_state_dict[new_key] = v
    model.load_state_dict(new_state_dict, strict=False)
    optimizer.load_state_dict(ckpt['optimizer'])
    start_iter = ckpt['iter'] + 1
    print(f"Resumed from step {start_iter}")
else:
    print("Starting fresh training")

losses = []

In [None]:
# Training loop with gradient accumulation
print("Starting training loop...")
print(f"Will train for {max_iters:,} iterations")
print(f"Estimated time on T4: {max_iters * 0.3 / 60:.1f} minutes")
print()

model.train()
t0 = time.time()
running_loss = 0.0

for it in range(start_iter, max_iters):
    # Update learning rate (P1: LR warmup + cosine decay)
    lr = get_lr(it)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    # Gradient accumulation loop (P2)
    optimizer.zero_grad(set_to_none=True)
    for micro_step in range(grad_accum_steps):
        xb, yb = next(train_loader)
        with torch.amp.autocast('cuda', dtype=torch.float16):
            logits, loss = model(xb, yb)
            loss = loss / grad_accum_steps  # Scale loss for accumulation
        scaler.scale(loss).backward()
    
    # Gradient clipping for stability
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    
    scaler.step(optimizer)
    scaler.update()
    
    running_loss += loss.item() * grad_accum_steps
    
    # Logging
    if it % 10 == 0:
        avg_loss = running_loss / 10 if it > 0 else running_loss
        losses.append(avg_loss)
        running_loss = 0.0
        dt = time.time() - t0; t0 = time.time()
        print(f"Step {it:5d}/{max_iters} | loss={avg_loss:.4f} | lr={lr:.2e} | {dt*1000:.0f}ms")
    
    # Visualization
    if it % eval_interval == 0 and it > 0:
        clear_output(wait=True)
        plt.figure(figsize=(10, 4))
        plt.plot(losses)
        plt.xlabel('Step (x10)')
        plt.ylabel('Loss')
        plt.title(f'Training Progress - Step {it}/{max_iters}')
        plt.grid(True, alpha=0.3)
        plt.show()
    
    # Checkpointing to Google Drive (critical for Colab!)
    if it % save_interval == 0 and it > 0:
        # Get raw model state dict (handle torch.compile)
        raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
        torch.save({
            'model': raw_model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'iter': it,
            'config': config,
            'losses': losses
        }, CKPT)
        print(f"ðŸ’¾ Checkpoint saved to Google Drive (step {it})")

print()
print("âœ… Training complete!")
print(f"Final loss: {losses[-1]:.4f}")
print(f"Checkpoint saved to: {CKPT}")

### Section Summary: Stage 6
**Training features:**
- Mixed precision (float16) for 2x memory savings
- Gradient accumulation for larger effective batch
- LR warmup + cosine decay for stable training
- Google Drive checkpointing for Colab session recovery
- Gradient clipping for stability

---

## Stage 7: Inference & Chat

Now talk to your model!
Note: This is a BASE model (completion), not an assistant (yet).

In [None]:
def chat(max_tokens=100, temp=0.8):
    model.eval()  # Switch to evaluation mode
    print("ðŸ’¬ Chat (type 'exit' to stop)")
    print("-" * 40)
    while True:
        prompt = input("You: ")
        if prompt.lower() == 'exit': break
        
        idx = tokenizer.encode(prompt).unsqueeze(0).to(device)
        print("AI: ", end="", flush=True)
        
        for _ in range(max_tokens):
            out = model.generate(idx, max_new=1, temperature=temp)
            new_tok = out[0, -1].item()
            print(tokenizer.decode([new_tok]), end="", flush=True)
            idx = out
        print("\n" + "-" * 40)

# Uncomment to start interactive chat:
# chat()

### Section Summary: Stage 7
**Built:** Interactive chat with streaming output

---

## Stage 8: RLHF Alignment

**Goal:** Teach the model specific behaviors using Reinforcement Learning.

**Simple example:** Make it "positive" (reward words like happy, good, great).

In [None]:
def get_reward(text):
    """Simple reward: +1 for each positive word"""
    positive = ["happy", "good", "great", "excellent", "love", "wonderful", "amazing"]
    return sum(1 for w in positive if w in text.lower())

# RLHF Demonstration - CONCEPTUAL ONLY (no learning occurs!)
# Missing for true RLHF: log_probs collection, value function, PPO loss, gradient update
# This demo only shows: sampling/scoring loop (the reward signal concept)
print("Initializing RLHF demonstration...")
prompts = ["Today I feel", "The weather is", "My work is"]

for i in range(50):
    prompt = prompts[i % len(prompts)]
    idx = tokenizer.encode(prompt).unsqueeze(0).to(device)
    
    # Generate 4 samples (in real RLHF, we'd track log_probs here)
    responses = []
    for _ in range(4):
        out = model.generate(idx, max_new=10, temperature=0.9)
        responses.append(tokenizer.decode(out[0].tolist()))
    
    # Score responses
    rewards = [get_reward(r) for r in responses]
    avg_reward = sum(rewards) / len(rewards)
    
    if i % 10 == 0:
        print(f"Step {i}: Avg reward = {avg_reward:.2f}")
        print(f"  Best: {responses[rewards.index(max(rewards))]}")

print("RLHF demonstration complete.")

### Section Summary: Stage 8

**RLHF basics:**
1. Generate multiple samples
2. Score with reward function
3. Encourage high-scoring outputs

---

## Conclusion

You have successfully implemented a complete GPT architecture with modern enhancements.

**Key Concepts Covered:**
- Byte Pair Encoding (BPE)
- Causal Self-Attention and Multi-Head Attention
- RMSNorm, RoPE, and SwiGLU
- Mixed-precision training and gradient accumulation

**Next steps:**
1. Train longer (100k+ steps)
2. Use larger/better datasets
3. Fine-tune on dialogue for assistant behavior