# Fine‑tune GPT‑2 from scratch on tinystories

This Colab notebook refactors **`train_gpt2.py`** (from https://github.com/karpathy/build-nanogpt/blob/master/train_gpt2.py) into a step‑by‑step workflow. Relevant explanations can be found in [https://www.youtube.com/watch?v=l8pRSuU81PU]. Our notebook is organized into the following components:

1. **Install** required libraries
2. **Define** GPT‑2 building blocks
3. **Prepare** an ultra‑lightweight streaming dataloader
4. **Configure** (optional) Distributed Data Parallel (DDP)
5. **Train**, **validate** on dataset, and **sample** text

Feel free to tweak hyper‑parameters such as `max_steps` and `total_batch_size` to match compute budget.

In [None]:
!pip install -U "datasets>=2.15" "huggingface_hub[fsspec]" fsspec

Collecting fsspec
  Downloading fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)
Collecting huggingface_hub[fsspec]
  Downloading huggingface_hub-0.33.2-py3-none-any.whl.metadata (14 kB)
Collecting fsspec
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Collecting hf-xet<2.0.0,>=1.1.2 (from huggingface_hub[fsspec])
  Downloading hf_xet-1.1.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (879 bytes)
Downloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading hf_xet-1.1.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m89.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hDownloading huggingface_hub-0.33.2-py3-none-any.whl (515 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m515.4/515.4 kB[0m [31m32.2 MB/s

In [None]:
!pip install ftfy


Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy
Successfully installed ftfy-6.3.1


In [None]:
!huggingface-cli login

Token is valid (permission: fineGrained).
The token `Sevres` has been saved to /root/.cache/huggingface/stored_tokens
[1m[31mCannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub.
Run the following command in your terminal in case you want to set the 'store' credential helper as default.

git config --global credential.helper store

Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.[0m
Token has not been saved to git credential helper.
Your token has been saved to /root/.cache/huggingface/token
Login successful.
The current active token is: `Sevres`


# Design nanoGPT Blocks

In attention mechanism, the intuition behind using the dot product of the Query and Key is analogous to mechanism used in nearest neighbor search. From Map Reduce perspective, the attention is the reduce, and the MLP is the map.

In [None]:
import re
import os, math, time, inspect
from dataclasses import dataclass
from ftfy import fix_text
import torch
import torch.nn as nn
from torch.nn import functional as F

import tiktoken
import numpy as np
# from hellaswag import render_example, iterate_examples
import datasets
import torch, itertools
from datasets import load_dataset
from transformers import AutoTokenizer

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size()
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        qkv = self.c_attn(x)
        # normally we don't define three separate linear layers, but project input to 3times dim
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y


## Transformer building blocks

Unlike ReLU, GELU returns non-zero values when $x<0$, given by

$\operatorname{GELU}(x) \;\approx\; 0.5 \, x \,\Bigl(1 + \tanh\!\Bigl[\sqrt{\tfrac{2}{\pi}}\,
  \bigl(x + 0.044715\,x^{3}\bigr)\Bigr]\Bigr)$

In [None]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc   = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu   = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

# Residual connection to control activation variance growth in the stream in forward propogation,
#and to preserve information and gradient flow, making very deep nets trainable
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp  = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

## GPT‑2 model wrapper

n_embd is chosen following GPT2 paper.
vocab_size is chosen after GPT-2 tokenizer. However for smaller datasets it would be better to reduce the size of vocab size to prevent weird tokens from appearing.

As for the size of T, Do not try to finetune GPT-2 on longer sequences unless you are prepared to alter its architecture.

Note that LayerNorm is **after** Transformer modules to bring stabler gradients in very deep nets; allows cranking the learning-rate without warm-up gymnastics.

# FlashAttention
 consists of kernel fusion operation. It is high performance implementation of the attention mechanism in transformers. It is fast, memory-efficient and numerically stable, by reducing GPU I/O ops.

 1. Computes blocks of attention scores in SRAM GPU registers, instead of materializing large NxN attention matrix in High Bandwidth Memory (HBM)

 2. Use tiling and fused kernel operations to reduce memory reads/writes

 3. Applies softmax + dropout + matmal in one fused kernel, though inference paths usually compile without dropout for speed.

In [None]:
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(
            dict(
                wte   = nn.Embedding(config.vocab_size, config.n_embd),
                wpe   = nn.Embedding(config.block_size, config.n_embd),
                h     = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
                ln_f  = nn.LayerNorm(config.n_embd),
            )
        )
        #use learnable position encoding
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # weight sharing that enhances efficiency and proves effective
        self.transformer.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.n_layer) ** -0.5
            nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size, "Sequence length > block size"
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        pos_emb = self.transformer.wpe(pos)
        tok_emb = self.transformer.wte(idx)
        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        # Notice that the third layernorm before lm_head
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

## Streaming dataloader & helpers

Note that not every dataset token is processed by GPT training. The vocab size is larger than actual size to make itself divisble by 2-exponents. Throughput is increased despite vocal_size raised.

Data Augementation(noise adding, shuffling) not added

In [None]:
import unicodedata


END_RE = re.compile(r'^\s*The end\.?\s*$', re.I)
END_SPLIT = re.compile(r'\b[Tt]he end\.?\s*')
_re_unk   = re.compile(r"\s*<unk>\s*")
_re_at_at = re.compile(r"\s*@\s*")
_re_head  = re.compile(r"={2,}")            # '== Heading =='

def clean_fragments(chars):
    # 1. Join lines
    text = "".join(chars)

    # 2. Fix garbled encoding (mojibake)
    text = fix_text(text)
    text = _re_unk.sub(" ", text)
    # text = (text.replace("@-@", "-")
    #                 .replace("@,@", ",")
    #                 .replace("@.@", "."))
    text = _re_at_at.sub(" ", text)
    text = _re_head.sub(" ", text)

    # 3. Normalize quotes and spacing
    text = unicodedata.normalize("NFKC", text)

    # 4. Split into paragraphs, drop "The end"
    paras = [p.strip() for p in text.split("\n\n")]
    paras = [p for p in paras if p and not END_RE.match(p)]

    # 5. Collapse whitespace
    return re.sub(r"\s+", " ", " ".join(paras)).strip()

class StreamingBatchLoader:
    """
    Replaces the original GPTDataset in nanoGPT for streaming use.
    Generates (x, y) token blocks from a streaming Hugging Face dataset.
    """
    def __init__(self, repo_id, version_name, split="train", block_size=1024, batch_size=8, seed=42):
        self.block_size = block_size
        self.batch_size = batch_size
        self.buf = []

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True,model_max_length=10_000_000)
        self.tokenizer.pad_token = self.tokenizer.eos_token  # For safety

        # Load streaming dataset               name=version_name,
        if len(version_name)==0:
            ds = load_dataset(
            repo_id,

            split="train",
            streaming=True)
        else:
            ds = load_dataset(
            repo_id,
            version_name,
            split="train",
            streaming=True)
        ds = ds.shuffle(buffer_size=10_000, seed=seed)
        ds = ds.map(self._tokenize, remove_columns=["text"])
        self.iterator = iter(ds)


    def _tokenize(self, example):
        text =  clean_fragments(example["text"])
        text = text.replace("The end.", " ")
        ids = self.tokenizer.encode(text.replace("={2,}", " "))   # strip wiki '=='
        return {"ids": ids + [self.tokenizer.eos_token_id]}



    def next(self):
        """
        Returns one (x, y) batch of shape (B, T)
        """
        required = self.batch_size * self.block_size + 1
        while len(self.buf) < required:
            self.buf.extend(next(self.iterator)["ids"])

        print(len(self.buf))
        tokens = torch.tensor(self.buf[:required], dtype=torch.long)
        print(tokens.shape)
        del self.buf[:self.batch_size * self.block_size]

        x = tokens[:-1].view(self.batch_size, self.block_size)
        y = tokens[1:].view(self.batch_size, self.block_size)
        if x.shape != (self.batch_size, self.block_size):
            print('x shape error')
            x = None
        if y.shape != (self.batch_size, self.block_size):
            print('y shape error')
            y = None
        return x, y
    def humanize(self, tensor_batch):
        """
        tensor_batch: (B, T) on *CPU*.
        Returns readable words list[str] of length B.
        """
        return [
            self.tokenizer.decode(row.tolist(), skip_special_tokens=True)
            for row in tensor_batch
        ]

# Hyper‑parameters (tweak to taste)
#total_batch_size = 524288  # 2**19, approx 0.5M, in number of tokens for Fineweb edu
total_batch_size = 8192 # for WikiText
B = 8
T = 1024

max_lr   = 3e-4
min_lr   = max_lr * 0.1
warmup_steps = 715
max_steps  = 3200  # shorter default for Colab; raise for full training

train_loader = StreamingBatchLoader(repo_id="roneneldan/TinyStories", version_name="",split="train", block_size=B, batch_size=T)
x, y = train_loader.next()
print(x.shape, y.shape)


## Device / DDP setup & hyper‑parameters

Distributed Data Parallel(DDP) saves GPU computation resources. Every GPU has its separate model and dataloader.

RANK = 0 means main process.

In [None]:

"""
"""

# Detect distributed run
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
    import torch.distributed as dist
    from torch.distributed import init_process_group
    from torch.nn.parallel import DistributedDataParallel as DDP

    init_process_group(backend='nccl')
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0
else:
    ddp_rank = 0
    ddp_world_size = 1
    master_process = True
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

device_type = 'cuda' if device.startswith('cuda') else 'cpu'
print('Using device:', device, 'rank', ddp_rank)

assert total_batch_size % (B * T * ddp_world_size) == 0
grad_accum_steps = total_batch_size // (B * T * ddp_world_size)
print('Gradient accumulation steps:', grad_accum_steps)








tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

8308
torch.Size([8193])
torch.Size([1024, 8]) torch.Size([1024, 8])
Using device: cuda rank 0
Gradient accumulation steps: 1


'## Training loop\nTypically the tiktoken compress rate is 3:1\n\nWhen we wish to increase batch size but constainted to memory size, we can accumulate gradient on multiple mini batches, and defer backward propogation until all mini batches in one batch are processed.\n\nIf model.require_backward_grad_sync=True, then DDP uses al-reduce operation to average gradients on all gpus, then send back to ensure all the gradients are synchornized.\n\n**why autocast to bfloat16:**\n\n16-bit numbers use half the memory of 32-bit. bfloat16 has the same exponent range as float32, but fewer mantissa bits, so it can represent very large/small numbers **without** underflow/overflow.\n'

## Build model & dataloaders

Before reaching warmup iterations, learning rate increases linearly. After the set steps, cosine decay is applied.

In AdamW optimizer, decoupled weight decay is used as L2 regularization to penalize large weights and avoid overfitting. Compared to SGD, it reduces the ampitude of gradient flucations.

Aside from the code implementations, it is also feasible to dynamically increase batch size.

In [None]:


model = GPT(GPTConfig())
model.to(device)
if ddp:
    model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module if ddp else model

def get_lr(it):
    if it < warmup_steps:
        return max_lr * (it + 1) / warmup_steps
    if it > max_steps:
        return min_lr
    decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (max_lr - min_lr)

optimizer = torch.optim.AdamW(
    raw_model.parameters(),
    lr=max_lr, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.1, fused=('fused' in inspect.signature(torch.optim.AdamW).parameters and device_type=='cuda')
)


We can regard the attention as the reduce, and the MLP as the map.


In [None]:
# loader = StreamingBatchLoader(batch_size=2)  # small smoke test
# x, y = loader.next()
#tinystory
print("x shape:", x.shape, "y shape:", y.shape)
print("--- x[0] as text ----------------")
print(train_loader.humanize(x))

## Training loop
Typically the tiktoken compress rate is 3:1. The iterations need to be more than 5000 to get enough gradient steps to learn.

When we wish to increase batch size but constainted to memory size, we can accumulate gradient on multiple mini batches, and defer backward propogation until all mini batches in one batch are processed.

If model.require_backward_grad_sync=True, then DDP uses al-reduce operation to average gradients on all gpus, then send back to ensure all the gradients are synchornized.

**why autocast to bfloat16:**

16-bit numbers use half the memory of 32-bit. bfloat16 has the same exponent range as float32, but fewer mantissa bits, so it can represent very large/small numbers **without** underflow/overflow.


In [None]:
max_steps = 13200
enc = tiktoken.get_encoding('gpt2')

for step in range(max_steps):
    t0 = time.time()
    model.train()
    optimizer.zero_grad()
    loss_accum = 0.0

    for micro_step in range(grad_accum_steps):
        x, y = train_loader.next()
        if x is None or y is None:
            continue
        x, y = x.to(device), y.to(device)
        if ddp:
            model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
        # increase throughput with bfloat16
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(x, y)
        (loss / grad_accum_steps).backward()
        loss_accum += loss.detach()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    lr = get_lr(step)
    for pg in optimizer.param_groups:
        pg['lr'] = lr
    optimizer.step() # updates parameters

    dt = time.time() - t0
    if master_process:
        print(f'step {step:04d} | loss {loss_accum.item():.4f} | lr {lr:.2e} | time {dt*1000:.0f} ms')
    if step%400==0:
        model.eval()
        prompt = "Our family have a picnic "
        tokens = torch.tensor(enc.encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
        max_len = 32
        for _ in range(max_len - tokens.size(1)):
            with torch.no_grad(), torch.autocast(device_type=device_type, dtype=torch.bfloat16):
                logits, _ = model(tokens)
            probs = torch.softmax(logits[:, -1], dim=-1)
            idx = torch.multinomial(probs, num_samples=1)
            tokens = torch.cat([tokens, idx], dim=-1)

        print(enc.decode(tokens[0].tolist()))

"""## Quick top‑k sampling check

This displays the performance of code: how the LLM fills sentences
"""


8296
torch.Size([8193])
step 0000 | loss 11.0258 | lr 4.20e-07 | time 1712 ms
Our family have a picnic ICSachelfaced majorCurrently violet witnessed ju Shepardounterirensuctiveeper697 lament desk psychologyeItemTracker_{ Sanctuary podcasts var educateenfranch Berman
8249
torch.Size([8193])
step 0001 | loss 10.9970 | lr 8.39e-07 | time 118 ms
8287
torch.Size([8193])
step 0002 | loss 10.9442 | lr 1.26e-06 | time 123 ms
8415
torch.Size([8193])
step 0003 | loss 10.8401 | lr 1.68e-06 | time 119 ms
8456
torch.Size([8193])
step 0004 | loss 10.7126 | lr 2.10e-06 | time 117 ms
8356
torch.Size([8193])
step 0005 | loss 10.5229 | lr 2.52e-06 | time 130 ms
8445
torch.Size([8193])
step 0006 | loss 10.4107 | lr 2.94e-06 | time 118 ms
8314
torch.Size([8193])
step 0007 | loss 10.2317 | lr 3.36e-06 | time 114 ms
8432
torch.Size([8193])
step 0008 | loss 10.1011 | lr 3.78e-06 | time 117 ms
8216
torch.Size([8193])
step 0009 | loss 9.9814 | lr 4.20e-06 | time 117 ms
8432
torch.Size([8193])
step 0010 | loss 

Below listed are some examples to test the LLM performance

In [None]:
model.eval()
prompt = "Hello, I'm a LLM,"
tokens = torch.tensor(enc.encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
max_len = 64
for _ in range(max_len - tokens.size(1)):
    with torch.no_grad(), torch.autocast(device_type=device_type, dtype=torch.bfloat16):
        logits, _ = model(tokens)
    probs = torch.softmax(logits[:, -1], dim=-1)
    idx = torch.multinomial(probs, num_samples=1)
    tokens = torch.cat([tokens, idx], dim=-1)

print(enc.decode(tokens[0].tolist()))

In [None]:
model.eval()
prompt = "Hello, I'm a language model,"
tokens = torch.tensor(enc.encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
max_len = 64
for _ in range(max_len - tokens.size(1)):
    with torch.no_grad(), torch.autocast(device_type=device_type, dtype=torch.bfloat16):
        logits, _ = model(tokens)
    probs = torch.softmax(logits[:, -1], dim=-1)
    idx = torch.multinomial(probs, num_samples=1)
    tokens = torch.cat([tokens, idx], dim=-1)

print(enc.decode(tokens[0].tolist()))

In [None]:
model.eval()
prompt = "Civil war is "
tokens = torch.tensor(enc.encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
max_len =  128
for _ in range(max_len - tokens.size(1)):
    with torch.no_grad(), torch.autocast(device_type=device_type, dtype=torch.bfloat16):
        logits, _ = model(tokens)
    probs = torch.softmax(logits[:, -1], dim=-1)
    idx = torch.multinomial(probs, num_samples=1)
    tokens = torch.cat([tokens, idx], dim=-1)

print(enc.decode(tokens[0].tolist()))

In [None]:
train_loader = StreamingBatchLoader(repo_id= "HuggingFaceFW/fineweb-edu", version_name="CC-MAIN-2024-10",split="train", block_size=B, batch_size=T)

x, y = train_loader.next()
print("x shape:", x.shape, "y shape:", y.shape)
print("--- x[0] as text ----------------")
print(train_loader.humanize(x))

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
model.eval()
prompt = "Civil war is "
tokens = torch.tensor([tokenizer.eos_token_id], device=device).unsqueeze(0)
print(tokens)
max_len =  128
for _ in range(max_len - tokens.size(1)):
    with torch.no_grad(), torch.autocast(device_type=device_type, dtype=torch.bfloat16):
        logits, _ = model(tokens)
    probs = torch.softmax(logits[:, -1], dim=-1)
    idx = torch.multinomial(probs, num_samples=1)
    tokens = torch.cat([tokens, idx], dim=-1)

print(enc.decode(tokens[0].tolist()))