<a href="https://colab.research.google.com/github/BrianZ60/GPT-2/blob/main/train_gpt2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F


class CasualSelfAttention(nn.Module):
    # all heads grouped together to run in parallel
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # k,q,v projections for all heads
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.SCALE_INIT = 1.0

        self.n_head = config.n_head
        self.n_embd = config.n_embd

        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1, 1, config.block_size, config.block_size))
        # made it the have same dims as att for the masked_fill

    def forward(self, x):
        B, T, C = x.shape # batch size, sequence length, num embd (hs * nh)
        # nh = num heads, hs = head size
        qkv = self.c_attn(x) # (B, T, 3C)
        q, k, v = qkv.split(self.n_embd, dim=2) # (B, T, C)
        # make nh into a batch dimension so operations can be applied in parallel
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, T, C) -> (B, T, nh, hs) -> (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # multiply and scale by factor of sqrt(hs)
        # att = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T)
        # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float("-inf")) # mask future tokens
        # att = F.softmax(att, dim=-1) # make attention sum to one
        # y = att @ v # the weighted sum. (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
        # Flash attention uses kernel fusion and avoids large reads/writes by using GPU on-chip memory more
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
        # transpose makes it not contiguous; we need contiguous for view()
        y = self.c_proj(y)
        return y



class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.SCALE_INIT = 1.0

    def forward(self, x):
        x = self.c_fc(x) # linear expansion
        x = self.gelu(x) # gelu is relu but more smooth, so no dead relu neuron
        x = self.c_proj(x) # linear projection
        return x


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CasualSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        # residual connections

        x = x + self.attn(self.ln_1(x)) # communicate
        x = x + self.mlp(self.ln_2(x)) # think individually abt info gathered

        return x


@dataclass # automatically make init
class GPTConfig:
    block_size: int = 1024 # max sequence length
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd : int = 768

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        # use module dict to replicate structure of the hf model
        self.transformer = nn.ModuleDict(dict(
            wte =  nn.Embedding(config.vocab_size, config.n_embd),
            wpe =  nn.Embedding(config.block_size, config.n_embd),
            h =    nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # weight sharing scheme
        # we do this because they are to an extent, inverses
        # we also expect the wte to react similarly for synonyms, and the lm_head to give synonyms similar scores
        # for more information, see https://arxiv.org/pdf/1608.05859
        self.transformer.wte.weight = self.lm_head.weight # also saves a lot of parameters

        # init params
        self.apply(self._init_weights) # apply to every submodule

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            # scale down weights of c_proj in mlp and attn
            if hasattr(module, "SCALE_INIT"):
                std *= (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)


    def forward(self, idx, targets=None):
        # idx: (B, T)
        B, T = idx.shape
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # (T)
        pos_emb = self.transformer.wpe(pos) # (T, n_embd)
        tok_emb = self.transformer.wte(idx) # (B, T, n_embd)
        x = tok_emb + pos_emb

        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        logits = self.lm_head(x) # (B, T, vocab_size)

        loss = None
        if targets is not None:
            # flatten logits into (B*T, vocab_size) and targets into (B*T)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    @classmethod
    def from_pretrained(cls, model_type):
        """Loads pretrained GPT-2 model weights from huggingface"""
        assert model_type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-x1"}
        from transformers import GPT2LMHeadModel
        print(f"loading weights from pretrained gpt: {model_type}")

        # make config_args dict based on model_type
        config_args = {
            "gpt2":        dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
            "gpt2-medium": dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
            "gpt2-large":  dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
            "gpt2-xl":     dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
        }[model_type]
        # add two more args:
        config_args["vocab_size"] = 50257
        config_args["block_size"] = 1024

        # unpack dict into args
        config = GPTConfig(**config_args)
        model = GPT(config)

        sd = model.state_dict()
        sd_keys = sd.keys()
        sd_keys = [k for k in sd_keys if not k.endswith(".attn.bias")] # we don't want the mask buffer

        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        sd_keys_hf = sd_hf.keys()
        # sd_keys_hf = [k for k in sd_keys_hf if not k.endswith(".attn.masked_bias")]
        # sd_keys_hf = [k for k in sd_keys_hf if not k.endswith(".attn.bias")]
        transposed = ["attn.c_attn.weight", "attn.c_proj.weight", "mlp.c_fc.weight", "mlp.c_proj.weight"]

        assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"

        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].T)

            else:
                assert sd_hf[k].shape == sd[k].shape # ,  f"Shape mismatch at key: {k}. {sd_hf[k].shape} != {sd[k].shape}"
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        return model



    def configure_optimizers(self, weight_decay, learning_rate, device):
        param_dict = {name : param for name, param in self.named_parameters() if param.requires_grad}

        # weight decay discourages model from relying too heavy on a weight by penalizing large weights
        # (we add a penalty to the loss that increases as weights get bigger)

        # we weight decay parameters that are 2D, like weight matrcies in linear layers and embeddings
        # biases and layernorms are not weight decayed
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        no_decay_params = [p for n, p in param_dict.items() if p.dim() < 2]

        optim_groups = [
            {"params": decay_params, "weight_decay": weight_decay},
            {"params": no_decay_params, "weight_decay": 0.0}
        ]

        num_decay_params = sum(p.numel() for p in decay_params) # counts total number of elements (parameters)
        num_no_decay_params = sum(p.numel() for p in no_decay_params)

        print(f"num decayed param tensors: {len(decay_params)}, totaling {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(no_decay_params)}, totaling {num_no_decay_params:,} parameters")

        use_fused = "cuda" in device
        print(f"fused AdamW: {use_fused}")

        optimizer = torch.optim.AdamW(params=optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)

        return optimizer






import tiktoken

class DataLoaderLite:
    def __init__(self, B, T):
        self.B = B
        self.T = T

        # load and store tokens
        enc = tiktoken.get_encoding("gpt2")

        !wget https://raw.githubusercontent.com/BrianZ60/GPT-2/refs/heads/main/input.txt
        with open("input.txt", "r") as f:
            text = f.read()
        tokens = enc.encode(text)
        self.tokens = torch.tensor(tokens)
        print(f"Loaded {len(self.tokens)} tokens")
        print(f"1 epoch = {len(self.tokens) // (B * T)} batches")

        self.current_position = 0

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position:self.current_position + B*T + 1]
        x = buf[:-1].view(B, T)
        y = buf[1:].view(B, T)

        self.current_position += B*T

        # if next batch is out of bounds, loop back around
        if self.current_position + B*T+1 >= len(self.tokens):
            self.current_position = 0

        return x, y


