# Lesson 07: Training Tricks (Warmup + Cosine + AMP + Accumulation)

This notebook builds a tiny GPT-like decoder and uses it to show **training-engineering tricks** that matter in real systems:

- **AdamW + weight decay** (why it is different from L2 in Adam)
- **learning rate warmup + cosine decay** (stability and smooth convergence)
- **gradient accumulation** (simulate larger batches without more GPU memory)
- **optional AMP** (mixed precision for speed on GPU)

The goal is not state-of-the-art performance. The goal is a *clear, runnable training loop* where you can see how each trick changes the mechanics.


## 1) Setup

We set seeds for reproducibility, pick a device, and define configuration knobs. The values are intentionally small so the notebook can run on a laptop.


In [None]:
import math
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

from datasets import load_dataset


In [None]:
# Configuration
SEED = 42  # random seed
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # compute device

BATCH_SIZE = 32  # batch size
BLOCK_SIZE = 128  # context length (tokens)
EMB_SIZE = 256  # embedding size
N_LAYERS = 4  # transformer layers
N_HEADS = 4  # attention heads
DROPOUT = 0.1  # dropout prob

MAX_STEPS = 1000  # training steps
EVAL_INTERVAL = 100  # eval interval (steps)
EVAL_ITERS = 20  # eval batches

LEARNING_RATE = 3e-4  # learning rate
WEIGHT_DECAY = 0.1  # weight decay
WARMUP_STEPS = 100  # lr warmup steps
ACCUM_STEPS = 4  # gradient accumulation steps
AMP_ENABLED = torch.cuda.is_available()  # use mixed precision

random.seed(SEED)
torch.manual_seed(SEED)

print('Device:', DEVICE)


## 2) Dataset + tokenizer (WikiText-2)

We use the `wikitext-2-raw-v1` dataset from `torchtext`. We build a basic English tokenizer and a vocabulary, then concatenate tokens into a 1D tensor.

**Note:** The first run will download the dataset.


In [None]:
# Load dataset (wikitext-2-raw-v1) using Hugging Face datasets
ds = load_dataset("wikitext", "wikitext-2-raw-v1")  # dataset object
train_text = "\n".join(ds["train"]["text"])  # training text
val_text = "\n".join(ds["validation"]["text"]) if "validation" in ds else ""  # validation text

# Build a simple character-level vocabulary
text = train_text + "\n" + val_text  # raw text
chars = sorted(list(set(text)))  # unique characters
VOCAB_SIZE = len(chars)  # vocab size
stoi = {ch: i for i, ch in enumerate(chars)}  # string-to-id map
itos = {i: ch for i, ch in enumerate(chars)}  # id-to-string map

def encode(s):
    return [stoi[c] for c in s]

def decode(ids):
    return ''.join([itos[i] for i in ids])

data = torch.tensor(encode(text), dtype=torch.long)  # raw text data
n = int(0.9 * len(data))  # length or count
train_data = data[:n]  # training split data
val_data = data[n:]  # validation split data

print('Vocab size:', len(chars))
print('Train tokens:', len(train_data), 'Val tokens:', len(val_data))

## 3) GPT-like decoder (simple)

We implement a tiny, decoder-only Transformer with:
- token + position embeddings
- stacked self-attention blocks with causal mask
- linear head for next-token prediction

This is intentionally small and minimal so the training loop stays readable.


In [None]:
class Block(nn.Module):
    def __init__(self, emb_size, n_heads, dropout):
        super().__init__()
        self.ln1 = nn.LayerNorm(emb_size)
        self.attn = nn.MultiheadAttention(
            embed_dim=emb_size,
            num_heads=n_heads,
            dropout=dropout,
            batch_first=True,
        )
        self.ln2 = nn.LayerNorm(emb_size)
        self.mlp = nn.Sequential(
            nn.Linear(emb_size, 4 * emb_size),
            nn.GELU(),
            nn.Linear(4 * emb_size, emb_size),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # Causal mask prevents attending to future tokens
        T = x.size(1)  # temperature
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        x_norm = self.ln1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, attn_mask=mask, need_weights=False)
        x = x + attn_out
        x = x + self.mlp(self.ln2(x))
        return x


