# GPT

SFT part : teach the model the assistant format

Basically : 
- the prompt tokens are masked (no loss)
- the response tokens are trained on (loss applied)

In [None]:
# Standard library
import csv
import math
import multiprocessing
import os
import random
import time
from pprint import pprint
from datetime import datetime

# Environment config
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Third-party
import numpy as np
from datasets import Dataset as ds, concatenate_datasets, load_dataset
from rotary_embedding_torch import RotaryEmbedding
from tokenizers import Tokenizer

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# Torch runtime config
torch.set_float32_matmul_precision("medium")
torch.cuda.empty_cache()

# Custom
from utils import count_parameters, load_synthetic_data, strip_compile_prefix, round_up

## Config & Model Definition

Mostly the same code as the pretraining notebook, including Flash Attention, RMSNorm etc.

In [None]:
#### CONFIG #####

# Basically GPT-2 Small
block_size = 1024
batch_size = 32
embed_dim = 768
num_layers = 12
num_heads = 12
dropout_prob = 0  # <!> finetuning, standard practice to disable dropout <!>
mlp_ratio = 4  # standard 4x expansion
pretrained_weights = "gpt_full_run.pt"

# Tokenizer
ROLE_TOKENS = ["<|user|>", "<|assistant|>"]
IGNORE_INDEX = -100  # to mask out the loss

# Training
NUM_EPOCHS = 3  # not too many or we're going to overfit our Q/A data
num_workers = 4
prefetch = 8
dtype = torch.bfloat16
device = "cuda"
print("torch.cuda.is_bf16_supported()", torch.cuda.is_bf16_supported())

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 rotary_emb: RotaryEmbedding,
                 causal: bool = True,
                 dropout: float = 0.1
                ):
        super().__init__()
        if embed_dim % num_heads != 0:
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads}).")
        
        self.causal = causal
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.dropout_p = dropout
        
        # Fused QKV projection: 3x the output size
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        
        # Shared rotary embedding (passed from GPT model)
        self.rotary_emb = rotary_emb
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
    
    def forward(self, x, k_v_cache=None):
        B, T, _ = x.shape
        using_cache = k_v_cache is not None and "K" in k_v_cache
    
        # 1. Single fused projection
        if using_cache:
            x_q = x[:, -1:, :]
            qkv = self.qkv_proj(x_q)  # (B, 1, 3 x embed_dim)
        else:
            qkv = self.qkv_proj(x)  # (B, T, 3 x embed_dim)
        
        # 2. Split into Q, K, V
        Q, K, V = qkv.chunk(3, dim=-1)  # Each is (B, T, embed_dim)
        
        def split_heads(t):
            return t.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 3. Split heads -> (B, H, T, D_head)
        Q = split_heads(Q)
        K = split_heads(K)
        V = split_heads(V)
    
        # 4. Apply RoPE 
        if using_cache:
            past_len = k_v_cache["K"].shape[-2]
            Q = self.rotary_emb.rotate_queries_or_keys(Q, offset=past_len)
            K = self.rotary_emb.rotate_queries_or_keys(K, offset=past_len)
            
            K = torch.cat([k_v_cache["K"], K], dim=-2)
            V = torch.cat([k_v_cache["V"], V], dim=-2)
            # When using the cache, the "causality" is already enforced by the fact that we are passing 1 query token against all valid past keys 
            # We don't need a mask as we want the current token to attend to everything in the history
            is_causal_step = False
        else:
            Q = self.rotary_emb.rotate_queries_or_keys(Q)
            K = self.rotary_emb.rotate_queries_or_keys(K)
            is_causal_step = self.causal
    
        # 5. Update cache
        if k_v_cache is not None:
            k_v_cache["K"] = K.detach()  # we will never .backward on these
            k_v_cache["V"] = V.detach()  
    
        # 6. Attention
        out = F.scaled_dot_product_attention(
            query=Q,
            key=K,
            value=V,
            attn_mask=None, 
            dropout_p=self.dropout_p if self.training else 0.0,
            is_causal=is_causal_step
        )
        
        # 7. Merge heads
        out = out.transpose(1, 2).contiguous().view(B, -1, self.embed_dim)

        # 8. Linear projection
        return self.out_proj(out), k_v_cache



