# GPT

First notebook was written from 'scratch', this one leverages existing libraries to experiment with actual training and inference.

I also added some improvements over baseline model : 
- Moved attention computation to optimized `F.scaled_dot_product_attention`
- Moved `LayerNorm` to `RMSNorm`, which is the standard now
- Moved `GELU` to `SWIGLU`
- Moved positional encoding to `RoPE` on `Q` and `K`
- Disabled bias in every linear layers
- Grouped `Q`, `V`, `K` projections into 1
- Padded embedding to nearest % of 128 for GPU optimization


In [1]:
# Standard library
import csv
import inspect
import math
import multiprocessing
import os
import time
from pprint import pprint
from datetime import datetime

# Environment config
from huggingface_hub import login
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
if os.path.isfile(".env"):
    with open(".env") as f:
        for line in f:
            key, value = line.strip().split("=")
            os.environ[key] = value

    HF_TOKEN = os.environ["HF_TOKEN"]
    login(HF_TOKEN)


# Third-party
import tiktoken
import triton
import triton.language as tl
from datasets import concatenate_datasets, load_dataset, Dataset as ds  # 3.6.0 to avoid issues with load_dataset
from rotary_embedding_torch import RotaryEmbedding
from tokenizers import Tokenizer

# PyTorch
import torch
print(torch.__version__)
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import LinearLR, SequentialLR

# Custom
from utils import clean_columns, round_up, strip_compile_prefix

# Torch runtime config
torch.cuda.empty_cache()
torch.set_float32_matmul_precision("medium")
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
dtype = torch.bfloat16  
print(f"torch.cuda.is_bf16_supported() -> {torch.cuda.is_bf16_supported()}\n")

torch.empty(
    1, device=f"cuda:{os.environ.get('LOCAL_RANK', 0)}", requires_grad=True
).backward()  # prevents a bug on some systems

!nvidia-smi

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


2.10.0+cu130
torch.cuda.is_bf16_supported() -> True

Thu Jan 22 13:26:28 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05              Driver Version: 580.95.05      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX 5000 Ada Gene...    Off |   00000000:17:00.0 Off |                  Off |
| 30%   42C    P2             51W /  250W |     369MiB /  32760MiB |      6%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+-------------------

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 rotary_emb: RotaryEmbedding,
                 causal: bool = True,
                 dropout: float = 0.1
                ):
        super().__init__()
        if embed_dim % num_heads != 0:
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads}).")
        
        self.causal = causal
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.dropout_p = dropout
        
        # Fused QKV projection: 3x the output size
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        
        # Shared rotary embedding (passed from GPT model)
        self.rotary_emb = rotary_emb
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
    
    def forward(self, x, k_v_cache=None):
        B, T, _ = x.shape
        using_cache = k_v_cache is not None and "K" in k_v_cache
    
        # 1. Single fused projection
        if using_cache:
            x_q = x[:, -1:, :]
            qkv = self.qkv_proj(x_q)  # (B, 1, 3 x embed_dim)
        else:
            qkv = self.qkv_proj(x)  # (B, T, 3 x embed_dim)
        
        # 2. Split into Q, K, V
        Q, K, V = qkv.chunk(3, dim=-1)  # Each is (B, T, embed_dim)
        
        def split_heads(t):
            return t.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 3. Split heads -> (B, H, T, D_head)
        Q = split_heads(Q)
        K = split_heads(K)
        V = split_heads(V)
    
        # 4. Apply RoPE 
        if using_cache:
            past_len = k_v_cache["K"].shape[-2]
            Q = self.rotary_emb.rotate_queries_or_keys(Q, offset=past_len)
            K = self.rotary_emb.rotate_queries_or_keys(K, offset=past_len)
            
            K = torch.cat([k_v_cache["K"], K], dim=-2)
            V = torch.cat([k_v_cache["V"], V], dim=-2)
            # When using the cache, the "causality" is already enforced by the fact that we are passing 1 query token against all valid past keys 
            # We don't need a mask as we want the current token to attend to everything in the history
            is_causal_step = False
        else:
            Q = self.rotary_emb.rotate_queries_or_keys(Q)
            K = self.rotary_emb.rotate_queries_or_keys(K)
            is_causal_step = self.causal
    
        # 5. Update cache
        if k_v_cache is not None:
            k_v_cache["K"] = K.detach()  # we will never .backward on these
            k_v_cache["V"] = V.detach()  
    
        # 6. Attention
        out = F.scaled_dot_product_attention(
            query=Q,
            key=K,
            value=V,
            attn_mask=None, 
            dropout_p=self.dropout_p if self.training else 0.0,
            is_causal=is_causal_step
        )
        
        # 7. Merge heads
        out = out.transpose(1, 2).contiguous().view(B, -1, self.embed_dim)

        # 8. Linear projection
        return self.out_proj(out), k_v_cache

