In [4]:
# train_colab.py
# Complete training script (warmup + cosine annealing + grad clipping + save export)
# Assumes you have input.txt in the working directory.

import os
import math
import json
from dataclasses import dataclass, asdict

import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken


# -----------------------------
# Model components
# -----------------------------

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # flag used for scaled init in GPT._init_weights
        self.c_proj.NANOGPT_SCALE_INIT = 1

        self.n_head = config.n_head
        self.n_embd = config.n_embd

        self.register_buffer(
            "bias",
            torch.tril(torch.ones(config.block_size, config.block_size))
                .view(1, 1, config.block_size, config.block_size)
        )

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)

        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)

        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc   = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu   = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp  = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte  = nn.Embedding(config.vocab_size, config.n_embd),
            wpe  = nn.Embedding(config.block_size, config.n_embd),
            h    = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))

        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # weight tying
        self.transformer.wte.weight = self.lm_head.weight

        # init weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, "NANOGPT_SCALE_INIT"):
                std *= (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size, (
            f"Seq len {T} > block_size {self.config.block_size}"
        )

        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        tok_emb = self.transformer.wte(idx)      # (B, T, C)
        pos_emb = self.transformer.wpe(pos)      # (T, C)
        x = tok_emb + pos_emb

        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)                # (B, T, vocab)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1)
            )
        return logits, loss


# -----------------------------
# Data loader
# -----------------------------

class DataLoaderLite:
    def __init__(self, B, T, enc, path="input.txt"):
        self.B = B
        self.T = T
        self.enc = enc

        if not os.path.exists(path):
            raise FileNotFoundError(
                f"{path} not found. Upload/put your corpus as input.txt"
            )

        with open(path, "r", encoding="utf-8") as f:
            text = f.read()

        tokens = enc.encode(text)
        self.tokens = torch.tensor(tokens, dtype=torch.long)

        print(f"loaded {len(self.tokens)} tokens")
        print(f"1 epoch = {len(self.tokens) // (B * T)} batches")

        self.current_position = 0

    def next_batch(self):
        B, T = self.B, self.T

        buf = self.tokens[self.current_position : self.current_position + B*T + 1]
        if buf.size(0) < B*T + 1:
            self.current_position = 0
            buf = self.tokens[self.current_position : self.current_position + B*T + 1]

        x = buf[:-1].view(B, T)
        y = buf[1:].view(B, T)

        self.current_position += B*T
        if self.current_position + (B*T + 1) > len(self.tokens):
            self.current_position = 0

        return x, y


# -----------------------------
# Train + save
# -----------------------------

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)

    torch.manual_seed(1337)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(1337)

    enc = tiktoken.get_encoding("gpt2")

    # ---- model config (GPT-2 small = 124M) ----
    config = GPTConfig(
        block_size=1024,
        vocab_size=50257,
        n_layer=12,
        n_head=12,
        n_embd=768
    )

    model = GPT(config).to(device)

    # ---- data loader ----
    train_loader = DataLoaderLite(B=8, T=256, enc=enc, path="input.txt")

    # ---- training hyperparams ----
    base_lr = 3e-4
    total_steps = 5000
    warmup_steps = 500
    max_grad_norm = 1.0

    optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr)

    def lr_lambda(step):
        if step < warmup_steps:
            return float(step + 1) / float(warmup_steps)
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    model.train()
    last_loss = None

    for step in range(total_steps):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        _, loss = model(x, y)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()
        scheduler.step()

        last_loss = loss.item()
        if step % 50 == 0 or step == total_steps - 1:
            print(
                f"step {step}/{total_steps} | "
                f"loss {last_loss:.4f} | "
                f"lr {scheduler.get_last_lr()[0]:.6e}"
            )

    print("final loss:", last_loss)

    # ---- export for HF Spaces inference ----
    export_dir = "hf_export"
    os.makedirs(export_dir, exist_ok=True)

    torch.save(model.state_dict(), os.path.join(export_dir, "model.pt"))

    with open(os.path.join(export_dir, "config.json"), "w") as f:
        json.dump(asdict(config), f)

    with open(os.path.join(export_dir, "meta.json"), "w") as f:
        json.dump({
            "tokenizer": "gpt2",
            "vocab_size": config.vocab_size,
            "block_size": config.block_size
        }, f)

    print(f"Saved model + config to {export_dir}/")


if __name__ == "__main__":
    main()


Using device: cuda
loaded 338025 tokens
1 epoch = 165 batches
step 0/5000 | loss 10.9527 | lr 1.200000e-06
step 50/5000 | loss 8.3372 | lr 3.120000e-05
step 100/5000 | loss 6.3455 | lr 6.120000e-05
step 150/5000 | loss 6.0740 | lr 9.120000e-05
step 200/5000 | loss 5.2532 | lr 1.212000e-04
step 250/5000 | loss 5.3142 | lr 1.512000e-04
step 300/5000 | loss 4.7946 | lr 1.812000e-04
step 350/5000 | loss 5.0287 | lr 2.112000e-04
step 400/5000 | loss 4.6069 | lr 2.412000e-04
step 450/5000 | loss 4.8475 | lr 2.712000e-04
step 500/5000 | loss 4.9737 | lr 3.000000e-04
step 550/5000 | loss 4.7973 | lr 2.999049e-04
step 600/5000 | loss 4.9622 | lr 2.996273e-04
step 650/5000 | loss 3.9560 | lr 2.991673e-04
step 700/5000 | loss 4.4178 | lr 2.985256e-04
step 750/5000 | loss 4.0366 | lr 2.977029e-04
step 800/5000 | loss 3.6059 | lr 2.967003e-04
step 850/5000 | loss 3.8397 | lr 2.955190e-04
step 900/5000 | loss 4.1377 | lr 2.941604e-04
step 950/5000 | loss 4.0849 | lr 2.926261e-04
step 1000/5000 | los