class GPT(nn.Module):
    def __init__(self, vocab_size, block_size, emb_size, n_layers, n_heads, dropout):
        super().__init__()
        self.block_size = block_size
        self.token_emb = nn.Embedding(vocab_size, emb_size)
        self.pos_emb = nn.Embedding(block_size, emb_size)
        self.blocks = nn.ModuleList([
            Block(emb_size, n_heads, dropout) for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(emb_size)
        self.head = nn.Linear(emb_size, vocab_size, bias=False)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        if T > self.block_size:
            raise ValueError("Sequence length exceeds block size")

        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
        x = self.token_emb(idx) + self.pos_emb(pos)

        for block in self.blocks:
            x = block(x)

        x = self.ln_f(x)
        logits = self.head(x)

        loss = None
        if targets is not None:
            # Cross-entropy on next-token prediction
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss


model = GPT(  # model instance
    vocab_size=VOCAB_SIZE,
    block_size=BLOCK_SIZE,
    emb_size=EMB_SIZE,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    dropout=DROPOUT,
).to(DEVICE)

print("Model parameters:", sum(p.numel() for p in model.parameters()) / 1e6, "M")

## 4) Training utilities

We will:
- sample random batches from the token stream
- compute evaluation loss + perplexity
- build a warmup + cosine learning-rate schedule

**Why AdamW?** Weight decay should *not* be part of Adam's adaptive moment estimates. `AdamW` decouples decay from the gradient update, which improves generalization.


In [None]:
@torch.no_grad()
def estimate_loss(model, train_data, val_data, eval_iters):
    model.eval()
    out = {}
    for split, data in [("train", train_data), ("val", val_data)]:
        losses = []
        for _ in range(eval_iters):
            xb, yb = get_batch(data, BATCH_SIZE, BLOCK_SIZE)
            _, loss = model(xb, yb)
            losses.append(loss.item())
        out[split] = sum(losses) / len(losses)
    model.train()
    return out


def get_batch(data, batch_size, block_size):
    # Randomly sample starting positions
    max_start = len(data) - block_size - 1
    idx = torch.randint(0, max_start, (batch_size,))
    x = torch.stack([data[i : i + block_size] for i in idx])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in idx])
    return x.to(DEVICE), y.to(DEVICE)


def warmup_cosine_lr(step, warmup_steps, total_steps):
    # Linear warmup, then cosine decay to zero
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    return 0.5 * (1.0 + math.cos(math.pi * progress))

## 5) Training loop (accumulation + warmup/cosine + AMP)

Key mechanics:
- **Gradient accumulation** divides the loss by `ACCUM_STEPS`, backprops multiple micro-batches, then steps once.
- **Warmup + cosine** updates LR each step (not each epoch) for smoother optimization.
- **AMP** uses mixed precision on GPU via `autocast` and `GradScaler`.

We also print the learning rate occasionally so you can see the schedule in action.


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)  # optimizer instance

# LR scheduler uses a lambda in [0, 1] multiplied by the base LR
lr_lambda = lambda step: warmup_cosine_lr(step, WARMUP_STEPS, MAX_STEPS)  # learning-rate schedule fn
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)  # LR scheduler

scaler = torch.cuda.amp.GradScaler(enabled=AMP_ENABLED)  # grad scaler for AMP

model.train()
optimizer.zero_grad(set_to_none=True)

# optimization steps
for step in range(1, MAX_STEPS + 1):
    # Accumulate gradients over multiple micro-batches
    for micro_step in range(ACCUM_STEPS):
        xb, yb = get_batch(train_data, BATCH_SIZE, BLOCK_SIZE)
        with torch.cuda.amp.autocast(enabled=AMP_ENABLED):
            _, loss = model(xb, yb)
            loss = loss / ACCUM_STEPS  # scale loss for accumulation

        if AMP_ENABLED:
            scaler.scale(loss).backward()
        else:
            loss.backward()

    # One optimizer step after accumulation
    if AMP_ENABLED:
        scaler.step(optimizer)
        scaler.update()
    else:
        optimizer.step()

    optimizer.zero_grad(set_to_none=True)
    scheduler.step()

    # Periodic evaluation
    if step % EVAL_INTERVAL == 0 or step == 1:
        losses = estimate_loss(model, train_data, val_data, EVAL_ITERS)
        train_ppl = math.exp(losses["train"])
        val_ppl = math.exp(losses["val"])
        lr = optimizer.param_groups[0]["lr"]
        print(
            f"step {step:5d} | lr {lr:.2e} | "
            f"train loss {losses['train']:.3f} (ppl {train_ppl:.1f}) | "
            f"val loss {losses['val']:.3f} (ppl {val_ppl:.1f})"
        )

## 6) Inference: tiny text generation

We do a simple greedy generation. For more creative text, use temperature and sampling.


In [None]:
@torch.no_grad()
def generate(model, prompt, max_new_tokens=50, temperature=1.0):
    model.eval()
    ids = encode(prompt)
    if not ids:
        ids = [0]
    idx = torch.tensor([ids], dtype=torch.long).to(DEVICE)

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -BLOCK_SIZE:]
        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / max(1e-8, temperature)
        probs = F.softmax(logits, dim=-1)
        next_id = torch.argmax(probs, dim=-1, keepdim=True)
        idx = torch.cat([idx, next_id], dim=1)

    text = decode(idx[0].tolist())
    return text

print(generate(model, 'The meaning of life is'))


## 7) Scaling notes

Real training setups go far beyond this notebook:

- **Distributed Data Parallel (DDP)** to split large batches across many GPUs.
- **Larger batch sizes** using both data-parallelism and gradient accumulation.
- **Mixed precision** plus kernel fusions for faster throughput.
- **Activation checkpointing** to save memory.
- **Longer schedules** and more aggressive regularization.

The mechanics you just saw (warmup, cosine decay, AdamW, accumulation, AMP) are still the core building blocks.


## 8) Exercises

1. Increase `BLOCK_SIZE` and compare training loss and generated text quality.
2. Turn off AMP and measure training speed on GPU.
3. Increase `ACCUM_STEPS` and reduce `BATCH_SIZE`. Does training stay stable?
4. Replace the cosine schedule with a constant LR. Compare perplexity.
5. Add gradient clipping and see if it changes stability.