In [3]:
class MLP(nn.Module):
    def __init__(self, embed_dim, hidden_dim=None, dropout_prob=0.1):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * embed_dim
        
        # Adjust hidden_dim for param count matching
        # For perf it's important to be % 8
        hidden_dim = round_up(2 * hidden_dim // 3, 8)

        # Fused projection
        self.gate_up_proj = nn.Linear(embed_dim, 2 * hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=False)
        self.dropout = nn.Dropout(dropout_prob)
        

        # Manual implementation would be this
        # self.gate_proj = nn.Linear(embed_dim, hidden_dim, bias=False)
        # self.up_proj = nn.Linear(embed_dim, hidden_dim, bias=False)
        # self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=False)
        # forward pass : return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))

        # For performance sake we should maybe use xformers fused kernel instead, specially opt for bf16
        # Actually for very small tensors the manual implementation is almost 2x faster
        # But in real setting i.e 16-32 sequences of 1024 tokens, this implementation is ~ 25% faster (%timeit)
        # HOWEVER : compile + bf16 + manual FASTER then bf16 + xformers.swiglu (cant compile without errors)
        # Bottom line : we can keep manual for the purpose of this notebook
        # from xformers.ops import SwiGLU
        # self.swiglu = SwiGLU(
        #     in_features=embed_dim,
        #     hidden_features=hidden_dim,
        #     out_features=embed_dim,
        #     bias=False,
        #     _pack_weights=True,  # on by default but for clarity
        # )
        # forward pass : return self.dropout(self.swiglu(x))
       
        
    def forward(self, x):
        gate_up = self.gate_up_proj(x)
        gate, up = gate_up.chunk(2, dim=-1)
        return self.dropout(self.down_proj(F.silu(gate) * up))


# ============================================================================
# Fused Cross-Entropy Loss (Triton)
# ============================================================================
#
# Stolen err adapted from modded-nanogpt
# Idea is that computing Softmax CE on large vocab is super expensive so we try to fuse ops together
# TODO: Implement SoftCap on logits
#
# Benefits: 
# - No need to materialize full [B*T, vocab_size] probability tensor
# - Single kernel pass instead of softmax + nll_loss
# - ~2-3x faster for large vocabularies
# ============================================================================


@triton.jit
def fused_cross_entropy_fwd_kernel(
    logits_ptr, losses_ptr, lse_ptr, targets_ptr,
    stride_logits_row,
    n_cols,
    BLOCK_SIZE: tl.constexpr
):
    """
    Fused forward pass: computes loss = -logits[target] + logsumexp(logits)
    without materializing softmax probabilities.
    """
    row_idx = tl.program_id(0).to(tl.int64)
    logits_row_ptr = logits_ptr + row_idx * stride_logits_row
    
    # Compute logsumexp in numerically stable way (online algorithm)
    max_val = -float('inf')
    sum_exp = 0.0
    
    for off in range(0, n_cols, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        mask = cols < n_cols
        logits = tl.load(logits_row_ptr + cols, mask=mask, other=-float('inf')).to(tl.float32)
        
        # Update running max and sum for logsumexp
        curr_max = tl.max(logits, axis=0)
        new_max = tl.maximum(max_val, curr_max)
        sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(logits - new_max), axis=0)
        max_val = new_max
    
    lse = max_val + tl.log(sum_exp)
    tl.store(lse_ptr + row_idx, lse)
    
    # Load target and compute loss
    target = tl.load(targets_ptr + row_idx).to(tl.int64)
    target_logit = tl.load(logits_row_ptr + target).to(tl.float32)
    loss = lse - target_logit
    
    tl.store(losses_ptr + row_idx, loss)


@triton.jit  
def fused_cross_entropy_bwd_kernel(
    grad_input_ptr, grad_output_ptr, lse_ptr, logits_ptr, targets_ptr,
    stride_logits_row, stride_grad_row,
    n_cols,
    BLOCK_SIZE: tl.constexpr
):
    """
    Fused backward pass: grad[i] = softmax[i] - (1 if i==target else 0)
    Recomputes softmax from logits and lse, avoiding memory overhead.
    """
    row_idx = tl.program_id(0).to(tl.int64)
    logits_row_ptr = logits_ptr + row_idx * stride_logits_row
    grad_row_ptr = grad_input_ptr + row_idx * stride_grad_row
    
    lse = tl.load(lse_ptr + row_idx)
    grad_loss = tl.load(grad_output_ptr + row_idx)
    target = tl.load(targets_ptr + row_idx).to(tl.int64)
    
    for off in range(0, n_cols, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        mask = cols < n_cols
        
        logits = tl.load(logits_row_ptr + cols, mask=mask, other=0.0).to(tl.float32)
        # softmax = exp(logits - lse)
        probs = tl.exp(logits - lse)
        # gradient = grad_loss * (probs - one_hot(target))
        is_target = (cols == target).to(tl.float32)
        grad = grad_loss * (probs - is_target)
        
        tl.store(grad_row_ptr + cols, grad.to(tl.bfloat16), mask=mask)


class FusedCrossEntropyLoss(torch.autograd.Function):
    """
    Memory-efficient cross-entropy that doesn't materialize the full probability matrix.
    For vocab=50k, batch=16, seq=1024: saves ~3GB of memory per forward pass.
    """
    @staticmethod
    def forward(ctx, logits, targets):
        n_rows, n_cols = logits.shape
        
        losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
        lse = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
        
        logits = logits.contiguous()
        targets = targets.contiguous()
        
        # One thread block per row (token)
        grid = (n_rows,)
        fused_cross_entropy_fwd_kernel[grid](
            logits, losses, lse, targets,
            logits.stride(0),
            n_cols,
            BLOCK_SIZE=1024,
            num_warps=8,
        )
        
        ctx.save_for_backward(logits, targets, lse)
        return losses
    
    @staticmethod
    def backward(ctx, grad_output):
        logits, targets, lse = ctx.saved_tensors
        n_rows, n_cols = logits.shape
        
        grad_input = torch.empty_like(logits, dtype=torch.bfloat16)
        grad_output = grad_output.contiguous()
        
        grid = (n_rows,)
        fused_cross_entropy_bwd_kernel[grid](
            grad_input, grad_output, lse, logits, targets,
            logits.stride(0), grad_input.stride(0),
            n_cols,
            BLOCK_SIZE=1024,
            num_warps=8,
        )
        return grad_input, None


def fused_cross_entropy(logits, targets):
    """Drop-in replacement for F.cross_entropy with fused kernel."""
    return FusedCrossEntropyLoss.apply(logits, targets).mean()

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self,
                 embed_dim,
                 num_heads,
                 rotary_emb,
                 mlp_ratio=4,
                 dropout_prob=0.1,
                 causal=True,
                ): 
        """
        Initialize a complete transformer block.
        
        APPROACH:
        1. Multi-head self-attention for sequence modeling
        2. 1st Normalization (pre-norm architecture)
        3. MLP with specified expansion ratio
        4. 2nd Normalization
    
        TRANSFORMER BLOCK ARCHITECTURE:
        x → Norm → MultiHeadAttention → + (residual) →
            Norm → MLP → + (residual) → output
    
        NB: We use pre-norm architecture (before attention/MLP)
        """
    
        super().__init__()
        self.norm1 = nn.RMSNorm(embed_dim)
        self.mha = MultiHeadAttention(embed_dim, num_heads, rotary_emb, causal, dropout_prob)  # causal = masking out tokens
        self.norm2 = nn.RMSNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio * embed_dim, dropout_prob)
    
    def forward(self, x, cache=None):
        x1 = self.norm1(x)
        x2, cache = self.mha(x1, cache)  # will be used when generating tokens during inference
        x2 = x2 + x  # residual path
    
        x3 = self.norm2(x2)
        x3 = self.mlp(x3) + x2  # residual path
        return x3, cache