class MLP(nn.Module):
    def __init__(self, embed_dim, hidden_dim=None, dropout_prob=0.1):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * embed_dim

        hidden_dim = round_up(2 * hidden_dim // 3, 8)

        # Fused projection
        self.gate_up_proj = nn.Linear(embed_dim, 2 * hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=False)
        self.dropout = nn.Dropout(dropout_prob)
       
    def forward(self, x):
        gate_up = self.gate_up_proj(x)
        gate, up = gate_up.chunk(2, dim=-1)
        return self.dropout(self.down_proj(F.silu(gate) * up))



class TransformerBlock(nn.Module):
    def __init__(self,
                 embed_dim,
                 num_heads,
                 rotary_emb,
                 mlp_ratio=4,
                 dropout_prob=0.1,
                 causal=True,
                ): 
        """
        Initialize a complete transformer block.
        
        APPROACH:
        1. Multi-head self-attention for sequence modeling
        2. 1st Normalization (pre-norm architecture)
        3. MLP with specified expansion ratio
        4. 2nd Normalization
    
        TRANSFORMER BLOCK ARCHITECTURE:
        x → Norm → MultiHeadAttention → + (residual) →
            Norm → MLP → + (residual) → output
    
        NB: We use pre-norm architecture (before attention/MLP)
        """
    
        super().__init__()
        self.norm1 = nn.RMSNorm(embed_dim)
        self.mha = MultiHeadAttention(embed_dim, num_heads, rotary_emb, causal, dropout_prob)  # causal = masking out tokens
        self.norm2 = nn.RMSNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio * embed_dim, dropout_prob)
    
    def forward(self, x, cache=None):
        x1 = self.norm1(x)
        x2, cache = self.mha(x1, cache)  # will be used when generating tokens during inference
        x2 = x2 + x  # residual path
    
        x3 = self.norm2(x2)
        x3 = self.mlp(x3) + x2  # residual path
        return x3, cache

In [None]:
class GPT(nn.Module):
    """
    Complete GPT (Generative Pre-trained Transformer) model.

    This combines embeddings, positional encoding, multiple transformer blocks,
    and a language modeling head for text generation.
    """

    def __init__(self,
                 vocab_size,
                 embed_dim,
                 num_layers,
                 num_heads,
                 mlp_ratio=4,
                 dropout_prob=0.1,
                 is_causal=True,
                ):
        """
        Initialize complete GPT model.
        """
        super().__init__()

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio

        self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
        self.dropout = nn.Dropout(dropout_prob)
        
        # Shared rotary embedding across all layers (more efficient for compilation)
        head_dim = embed_dim // num_heads
        self.rotary_emb = RotaryEmbedding(dim=head_dim)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, self.rotary_emb, mlp_ratio, dropout_prob, is_causal) 
            for _ in range(num_layers)
        ])
        self.norm = nn.RMSNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # weight tying

        # below shamefully stolen from nano-gpt
        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        # don't forget swiglu variant !
        for pn, p in self.named_parameters():
            if pn.endswith(("out_proj.weight", "down_proj.weight")):
                # Residual projections: scale down to prevent variance explosion
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.num_layers))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
       
    def forward(self, tokens):
        embeddings = self.embedding(tokens)
        x = self.dropout(embeddings)
        for b in self.blocks:
            x, _ = b(x)
        features = self.norm(x)  # normalized to stabilize training
        return self.lm_head(features)

    @property
    def device(self):
        return next(self.parameters()).device

    @torch.no_grad()
    def generate(self,
                 prompt_tokens,
                 max_new_tokens=50,
                 temperature=1.0,
                 top_k=0,
                 top_p=0.0,
                 use_cache=True,
                ):
        """
        Auto-regressive text generation loop

        prompt_tokens : tensor with input tokens, shape (1, n_tokens)
        max_new_tokens : how many new tokens we want to generate
        temperature : controls expressivity (lower = less, higher = more funky, degenerate cases : 0 = argmax, +inf = random guess)
        top_k : restrict prediction to top_k tokens to avoid sampling low prob garbage, set to 0 to disable, top_k ∈ [0, vocab_size]
        top_p : sample from smallest set with cumulative prob >= top_p (adapts to model confidence, usually top_k OR top_p), top_p ∈ [0, 1]
        use_cache : set to True to avoid re-computing expensive K, V matrices
        
        """
        
        self.eval()

        tokens_out = prompt_tokens.clone()
        current_tokens = prompt_tokens.clone()
        tokens_out = tokens_out.to(self.device)
        current_tokens = current_tokens.to(self.device)
        cache = [{} if use_cache else None for _ in range(len(self.blocks))]
        
        for _ in range(max_new_tokens):

            x = self.embedding(current_tokens)
            for i, b in enumerate(self.blocks):
                x, c_i = b(x, cache[i])
                cache[i] = c_i
            
            features = self.norm(x)
            logits = self.lm_head(features)    
            last_logits = logits[:, -1, :]
    
            if temperature == 0:
                # Greedy decoding if temp is 0
                next_token = torch.argmax(last_logits, dim=-1, keepdim=True)
            else:
                # "reshape the distribution", i.e crushing logits before softmax ~ uniform distribution etc.
                scaled_logits = last_logits / temperature
                
                # Only sample from top k tokens to avoid garbage prediction derailing whole prediction
                if int(top_k) > 0:
                    # most of probability mass in on a small amount of tokens, maybe 50 ?
                    values, indices = torch.topk(scaled_logits, top_k)
                    scaled_logits = torch.full_like(scaled_logits, float('-inf'))
                    scaled_logits.scatter_(1, indices, values)

                # TODO : DISABLE top_k + top_p ? Modern implementation *usually* only expose top_p
                if top_p > 0.0 and top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True, dim=-1)
                    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                    # Remove tokens with cumulative probability above the threshold
                    sorted_indices_to_remove = cumulative_probs > top_p
                    # Shift right to keep at least one token (the first one)
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0
                    # Set logits to -inf for tokens we want to remove
                    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                    scaled_logits[indices_to_remove] = float('-inf')
                
                # logits -> distribution probability
                probs = torch.softmax(scaled_logits, dim=-1)
                # Sample from prob distribution, nb : we don't simply take max prob token to allow "creativity"
                next_token = torch.multinomial(probs, num_samples=1)

            # Stop generating if model thinks the "document" is finished
            # eot_id = tokenizer.eot_token
            if next_token.item() == eot_id:
                break
            
            tokens_out = torch.cat([tokens_out, next_token], dim=1)

            # If caching, we only need to feed the newest token next time, otherwise full sequence
            current_tokens = next_token if use_cache else tokens_out
       
        return tokens_out

