<a href="https://colab.research.google.com/github/Schwaldlander/nanogpt_demo/blob/main/train_gpt2_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine‑tune GPT‑2 from scratch on EDU FineWeb‑10B

This Colab notebook refactors **`train_gpt2.py`** (from https://github.com/karpathy/build-nanogpt/blob/master/train_gpt2.py) into a step‑by‑step workflow. Releant explanation can be found in [https://www.youtube.com/watch?v=l8pRSuU81PU]:

1. **Install** required libraries  
2. **Define** GPT‑2 building blocks  
3. **Prepare** an ultra‑lightweight streaming dataloader  
4. **Configure** (optional) Distributed Data Parallel (DDP)  
5. **Train**, **validate** on HellaSwag, and **sample** text  

Feel free to tweak hyper‑parameters such as `max_steps` and `total_batch_size` to match your compute budget.


In [1]:
# @title Install dependencies
!pip install -q tiktoken hellaswag

[31mERROR: Could not find a version that satisfies the requirement hellaswag (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for hellaswag[0m[31m
[0m

## Imports

hellasawg library is a challenge dataset of filling sentences. Project Link[https://huggingface.co/datasets/Rowan/hellaswag]

In [2]:
import os, math, time, inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

import tiktoken
import numpy as np
from hellaswag import render_example, iterate_examples


ModuleNotFoundError: No module named 'hellaswag'

## Transformer building blocks

Unlike ReLU, GELU returns non-zero values when $x<0$, given by
$\operatorname{GELU}(x) \;\approx\; 0.5 \, x \,\Bigl(1 + \tanh\!\Bigl[\sqrt{\tfrac{2}{\pi}}\,
  \bigl(x + 0.044715\,x^{3}\bigr)\Bigr]\Bigr)$


In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size()
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        qkv = self.c_attn(x)
        # normally we don't define three separate linear layers, but project input to 3times dim
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc   = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu   = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

# Residual connection to control activation variance growth in the stream in forward propogation
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp  = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


## GPT‑2 model wrapper

n_embd is chosen following GPT2 paper.

# FlashAttention
 kernel fusion operation. It is high performance implementation of the attention mechanism in transformers. It is fast, memory-efficient and numerically stable, by reducing GPU I/O ops.

 1. Computes blocks of attention scores in GPU registers

 2. Use tiling and fused kernel operations to reduce memory reads/writes

 3. Applies softmax + dropout + matmal in one fused kernel




In [None]:
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(
            dict(
                wte   = nn.Embedding(config.vocab_size, config.n_embd),
                wpe   = nn.Embedding(config.block_size, config.n_embd),
                h     = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
                ln_f  = nn.LayerNorm(config.n_embd),
            )
        )
        #use learnable position encoding
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # weight sharing that enhances efficiency and proves effective
        self.transformer.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.n_layer) ** -0.5
            nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size, "Sequence length > block size"
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        pos_emb = self.transformer.wpe(pos)
        tok_emb = self.transformer.wte(idx)
        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        # Notice that the third layernorm before lm_head
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss


## Streaming dataloader & helpers

Note that not every dataset token is processed by GPT training.

Data Augementation(noise adding, shuffling) not added

In [None]:
def load_tokens(filename):
    npt = np.load(filename).astype(np.int32)
    return torch.tensor(npt, dtype=torch.long)


class DataLoaderLite:
    """Ultra‑lightweight streaming dataloader."""
    def __init__(self, B, T, process_rank, num_processes, split):
        self.B = B
        self.T = T
        self.process_rank = process_rank
        self.num_processes = num_processes
        assert split in {'train', 'val'}

        data_root = "edu_fineweb10B"
        shards = sorted([os.path.join(data_root, s) for s in os.listdir(data_root) if split in s])
        assert shards, f"No shards found for split {split}"
        self.shards = shards
        self.reset()

    def reset(self):
        self.current_shard = 0
        self.tokens = load_tokens(self.shards[self.current_shard])
        self.current_position = self.B * self.T * self.process_rank

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position + B*T + 1]
        x = buf[:-1].view(B, T)
        y = buf[1:].view(B, T)
        self.current_position += B * T * self.num_processes
        # prevent overflow
        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
            self.current_shard = (self.current_shard + 1) % len(self.shards)
            self.tokens = load_tokens(self.shards[self.current_shard])
            self.current_position = B * T * self.process_rank
        return x, y


def get_most_likely_row(tokens, mask, logits):
    shift_logits = logits[..., :-1, :].contiguous()
    shift_tokens = tokens[..., 1:].contiguous()
    flat_logits = shift_logits.view(-1, shift_logits.size(-1))
    flat_tokens = shift_tokens.view(-1)
    shift_losses = F.cross_entropy(flat_logits, flat_tokens, reduction='none').view(tokens.size(0), -1)
    shift_mask = mask[..., 1:].contiguous()
    avg_loss = (shift_losses * shift_mask).sum(dim=1) / shift_mask.sum(dim=1)
    return avg_loss.argmin().item()


## Device / DDP setup & hyper‑parameters

Distributed Data Parallel(DDP) saves GPU computation resources. Every GPU has its separate model and dataloader.  

RANK = 0 means main process.

In [None]:
# Detect distributed run
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
    import torch.distributed as dist
    from torch.distributed import init_process_group
    from torch.nn.parallel import DistributedDataParallel as DDP

    init_process_group(backend='nccl')
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0
else:
    ddp_rank = 0
    ddp_world_size = 1
    master_process = True
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

device_type = 'cuda' if device.startswith('cuda') else 'cpu'
print('Using device:', device)

# Hyper‑parameters (tweak to taste)
total_batch_size = 524288  # tokens
B = 64
T = 1024
assert total_batch_size % (B * T * ddp_world_size) == 0
grad_accum_steps = total_batch_size // (B * T * ddp_world_size)
print('Gradient accumulation steps:', grad_accum_steps)

max_lr   = 6e-4
min_lr   = max_lr * 0.1
warmup_steps = 715
max_steps  = 500  # shorter default for Colab; raise for full training


## Build model & dataloaders

Before reaching warmup iterations, learning rate increases linearly. After the set steps, cosine decay is applied.

In optimizer, weight decay is used as L2 regularization to penalize large weights and avoid overfitting.

Aside from the code implementations, it is also feasible to dynamically increase batch size.

In [None]:
train_loader = DataLoaderLite(B, T, ddp_rank, ddp_world_size, split='train')
val_loader   = DataLoaderLite(B, T, ddp_rank, ddp_world_size, split='val')

model = GPT(GPTConfig())
model.to(device)
if ddp:
    model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module if ddp else model

def get_lr(it):
    if it < warmup_steps:
        return max_lr * (it + 1) / warmup_steps
    if it > max_steps:
        return min_lr
    decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (max_lr - min_lr)

optimizer = torch.optim.AdamW(
    raw_model.parameters(),
    lr=max_lr, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.1, fused=('fused' in inspect.signature(torch.optim.AdamW).parameters and device_type=='cuda')
)


## Training loop
Typically the tiktoken compress rate is 3:1

When we wish to increase batch size but constainted to memory size, we can accumulate gradient on multiple mini batches, and defer backward propogation until all mini batches in one batch are processed.

If model.require_backward_grad_sync=True, then DDP uses al-reduce operation to average gradients on all gpus, then send back to ensure all the gradients are synchornized.

In [None]:
enc = tiktoken.get_encoding('gpt2')

for step in range(max_steps):
    t0 = time.time()
    model.train()
    optimizer.zero_grad()
    loss_accum = 0.0

    for micro_step in range(grad_accum_steps):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)
        if ddp:
            model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(x, y)
        (loss / grad_accum_steps).backward()
        loss_accum += loss.detach()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    lr = get_lr(step)
    for pg in optimizer.param_groups:
        pg['lr'] = lr
    optimizer.step() # updates parameters

    dt = time.time() - t0
    if master_process:
        print(f'step {step:04d} | loss {loss_accum.item():.4f} | lr {lr:.2e} | time {dt*1000:.0f} ms')


## Quick top‑k sampling check

In [None]:
model.eval()
prompt = "Hello, I'm a language model,"
tokens = torch.tensor(enc.encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
max_len = 32
for _ in range(max_len - tokens.size(1)):
    with torch.no_grad(), torch.autocast(device_type=device_type, dtype=torch.bfloat16):
        logits, _ = model(tokens)
    probs = torch.softmax(logits[:, -1], dim=-1)
    idx = torch.multinomial(probs, num_samples=1)
    tokens = torch.cat([tokens, idx], dim=-1)

print(enc.decode(tokens[0].tolist()))