In [5]:
class GPT(nn.Module):
    """
    Complete GPT (Generative Pre-trained Transformer) model.

    This combines embeddings, positional encoding, multiple transformer blocks,
    and a language modeling head for text generation.
    """

    def __init__(self,
                 vocab_size,
                 embed_dim,
                 num_layers,
                 num_heads,
                 mlp_ratio=4,
                 dropout_prob=0.1,
                 is_causal=True,
                ):
        """
        Initialize complete GPT model.
        """
        super().__init__()

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio

        self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
        self.dropout = nn.Dropout(dropout_prob)
        
        # Shared rotary embedding across all layers (more efficient for compilation)
        head_dim = embed_dim // num_heads
        self.rotary_emb = RotaryEmbedding(dim=head_dim)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, self.rotary_emb, mlp_ratio, dropout_prob, is_causal) 
            for _ in range(num_layers)
        ])
        self.norm = nn.RMSNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # weight tying

        # below shamefully stolen from nano-gpt
        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        # don't forget swiglu variant !
        for pn, p in self.named_parameters():
            if pn.endswith(("out_proj.weight", "down_proj.weight")):
                # Residual projections: scale down to prevent variance explosion
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.num_layers))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
       
    def forward(self, tokens):
        embeddings = self.embedding(tokens)
        x = self.dropout(embeddings)
        for b in self.blocks:
            x, _ = b(x)
        features = self.norm(x)  # normalized to stabilize training
        return self.lm_head(features)

    @property
    def device(self):
        return next(self.parameters()).device

    @torch.no_grad()
    def generate(self,
                 prompt_tokens,
                 max_new_tokens=50,
                 temperature=1.0,
                 top_k=0,
                 top_p=0.0,
                 use_cache=True,
                ):
        """
        Auto-regressive text generation loop

        prompt_tokens : tensor with input tokens, shape (1, n_tokens)
        max_new_tokens : how many new tokens we want to generate
        temperature : controls expressivity (lower = less, higher = more funky, degenerate cases : 0 = argmax, +inf = random guess)
        top_k : restrict prediction to top_k tokens to avoid sampling low prob garbage, set to 0 to disable, top_k ∈ [0, vocab_size]
        top_p : sample from smallest set with cumulative prob >= top_p (adapts to model confidence, usually top_k OR top_p), top_p ∈ [0, 1]
        use_cache : set to True to avoid re-computing expensive K, V matrices
        
        """
        
        self.eval()

        tokens_out = prompt_tokens.clone()
        current_tokens = prompt_tokens.clone()
        tokens_out = tokens_out.to(self.device)
        current_tokens = current_tokens.to(self.device)
        cache = [{} if use_cache else None for _ in range(len(self.blocks))]
        
        for _ in range(max_new_tokens):

            x = self.embedding(current_tokens)
            for i, b in enumerate(self.blocks):
                x, c_i = b(x, cache[i])
                cache[i] = c_i
            
            features = self.norm(x)
            logits = self.lm_head(features)    
            last_logits = logits[:, -1, :]
    
            if temperature == 0:
                # Greedy decoding if temp is 0
                next_token = torch.argmax(last_logits, dim=-1, keepdim=True)
            else:
                # "reshape the distribution", i.e crushing logits before softmax ~ uniform distribution etc.
                scaled_logits = last_logits / temperature
                
                # Only sample from top k tokens to avoid garbage prediction derailing whole prediction
                if int(top_k) > 0:
                    # most of probability mass in on a small amount of tokens, maybe 50 ?
                    values, indices = torch.topk(scaled_logits, top_k)
                    scaled_logits = torch.full_like(scaled_logits, float('-inf'))
                    scaled_logits.scatter_(1, indices, values)

                # TODO : DISABLE top_k + top_p ? Modern implementation *usually* only expose top_p
                if top_p > 0.0 and top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True, dim=-1)
                    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                    # Remove tokens with cumulative probability above the threshold
                    sorted_indices_to_remove = cumulative_probs > top_p
                    # Shift right to keep at least one token (the first one)
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0
                    # Set logits to -inf for tokens we want to remove
                    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                    scaled_logits[indices_to_remove] = float('-inf')
                
                # logits -> distribution probability
                probs = torch.softmax(scaled_logits, dim=-1)
                # Sample from prob distribution, nb : we don't simply take max prob token to allow "creativity"
                next_token = torch.multinomial(probs, num_samples=1)

            # Stop generating if model thinks the "document" is finished
            # eot_id = tokenizer.eot_token
            if next_token.item() == eot_id:
                break
            
            tokens_out = torch.cat([tokens_out, next_token], dim=1)

            # If caching, we only need to feed the newest token next time, otherwise full sequence
            current_tokens = next_token if use_cache else tokens_out
       
        return tokens_out