## Tokenizer

Must be the exact same used for pretraining, on top just add 2 extra tokens for assistant / user roles and assign these the same embedding we learnt during pretraining as the end of text token.

In [None]:
import tiktoken
from tiktoken.core import Encoding

base = tiktoken.get_encoding("gpt2")

special_tokens = {
    "<|endoftext|>": base.eot_token,        # must be preserved
    "<|user|>": base.n_vocab,
    "<|assistant|>": base.n_vocab + 1,
}

tokenizer = Encoding(
    name="gpt2-with-roles",
    pat_str=base._pat_str,
    mergeable_ranks=base._mergeable_ranks,
    special_tokens=special_tokens,
)

eot_id = tokenizer.eot_token
user_id = tokenizer.encode_single_token("<|user|>")
assistant_id = tokenizer.encode_single_token("<|assistant|>")

print(user_id, assistant_id, eot_id)
print(tokenizer.decode([user_id, assistant_id, eot_id]))

vocab_size = round_up(tokenizer.n_vocab, 128)
print("Vocab size:", tokenizer.n_vocab, "→ padded:", vocab_size)

## Supervised Fine-Tuning Dataset

### Goal
Train the model to generate assistant responses, NOT to predict instructions -> Autoregressive LMs predict the NEXT token at each position. We use this by masking instruction tokens in the loss calculation.

---

### The Process

#### 1. Format the Data
```
Input text: "<|user|>\n{instruction}\n<|assistant|>\n{response}"

```

#### 2. Tokenize
```
ids = [user_tok, What, is, 2, +, 2, ?, asst_tok, The, answer, is, 4, eot]
idx:   0        1     2   3  4  5  6  7         8    9      10  11 12
```

#### 3. Create Shifted Labels
```python
labels[:-1] = ids[1:]  # Each label is the NEXT token to predict

labels = [What, is, 2, +, 2, ?, asst_tok, The, answer, is, 4, eot, IGNORE]
```
**Meaning:** `labels[i]` = what should be predicted after seeing `ids[i]`