import time

device = "cuda" if torch.cuda.is_available() else "cpu"

torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

total_batch_size = 524288 / 4 # measured in total number of tokens 
# We use this b/c it is 2**19, close to openai's 0.5M

# We simulate 0.5M batch size by doing many forward, backward passes, accumulating the gradient
 
B = 8 # micro batch size
T = 512 # sequence length (num tokens)

assert total_batch_size % (B * T) == 0, "make sure total_batch_size is divisible by B*T"
grad_accum_steps = total_batch_size // (B * T)
print(f"desired total batch size: {total_batch_size}")
print(f"=> caculated gradient accumulation steps: {grad_accum_steps}")


train_loader = DataLoaderLite(B=B, T=T)

# torch.set_float32_matmul_precision("high") # use TF32 (lower precision then FP32, but faster)
# variables are still FP32, but the matrix multi are TF32


model = GPT(GPTConfig(vocab_size=50304)).to(device) # use 50304 instead of 50257 b/c it is a much nicer number (divisible by 128)
model = torch.compile(model)
# What torch.compile() does:
# 1. Views the entire network as a whole, allowing for more efficient processing and minimizes Python interpreter overhead
# 2. Reduces read/write time btwn gpu and memory with operator fusion. This also mitigates memory bandwidth cost.