## Full Training

As I do not have unlimited compute, our datasets need to be as clean as possible and high signal. C4 / FineWeb will have massive overlap with OpenWebText (from same common crawl), wikipedia is probably going to be oversampled, which should be fine, bookcorpus is mostly self-published fictions, should have less overlap but it has known issues.

I thought adding maths and code would be a good idea but it introduces some issues with learning new semantics / syntax etc. might be too ambitious for a small(er) scale model. Arxiv is definitely a no-no, open-web-math might be fine as it's more informal.

In [6]:
#### CONFIG #####

# Basically GPT-2 Small
block_size = 1024  # 512 for debug, 2048 realistically too slow
batch_size = 32    # sweet spot, have some room to increase
embed_dim = 768
num_layers = 12
num_heads = 12
dropout_prob = 0.05  # usually set as 0 but my model being relatively smaller, should help
mlp_ratio = 4  # standard 4x expansion


# Training
MAX_STEPS = 700000                         # Total number of micro-batches to process
GRAD_ACCUM_STEPS = 20                      # Better for GPU ?
LOG_INTERVAL = MAX_STEPS // 1000           # Log every xxx micro-batches
num_workers = 4                            # For data loading
prefetch = 8
device = "cuda"
model_path = f"gpt_full_run.pt"            # where do we store trained model

In [7]:
# ============================================================================
# Setup Tokenizer (tiktoken gpt2 - 50,257 vocab, much better for small models)
# ============================================================================
print("Loading tokenizer...")
tokenizer = tiktoken.get_encoding("gpt2")
eot_id = tokenizer.eot_token
# original gpt2 tokenizer has an awkward dim for GPUs, pad it to nearest 128 multiple, speeds things up decently
vocab_size = round_up(tokenizer.n_vocab, 128)


print(f"✓ Loaded tiktoken gpt2")
print(f"  Vocab size: {vocab_size:,}")
print(f"  EOT token ID: {eot_id}")