#### 4. Mask the Instruction
```python
# Find position of <|assistant|> token (position 7 in example)
labels[:assistant_pos+1] = IGNORE_INDEX (-100)

labels = [IGN, IGN, IGN, IGN, IGN, IGN, IGN, The, answer, is, 4, eot, IGN]
          └─────────instruction masked───────────┘  └────train here────┘
```

#### 5. Compute Loss (during training)
```python
logits = model(ids)  # Model predicts next token at each position
loss = CrossEntropyLoss(logits, labels, ignore_index=-100)
```

**What happens:**
- Position 0-6: `labels[i] = -100` → loss ignored (don't train on instruction)
- Position 7: predict "The" after `<|assistant|>` → **COMPUTE LOSS** ✓
- Position 8: predict "answer" after "The" → **COMPUTE LOSS** ✓
- Position 9: predict "is" after "answer" → **COMPUTE LOSS** ✓
- Position 10: predict "4" after "is" → **COMPUTE LOSS** ✓
- Position 11: predict `<eot>` after "4" → **COMPUTE LOSS** ✓

---

### Why This Works

**Causal Attention Mask**
- Prevents the model from "seeing" future tokens
- At position i, model only attends to tokens 0 to i

**Teacher Forcing**
- Model sees correct previous tokens during training
- Learns to predict the next one

**Masking with -100**
- `CrossEntropyLoss` ignores these positions
- Gradients only flow through response tokens

**Result:** Model learns "given instruction X, generate response Y" without wasting compute trying to predict the instruction itself.

---

### Key Facts

- `IGNORE_INDEX = -100` (standard PyTorch convention)
- Only ~5-20% of tokens typically contribute to loss (just the responses)
- The shift (`labels[i] = ids[i+1]`) aligns predictions with targets
- The masking + smaller size dataset is going to finetune behavior but not (or barely) knowledge !

In [None]:
# Load multiple datasets
print("Loading datasets...")
alpaca_cleaned = load_dataset("yahma/alpaca-cleaned")
platypus = load_dataset("garage-bAInd/Open-Platypus")
no_robots = load_dataset("HuggingFaceH4/no_robots")

# Format functions for each dataset
def format_alpaca(example):
    if example["input"].strip():
        user_text = (
            f"{example['instruction']}\n\n"
            f"{example['input']}"
        )
    else:
        user_text = example["instruction"]
    return {
        "user": user_text,
        "assistant": example["output"]
    }

def format_platypus(example):
    return {
        "user": example["instruction"],
        "assistant": example["output"]
    }

def format_no_robots(example):
    return {
        "user": example["prompt"],
        "assistant": example["messages"][1]["content"]
    }

# Map each dataset to common format
print("Formatting datasets...")
alpaca_formatted = alpaca_cleaned.map(
    format_alpaca, 
    remove_columns=alpaca_cleaned["train"].column_names
)
platypus_formatted = platypus.map(
    format_platypus, 
    remove_columns=platypus["train"].column_names
)
no_robots_formatted = no_robots.map(
    format_no_robots, 
    remove_columns=no_robots["train"].column_names
)

combined_datasets = [
    alpaca_formatted["train"],
    platypus_formatted["train"],
    no_robots_formatted["train"]
]

# OPTIONAL : can use any modern LLM to generate custom SFT data
if os.path.isfile("synthetic_sft_data.jsonl"):
    print("Loading synthetic data...")
    synthetic_data = load_synthetic_data("synthetic_sft_data.jsonl")
    print(f"Loaded {len(synthetic_data)} synthetic examples")
    
    # Transform list of dicts to dict of lists
    data_dict = {}
    for key in synthetic_data[0].keys():
        data_dict[key] = [item[key] for item in synthetic_data]
    
    synthetic_dataset = ds.from_dict(data_dict)
    
    # Print example to verify format
    print("\nExample from synthetic dataset:")
    pprint(synthetic_dataset[0])

    # Add to all datasets
    combined_datasets.append(synthetic_dataset)

# Combine all datasets
print("\n\nCombining datasets...")
combined_train = concatenate_datasets(combined_datasets)

print(f"Total training examples: {len(combined_train)}")
print("\nExample from combined dataset:")
pprint(next(iter(combined_train)))

def tokenize_sft(example):
    text = (
        "<|user|>\n"
        f"{example['user']}\n"
        "<|assistant|>\n"
        f"{example['assistant']}"
    )

    ids = tokenizer.encode(text, allowed_special=set(special_tokens.keys()))
    ids.append(eot_id)

    # Find assistant token
    try:
        assistant_pos = ids.index(assistant_id)
    except ValueError:
        # Return empty tensors that will be filtered out
        return {"input_ids": [], "labels": []}

    # Create labels as shifted version of ids
    labels = [IGNORE_INDEX] * len(ids)
    labels[:-1] = ids[1:]

    # Mask out everything before assistant response
    labels[:assistant_pos + 1] = [IGNORE_INDEX] * (assistant_pos + 1)

    # Truncate
    ids = ids[:block_size]
    labels = labels[:block_size]

    return {
        "input_ids": ids,
        "labels": labels,
    }

# Tokenize combined dataset
print("Tokenizing combined dataset...")
combined_tokenized = combined_train.map(
    tokenize_sft,
    remove_columns=combined_train.column_names,
    num_proc=4,
)

# Filter out empty examples (failed tokenization)
pre_filter_len = len(combined_tokenized)
combined_tokenized = combined_tokenized.filter(lambda x: len(x["input_ids"]) > 0)
print(f"Filtered {pre_filter_len - len(combined_tokenized)} invalid examples")
print(f"Final dataset size: {len(combined_tokenized)}")

In [None]:
def collate_fn(batch):
    # Inputs from user / assistant conversations are of variable length -> pad for training
    # To squeeze out performance, can assign inputs to buckets or pad with fixed len -> compile model
    # Here the data is reasonably sized so we can skip
    batch = [x for x in batch if x is not None]
    max_len = max(len(x["input_ids"]) for x in batch)
    
    input_ids = []
    labels = []
    
    for x in batch:
        pad_len = max_len - len(x["input_ids"])
        input_ids.append(
            x["input_ids"] + [eot_id] * pad_len
        )
        labels.append(
            x["labels"] + [IGNORE_INDEX] * pad_len
        )
    
    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "labels": torch.tensor(labels, dtype=torch.long),
    }