max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10
max_steps = 50

def get_lr(it):
  # linear warmup for first warmup_steps steps
  if it < warmup_steps:
    return max_lr * (it+1) / warmup_steps # linearly increasing lr to max_Lr

  # use min_lr after we do our lr decay
  if it > max_steps:
      return min_lr

  # when in between, we cosine decay to min lr
  decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps) # don't include the warmup steps
  assert 0 <= decay_ratio <= 1
  coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # start at 1 and goes to 0
  return min_lr + (max_lr - min_lr) * coeff



optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device=device)


for step in range(max_steps):

    t0 = time.time()

    optimizer.zero_grad() # only zero grad every grad_accum_steps

    loss_accum = 0.0
    for micro_step in range(grad_accum_steps):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)

        # with torch.autocast(device_type=device, dtype=torch.bfloat16):
        #     logits, loss = model(x, y) # actually changes the datatype of logits, but others remain FP32 (mixed precision)
        logits, loss = model(x, y) 
        loss /= grad_accum_steps
        loss_accum += loss.detach() # don't want to track gradients here
        loss.backward()

    # square all the gradients, add them up, and take sqrt to get grad norm
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # cap gradient norm at 1.0

    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
    optimizer.step()
    torch.cuda.synchronize() # wait for gpu to finish scheduled work before continuing
    t1 = time.time()
    dt = t1 - t0 # time diff in seconds
    tokens_processed = train_loader.B * train_loader.T * grad_accum_steps
    tokens_per_sec = tokens_processed / dt
    print(f"step: {step} | loss: {loss_accum.item():.6f} | lr: {lr:.4e} | norm: {norm:.4f} | dt: {dt*1000:.2f}ms, tok/sec: {tokens_per_sec:.2f}")



import sys; sys.exit(0) # skip eval for now



model.eval()
num_return_sequences = 5
max_length = 30

tokens = enc.encode("Hello, I'm a language model,")
tokens = torch.tensor(tokens, dtype=torch.long) # (8,)
tokens = tokens.unsqueeze(dim=0).repeat(num_return_sequences, 1) # (5, 8)
x = tokens.to(device)


torch.manual_seed(42)
torch.cuda.manual_seed(42)
while x.size(1) < max_length:
    with torch.inference_mode():
        logits = model(x) # (B, T, vocab_size)
        logits = logits[:, -1, :] # (B, vocab_size)
        probs = F.softmax(logits, dim=-1)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # (B, 50)
        # select a token
        ix = torch.multinomial(topk_probs, 1) # (B, 1)
        xcol = torch.gather(topk_indices, -1, ix) # get the element at the index for each batch
        x = torch.cat((x, xcol), dim=1)
for i in range(num_return_sequences):
    tokens = x[i].tolist()
    decoded = enc.decode(tokens)
    print(">", decoded)

--2025-08-11 03:44:46--  https://raw.githubusercontent.com/BrianZ60/GPT-2/refs/heads/main/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.3’


2025-08-11 03:44:46 (26.5 MB/s) - ‘input.txt.3’ saved [1115394/1115394]

Loaded 338025 tokens
1 epoch = 82 batches
num decayed param tensors: 50, totaling 124,354,560 parameters
num non-decayed parameter tensors: 98, totaling 121,344 parameters
fused AdamW: True
step: 0 | loss: 10.875685 | lr: 6.0000e-05 | norm: 30.3851 | dt: 8319.55ms, tok/sec: 492.33
step: 1 | loss: 9.491011 | lr: 1.2000e-04 | norm: 8.0198 | dt: 872.72ms, tok/sec: 4693.37
step: 2 | loss: 9.094246 | lr: 1.8000e-04 | norm: 4.0901 | dt: 883.77ms, tok/sec: 4634.71
step: 3 | loss: 8.982092 | lr:

SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