# ============================================================================
# Load and Clean Datasets
# ============================================================================
print("\nLoading & cleaning up datasets...")
cleaned_datasets = {
    # Main dataset, high quality web crawl
    "fineweb-edu": load_dataset(
        "HuggingFaceFW/fineweb-edu", 
        name="sample-10BT", 
        split="train",
        trust_remote_code=True
    ),
    
    # Wikipedia - use larger dataset
    "wiki": load_dataset(
        "wikimedia/wikipedia", 
        "20231101.en", 
        split="train[:15%]",  # will have some overlap with fineweb
        trust_remote_code=True
    ),
    
    # For flavor / diversity
    "bookcorpus": load_dataset(
        "bookcorpus", 
        split="train[:10%]",
        trust_remote_code=True
    ),
}

cleaned_datasets = {n: clean_columns(d) for n, d in cleaned_datasets.items()}

print(f"\n✓ Loaded {len(cleaned_datasets)} datasets")
for n, ds in cleaned_datasets.items():
    print(f"  Dataset {n}: {len(ds):,} examples")


print("\nConcatenating datasets...")
train_ds = concatenate_datasets(list(cleaned_datasets.values()))
print(f"  Combined size: {len(train_ds):,} examples")

# TODO : Way too slow on my machine
# # ============================================================================
# # Interleave and Shuffle for Better Mixing
# # ============================================================================
# from datasets import interleave_datasets
# print("\nInterleaving datasets...")
# train_ds = interleave_datasets(
#     list(cleaned_datasets.values()),
#     probabilities=[0.70, 0.15, 0.10, 0.05],  # Adjust weights as needed
#     seed=42,
#     stopping_strategy="all_exhausted"
# )

print("Shuffling...")
train_ds = train_ds.shuffle(seed=42)
print(f"✓ Final Train Size: {len(train_ds):,} rows")


def clean_text(example):
    """Remove special tokens and other artifacts"""
    text = example["text"]
    
    # Remove all common special tokens
    special_tokens = [
        "<|endoftext|>", "<|im_start|>", "<|im_end|>",
        "<|system|>", "<|user|>", "<|assistant|>",
        "<s>", "</s>", "[INST]", "[/INST]",
        "<<SYS>>", "<</SYS>>"
    ]
    
    for token in special_tokens:
        text = text.replace(token, "")
    
    example["text"] = text
    return example

# Tokenizer is going to complain if special tokens are found in training data
print("\nCleaning special tokens from datasets...")
train_ds = train_ds.map(
    clean_text,
    num_proc=os.cpu_count() // 2,
    desc="Cleaning special tokens"
)

# ============================================================================
# Document Packing Function
# ============================================================================
def pack_documents(examples):
    """
    Concatenate all documents in the batch, then slice into fixed-size blocks.
    Each document is terminated with exactly ONE EOT token.
    
    Output chunks are of length (block_size + 1), suitable for
    x = chunk[:-1], y = chunk[1:].
    """
    all_tokens = []

    for text in examples["text"]:
        if not text or not text.strip():
            continue

        try:
            doc_ids = tokenizer.encode(text)
        except Exception as e:
            continue  # skip bad docs safely

        doc_ids.append(eot_id)  # exactly one end-of-text
        all_tokens.extend(doc_ids)

    # Now chop into blocks of (block_size + 1)
    chunks = []
    total = len(all_tokens)

    for i in range(0, total, block_size):
        chunk = all_tokens[i : i + block_size + 1]
        if len(chunk) == block_size + 1:
            chunks.append(chunk)
        # else: drop the final tiny tail (standard practice)

    return {"chunk_ids": chunks}


# ============================================================================
# Apply Processing
# ============================================================================
print("\nTokenizing and packing documents (this may take a few minutes)...")
train_tokenized = train_ds.map(
    pack_documents,
    batched=True,
    batch_size=1250,
    num_proc=multiprocessing.cpu_count() // 2,
    remove_columns=train_ds.column_names,
    desc="Packing documents"
)

print(f"\n✓ Tokenization complete!")
print(f"  Total chunks: {len(train_tokenized):,}")
print(f"  Chunk size: {block_size + 1} tokens")
print(f"  Approx total tokens: {len(train_tokenized) * block_size:,}")

# ============================================================================
# Create DataLoader
# ============================================================================
class TokenDataset(Dataset):
    def __init__(self, tokenized_data):
        self.data = tokenized_data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        chunk = self.data[idx]["chunk_ids"]
        # Return input (x) and target (y) shifted by 1
        # [0, 1, 2, ..., block_size-1]
        # [1, 2, 3, ..., block_size]
        return chunk[:-1].long(), chunk[1:].long()

# Can also use numpy format and cast to tensors in the loader, not sure which one is slower
train_tokenized.set_format(type="torch", columns=["chunk_ids"])
token_dataset = TokenDataset(train_tokenized)

train_loader = DataLoader(
    token_dataset,
    batch_size=batch_size,
    prefetch_factor=prefetch,
    shuffle=True,
    drop_last=True,  # CRITICAL for torch.compile ! If one batch has dif shape it will trigger re-compilation and slow down training a lot
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True,
)