train_loader = DataLoader(
    combined_tokenized,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    prefetch_factor=prefetch,
    pin_memory=True,
)

print(f"DataLoader created with {len(train_loader)} batches of {batch_size} seqs")

## Model Loading

Start from pretrained model, set `dropout` to 0, extend embedding table to the new role tokens

In [None]:
model_config = {
    "vocab_size": vocab_size,  # This should already be 50304 from both pretraining and SFT
    "embed_dim": embed_dim,
    "num_layers": num_layers,
    "num_heads": num_heads,
    "mlp_ratio": mlp_ratio,
    "dropout_prob": dropout_prob,
}

print("Initializing model with config:")
pprint(model_config)
model = GPT(**model_config).to(device)

print(f"Loading: {pretrained_weights}...")
trained_weights = strip_compile_prefix(torch.load(pretrained_weights, map_location=device))
model.load_state_dict(trained_weights, strict=True)
print("Loaded pretrained weights")

# Initialize the new special token embeddings (they exist but were untrained padding)
# Indices 50257 (<|user|>) and 50258 (<|assistant|>) are within the padded vocab
with torch.no_grad():
    # Option 1: Copy from <|endoftext|> token
    model.embedding.weight[user_id] = model.embedding.weight[eot_id].clone()
    model.embedding.weight[assistant_id] = model.embedding.weight[eot_id].clone()
    
    # Option 2 (alternative): Use mean of all trained embeddings
    # mean_emb = model.embedding.weight[:base.n_vocab].mean(dim=0)
    # model.embedding.weight[user_id] = mean_emb
    # model.embedding.weight[assistant_id] = mean_emb

# If using weight tying, lm_head already shares weights, no action needed
# If NOT using weight tying, you'd need to update lm_head separately

print(f"Initialized special tokens: <|user|>={user_id}, <|assistant|>={assistant_id}")

# Ready to train!
model.train()
_, _ = count_parameters(model)

Weight decay is quite huge compared to habitual CNN but seems to be the standard for LLMs (empirical evidence), usually 0.1 for pretraining, 0.01 for SFT, helps avoiding memorization etc. Small lr because we're finetuning.

In [None]:
# Separate parameters into decay and no-decay groups (same logic as pretraining)
decay_params = []
no_decay_params = []

for name, param in model.named_parameters():
    if not param.requires_grad:
        continue
    # Don't apply weight decay to:
    # - Norm parameters (RMSNorm weights)
    # - Embedding table (tied with lm_head)
    # - Any bias terms
    if any(kw in name.lower() for kw in ['norm', 'bias', 'embed', 'embedding', 'lm_head']):
        no_decay_params.append(param)
    else:
        decay_params.append(param)

print(f"Decay params: {len(decay_params)}, No decay params: {len(no_decay_params)}")

optimizer = torch.optim.AdamW([
    {'params': decay_params, 'weight_decay': 0.01},
    {'params': no_decay_params, 'weight_decay': 0.0}
],
    lr=2e-5,
    betas=(0.9, 0.95),
)

# Learning rate scheduler: linear warmup + cosine decay
total_steps = NUM_EPOCHS * len(train_loader)
warmup_steps = int(total_steps * 0.05)  # 5% warmup

from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR

warmup_scheduler = LinearLR(
    optimizer,
    start_factor=0.1,  # start at 0.1 * lr
    total_iters=warmup_steps
)
cosine_scheduler = CosineAnnealingLR(
    optimizer,
    T_max=total_steps - warmup_steps,
    eta_min=1e-6  # min lr
)
scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup_scheduler, cosine_scheduler],
    milestones=[warmup_steps]
)

print(f"Total steps: {total_steps}, Warmup steps: {warmup_steps}")

We train over `epochs` because SFT data is 1) manageable 2) fixed Q/A samples while pretraining data is intractable and random contiguous chunk of texts.

In [None]:
import time

global_step = 0
total_batches = len(train_loader)
log_interval = 100

print(f"Starting SFT training: {NUM_EPOCHS} epochs, {total_batches} batches/epoch, {total_batches * NUM_EPOCHS} total steps")
print("-" * 80)

training_start = time.time()

for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()
    epoch_loss_sum = 0.0
    epoch_loss_count = 0

    for step, batch in enumerate(train_loader):
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)

        with torch.amp.autocast('cuda', dtype=dtype):
            logits = model(input_ids)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
                ignore_index=IGNORE_INDEX,
            )

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        global_step += 1

        # Track epoch loss
        epoch_loss_sum += loss.item()
        epoch_loss_count += 1

        if step % log_interval == 0:
            current_lr = optimizer.param_groups[0]["lr"]
            pct_complete = 100 * (step + 1) / total_batches
            elapsed = time.time() - epoch_start
            
            # Estimate time remaining for epoch
            if step > 0:
                steps_per_sec = step / elapsed
                remaining_steps = total_batches - step
                eta_sec = remaining_steps / steps_per_sec
                eta_str = f"{int(eta_sec // 60):02d}:{int(eta_sec % 60):02d}"
            else:
                eta_str = "--:--"
            
            avg_loss = epoch_loss_sum / epoch_loss_count
            
            print(
                f"epoch {epoch+1}/{NUM_EPOCHS} | "
                f"step {step:>5}/{total_batches} ({pct_complete:5.1f}%) | "
                f"loss {loss.item():.4f} (avg {avg_loss:.4f}) | "
                f"lr {current_lr:.2e} | "
                f"ETA {eta_str}"
            )

    # End of epoch summary
    epoch_elapsed = time.time() - epoch_start
    epoch_avg_loss = epoch_loss_sum / epoch_loss_count
    print("-" * 80)
    print(
        f"Epoch {epoch+1} complete | "
        f"avg loss: {epoch_avg_loss:.4f} | "
        f"time: {int(epoch_elapsed // 60):02d}:{int(epoch_elapsed % 60):02d}"
    )
    print("-" * 80)

    # Save model in full precision
    torch.save(model.state_dict(), f"GPT_SFT_epoch_{epoch}.pt")
    print(f"Saved checkpoint: GPT_SFT_epoch_{epoch}.pt")

# Final summary
total_time = time.time() - training_start
print("=" * 80)
print(f"Training complete! Total time: {int(total_time // 3600):02d}:{int((total_time % 3600) // 60):02d}:{int(total_time % 60):02d}")
print(f"Final checkpoint: GPT_SFT_epoch_{NUM_EPOCHS - 1}.pt")