print(f"\n✓ DataLoader ready")
print(f"  Batches per epoch: {len(train_loader):,}")
print(f"  Tokens per batch: {batch_size * block_size:,}")

Loading tokenizer...
✓ Loaded tiktoken gpt2
  Vocab size: 50,304
  EOT token ID: 50256

Loading & cleaning up datasets...


Resolving data files:   0%|          | 0/2410 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/98 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]


✓ Loaded 4 datasets
  Dataset fineweb-edu: 9,672,101 examples
  Dataset wiki: 640,781 examples
  Dataset bookcorpus: 7,400,423 examples
  Dataset maths-qa: 200,035 examples

Concatenating datasets...
  Combined size: 17,913,340 examples
Shuffling...
✓ Final Train Size: 17,913,340 rows

Cleaning special tokens from datasets...

Tokenizing and packing documents (this may take a few minutes)...

✓ Tokenization complete!
  Total chunks: 10,450,143
  Chunk size: 1025 tokens
  Approx total tokens: 10,700,946,432

✓ DataLoader ready
  Batches per epoch: 326,566
  Tokens per batch: 32,768


In [8]:
# --- Verification ---
print("-" * 20)
print(f"Final Train Batches: {len(train_loader)}")
x, y = next(iter(train_loader))
print(f"Input x shape: {x.shape}")  # Should be [Batch, Block_Size]
print(f"Target y shape: {y.shape}") # Should be [Batch, Block_Size]

s = 25
print(f"\nSanity Check (Shifting {s}):")
print(f"x[0, -{s}:]: {x[0, -s:].tolist()}")
print(f"x[0, -{s}:]: {tokenizer.decode(x[0, -s:].tolist())}")
print()
print(f"y[0, -{s}:]: {y[0, -s:].tolist()}")
print(f"y[0, -{s}:]: {tokenizer.decode(y[0, -s:].tolist())}")

--------------------
Final Train Batches: 326566
Input x shape: torch.Size([32, 1024])
Target y shape: torch.Size([32, 1024])