In [None]:
@torch.no_grad()
def generate(
    model,
    prompt_text,
    max_new_tokens=150,
    temperature=0.7,
    top_p=0.9,
    stop_token_id=eot_id,
):
    """Generate response from a prompt string."""
    model.eval()
    
    enc = tokenizer.encode(prompt_text, allowed_special=set(special_tokens.keys()))
    input_ids = torch.tensor(enc, dtype=torch.long).unsqueeze(0).to(device)
    
    for _ in range(max_new_tokens):
        # Crop to block_size if needed
        input_ids_cond = input_ids if input_ids.shape[1] <= block_size else input_ids[:, -block_size:]
        
        logits = model(input_ids_cond)
        next_token_logits = logits[:, -1, :]
        
        # Apply temperature
        next_token_logits = next_token_logits / temperature
        
        # Top-p sampling
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True, dim=-1)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
            next_token_logits[indices_to_remove] = float('-inf')
        
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        input_ids = torch.cat([input_ids, next_token], dim=1)
        
        if stop_token_id is not None and next_token.item() == stop_token_id:
            break
    
    return tokenizer.decode(input_ids[0].tolist())


def format_prompt(user_message):
    """Format a user message into the chat template."""
    return f"<|user|>\n{user_message}\n<|assistant|>\n"


# =============================================================================
# Quick Evaluation Suite
# =============================================================================
print("=" * 80)
print("POST-TRAINING EVALUATION")
print("=" * 80)

eval_prompts = [
    # Basic instruction following
    ("Simple instruction", "Write a short greeting message."),
    
    # Knowledge recall (don't expect accuracy, just coherence)
    ("Knowledge", "What is the capital of France?"),
    
    # Reasoning (basic)
    ("Basic reasoning", "If I have 3 apples and buy 2 more, how many do I have?"),
    
    # Creative
    ("Creative", "Write a haiku about programming."),
    
    # Explanation
    ("Explanation", "Explain what a neural network is in simple terms."),
    
    # Code (if trained on code data)
    ("Code", "Write a Python function that adds two numbers."),
    
    # Multi-step
    ("Multi-step", "List 3 benefits of exercise."),
    
    # Refusal/boundary (interesting to see behavior)
    ("Edge case", "Summarize the following text: "),
]

model.eval()
for category, user_msg in eval_prompts:
    prompt = format_prompt(user_msg)
    
    print(f"\n[{category}]")
    print(f"User: {user_msg}")
    print("-" * 40)
    
    try:
        response = generate(model, prompt, max_new_tokens=150, temperature=0.7)
        # Extract just the assistant response
        if "<|assistant|>" in response:
            assistant_part = response.split("<|assistant|>\n")[-1]
            # Clean up any trailing special tokens
            assistant_part = assistant_part.replace("<|endoftext|>", "").strip()
        else:
            assistant_part = response
        print(f"Assistant: {assistant_part}")
    except Exception as e:
        print(f"Error: {e}")
    
    print("-" * 40)

# =============================================================================
# Quantitative checks
# =============================================================================
print("\n" + "=" * 80)
print("FORMAT COMPLIANCE CHECK")
print("=" * 80)

# Check if model properly terminates with EOT
test_prompts = [format_prompt(p) for _, p in eval_prompts[:3]]
eot_count = 0
total_length = 0

for prompt in test_prompts:
    enc = tokenizer.encode(prompt, allowed_special=set(special_tokens.keys()))
    input_ids = torch.tensor(enc, dtype=torch.long).unsqueeze(0).to(device)
    
    # Generate with greedy decoding for consistency
    model.eval()
    for _ in range(200):
        logits = model(input_ids[:, -block_size:] if input_ids.shape[1] > block_size else input_ids)
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
        input_ids = torch.cat([input_ids, next_token], dim=1)
        if next_token.item() == eot_id:
            eot_count += 1
            break
    
    total_length += input_ids.shape[1] - len(enc)

avg_response_len = total_length / len(test_prompts)
eot_rate = 100 * eot_count / len(test_prompts)

print(f"EOT termination rate: {eot_count}/{len(test_prompts)} ({eot_rate:.0f}%)")
print(f"Avg response length: {avg_response_len:.0f} tokens")
print(f"  → {'Good: Model learns to stop' if eot_rate > 50 else 'Warning: Model may ramble'}")
print(f"  → {'Good: Reasonable length' if 20 < avg_response_len < 150 else 'Check: Unusual response length'}")

print("\n" + "=" * 80)
print("Evaluation complete!")
print("=" * 80)