Sanity Check (Shifting 25):
x[0, -25:]: [16815, 286, 362, 4, 284, 604, 4, 287, 19922, 4760, 3871, 13, 17, 198, 5886, 1802, 45752, 7652, 357, 1533, 11, 311, 10426, 16, 14]
x[0, -25:]:  prevalence of 2% to 4% in POAG patients.2
Over 100 genomic regions (eg, SIX1/

y[0, -25:]: [286, 362, 4, 284, 604, 4, 287, 19922, 4760, 3871, 13, 17, 198, 5886, 1802, 45752, 7652, 357, 1533, 11, 311, 10426, 16, 14, 50]
y[0, -25:]:  of 2% to 4% in POAG patients.2
Over 100 genomic regions (eg, SIX1/S


In [9]:
# --- Model ---
model_config = {
    "vocab_size": vocab_size,
    "embed_dim": embed_dim,
    "num_layers": num_layers,
    "num_heads": num_heads,
    "mlp_ratio": mlp_ratio,
    "dropout_prob": dropout_prob,
    "is_causal": True,
}

print("Initializing model with config : ")
pprint(model_config)

model = GPT(**model_config).to(device)

Initializing model with config : 
{'dropout_prob': 0.05,
 'embed_dim': 768,
 'is_causal': True,
 'mlp_ratio': 4,
 'num_heads': 12,
 'num_layers': 12,
 'vocab_size': 50304}


In [10]:
# # # TODO : implement resume from checkpoint (i.e optimizer state, step count etc.)
ckpt_path = "gpt_full_run_250000.pt"
state_dict = strip_compile_prefix(torch.load(ckpt_path, map_location=device))
model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

Using standard practice :
- linear warmup + cosine schedule
- big weight decay (empirically proven to be beneficial)
- gradient accumulation to emulate large batch size

Nb : hearing contradicting statements about label_smoothing so disabled for now, we also only decay linear weights unlike nanogpt that also decay embeddings (NanoGPT does not follow modern best practices here apparently)

In [11]:
# --- Optimizer ---
# Separate parameters into decay and no-decay groups
decay_params = []
no_decay_params = []

for name, param in model.named_parameters():
    if not param.requires_grad:
        continue
    
    # Don't apply weight decay to:
    # - Norm parameters (scale/weight etc.)
    # - Embedding table (which is tied to lm_head)
    # - Any bias terms if present
    if any(keyword in name.lower() for keyword in ['norm', 'bias', 'embed', 'embedding', 'lm_head']):
        no_decay_params.append(param)
    else:
        decay_params.append(param)

print(f"\nDecay params: {len(decay_params)}, No decay params: {len(no_decay_params)}")

total_optim_steps = MAX_STEPS // GRAD_ACCUM_STEPS
print(f"Total Micro-batches: {MAX_STEPS}")
print(f"Gradient Accumulation: {GRAD_ACCUM_STEPS}")
print(f"Total Optimizer Updates: {total_optim_steps}")

# Perform ADAM update with a single kernel
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
print("Using fused AdamW : ", use_fused)

optimizer = torch.optim.AdamW([
    {'params': decay_params, 'weight_decay': 0.1},
    {'params': no_decay_params, 'weight_decay': 0.0}
], 
    lr=3e-4, 
    betas=(0.9, 0.95),
    **extra_args
)

warmup_steps = int(total_optim_steps * 0.05)  # 5% of total optim steps
warmup_scheduler = LinearLR(
    optimizer, 
    start_factor=0.1,  # Start at 3e-5
    total_iters=warmup_steps
)
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=total_optim_steps - warmup_steps, 
    eta_min=1e-5
)
scheduler = SequentialLR(
    optimizer, 
    schedulers=[warmup_scheduler, cosine_scheduler], 
    milestones=[warmup_steps]
)

# Use fused cross-entropy for efficiency (defined above with Triton kernels)
# Saves ~3GB memory and is ~2x faster than standard F.cross_entropy for large vocab
print("Using fused cross-entropy loss (Triton kernel)")


Decay params: 48, No decay params: 26
Total Micro-batches: 700000
Gradient Accumulation: 20
Total Optimizer Updates: 35000
Using fused AdamW :  True
Using fused cross-entropy loss (Triton kernel)


In [12]:
# reduce-overhead mode: faster compilation, better for iterative development
# max-autotune mode : overkill and causes very long compile times with marginal gains
# default mode should be plenty enough
# torch._inductor.config.coordinate_descent_tuning = True can be useful but makes compile MUCH slower
# Must use dynamic=False or else it's much slower

torch._inductor.config.coordinate_descent_tuning = True
model = torch.compile(model, dynamic=False, fullgraph=True)
model.to(device)
model.train()

# Multiple warmup passes to force compilation
print("Warming up compilation...")
x = torch.randint(0, vocab_size, (batch_size, block_size), device=device)
y = x.clone()

s = time.time()
for i in range(10):
    start = time.time()
    with torch.autocast("cuda", torch.bfloat16):
        logits = model(x)
        loss = fused_cross_entropy(
            logits.view(-1, logits.size(-1)),
            y.view(-1)
        )
    loss.backward()
    model.zero_grad(set_to_none=True)
    torch.cuda.synchronize()
    print(f"  Warmup {i+1}: {time.time() - start:.2f}s")

print(f"✓ Model fully compiled in {time.time() - s:.2f}s")

x = y = None
del x
del y

Warming up compilation...
  Warmup 1: 156.37s
  Warmup 2: 68.05s
  Warmup 3: 0.28s
  Warmup 4: 0.28s
  Warmup 5: 0.28s
  Warmup 6: 0.28s
  Warmup 7: 0.28s
  Warmup 8: 0.28s
  Warmup 9: 0.28s
  Warmup 10: 0.28s
✓ Model fully compiled in 226.65s


In [13]:
# second pass (should be fast)
x = torch.randint(0, vocab_size, (batch_size, block_size), device=device)
y = x.clone()
start = time.time()
with torch.autocast("cuda", torch.bfloat16):
    loss = fused_cross_entropy(model(x).view(-1, vocab_size), y.view(-1))
loss.backward()
model.zero_grad()  # cleanup gradients
print(f"✓ Second pass : {time.time() - start:.2f}s")
x = y = None
del x
del y

✓ Second pass : 0.03s


Ideally we should turn this into a python script and run the training from the terminal, would eliminate some python overhang and squeeze out some speed zzzzzz

In [None]:
# --- CSV Logger ---
log_file = f'{model_path.split(".")[0]}__{datetime.now().strftime("%Y-%m-%d")}__pretraining_logs.csv'
print("Saving logs in : ", log_file)
file_exists = os.path.isfile(log_file)
with open(log_file, "a", newline="") as f:
    writer = csv.writer(f)
    if not file_exists:
        writer.writerow(["micro_step", "optim_step", "loss", "lr", "tokens_seen", "tokens_per_sec", "timestamp"])

# --- Training Loop ---
micro_step = 0
optim_step = 0
tokens_seen = 0
# Accumulate loss on GPU to avoid CPU-GPU sync every step
# Only call .item() at log intervals (every 500 steps instead of every step)
running_loss = torch.zeros(1, device=device)

start_time = time.time()
start_training = time.time()
last_tokens_seen = 0

model_params = decay_params + no_decay_params
optimizer.zero_grad(set_to_none=True)
while micro_step < MAX_STEPS:
    for x, y in train_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        B, T = x.shape
        tokens_seen += B * T

        # --- Forward ---
        with torch.autocast(device_type="cuda", dtype=dtype):
            logits = model(x)
            loss = fused_cross_entropy(
                logits.view(-1, logits.size(-1)),
                y.view(-1)
            )

        # --- Backward (gradient accumulation) ---
        (loss / GRAD_ACCUM_STEPS).backward()

        # --- Optimizer step ---
        if (micro_step + 1) % GRAD_ACCUM_STEPS == 0:
            torch.nn.utils.clip_grad_norm_(model_params, 1.0)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()
            optim_step += 1

        # --- Bookkeeping (.item() triggers cpu-gpu synchro so we avoid it) ---
        running_loss += loss.detach()
        micro_step += 1

        # --- Logging ---
        if micro_step % LOG_INTERVAL == 0:
            # Only sync with CPU here (once per 500 steps, not every step)
            avg_loss = (running_loss / LOG_INTERVAL).item()
            
            elapsed = time.time() - start_time
            tokens_delta = tokens_seen - last_tokens_seen
            tokens_per_sec = tokens_delta / elapsed

            current_lr = optimizer.param_groups[0]["lr"]
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

            print(
                f"step {micro_step:06d} | "
                f"opt_step {optim_step:04d} | "
                f"loss {avg_loss:.3f} | "
                f"lr {current_lr:.2e} | "
                f"{tokens_per_sec:,.0f} tok/s"
            )

            with open(log_file, "a", newline="") as f:
                writer = csv.writer(f)
                writer.writerow([
                    micro_step,
                    optim_step,
                    f"{avg_loss:.4f}",
                    f"{current_lr:.2e}",
                    tokens_seen,
                    int(tokens_per_sec),
                    timestamp,
                ])

            running_loss.zero_()  # Reset on GPU
            start_time = time.time()
            last_tokens_seen = tokens_seen

        # --- Checkpointing ---
        if micro_step % 50_000 == 0 and micro_step > 0:
            mid_model_path = model_path.replace(".pt", f"_{micro_step}.pt")
            print(f"Saving intermediate model to {mid_model_path}")
            torch.save(model.state_dict(), mid_model_path)

        # --- Exit ---
        if micro_step >= MAX_STEPS:
            elapsed = int(time.time() - start_training)
            h, m, s = elapsed // 3600, (elapsed % 3600) // 60, elapsed % 60
            print(f"\nProcessed {tokens_seen:,} tokens in {h:02d}:{m:02d}:{s:02d}")
            print(f"Saving final model to {model_path}")
            # avoid saving 2 times same model
            if model_path == mid_model_path:
                os.remove(mid_model_path)
            torch.save(model.state_dict(), model_path)
            break

Saving logs in :  gpt_full_run__2026-01-22__pretraining_logs.csv


In [36]:
prompt = "Redtoothed triggerfish are normally deep blue or purple with a light blue head. They are" 
x = torch.tensor(tokenizer.encode(prompt))

model.eval()
model.to("cuda")
out = model.generate(
    x.unsqueeze(0).to("cuda"),
    max_new_tokens=200,
    temperature=0.9,
    top_p=0.95,
    top_k=0,
    use_cache=True,
)

print("\nOutput : ", tokenizer.decode(out[0].tolist()))


Output :  Redtoothed triggerfish are normally deep blue or purple with a light blue head. They are often found in calmer waters.
At the Zoo
Diptops (also known as North Sea lampreys)
These are small, fairly small freshwater fish and are only found in the Kerguelen, Ardoch, Vitek and Wankeln rivers. They have a wide flattened body that is most often used as a display tank. They are normally white, but can also be bright green and red. They can grow to over 4 metres in length and may live up to 12 years. They are very beautiful, very colourful and very hardy. They are well suited to tanks that are kept in slow running waters.
Great white mako (Paliwa mako)
The Great white mako (Paliwa mako) is a large fish, about the size of a carp. They are among the most common fish in the Nieuwe Plight. They grow up to 14-17 cm in length and can live up to 11 years in


Conclusion : 

2 days of training, 20B tokens seen, final loss ~ 2.99, managed to produce english looking text, factually incorrect and with some artifacts

## TO DO : 

- [x] ROPE for K, V, Q
- [x] ROPE for K, V cache, need double check
- [x] AdaptiveLogSoftmaxWithLoss ? Since with modern tokenizer output space is huge (Fused CE is it)
- [x] Figure out the compile stuff + hyperparameters to maximize throughput
- [x] Maybe check if compiled optimizer step does anything
- [x] Double check training loop + if we can use `xformers.swiglu` with compile, maybe with explicit `torch.compiler.cudagraph_mark_step_begin()` ?
- [x] Chunk documents properly to avoid topic jumps
- [ ] Loss for `<eos>` prediction (avoid endless rambling ?)
- [x] Top k sampling
- [x] Top p nucleus
- [x] Temperature
- [x] Add stop token / EOS handling
- [x] Training on a real problem to see how far we can push current model
- [x] Clean up / Revisit markdown / maths
- [ ] Explore hyper connections and manifold constrained HC
- [x] Check newer architectures / design choices (https://github.com/lucidrains git is a gold mine)
- [x] MUON optimizer ?  -> not worth for a small model