# GPT

SFT part : teach the model the assistant format

Basically : 
- the prompt tokens are masked (no loss)
- the response tokens are trained on (loss applied)

In [1]:
# Standard library
import csv
import math
import multiprocessing
import os
import random
import time
from pprint import pprint
from datetime import datetime

# Environment config
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Third-party
import numpy as np
from datasets import Dataset as ds, concatenate_datasets, load_dataset
from rotary_embedding_torch import RotaryEmbedding
from tokenizers import Tokenizer

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# Torch runtime config
torch.set_float32_matmul_precision("medium")
torch.cuda.empty_cache()

# Custom
from utils import count_parameters, load_synthetic_data, strip_compile_prefix

## Config & Model Definition

Mostly the same code as the pretraining notebook, including Flash Attention, RMSNorm etc.

In [2]:
#### CONFIG #####

# Basically GPT-2 Small
block_size = 1024
batch_size = 16
embed_dim = 768
num_layers = 12
num_heads = 12
dropout_prob = 0  # <!> finetuning, standard practice to disable dropout <!>
mlp_ratio = 4  # standard 4x expansion
pretrained_weights = "gpt_model_1024_158417.pt"

# Tokenizer
ROLE_TOKENS = ["<|user|>", "<|assistant|>"]
IGNORE_INDEX = -100  # to mask out the loss

# Training
NUM_EPOCHS = 3  # not too many or we're going to overfit our Q/A data
num_workers = 4
prefetch = 4
dtype = torch.bfloat16
device = "cuda"
print("torch.cuda.is_bf16_supported()", torch.cuda.is_bf16_supported())

torch.cuda.is_bf16_supported() True


In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, causal: bool = True, dropout: float = 0.1):
        super().__init__()
        if embed_dim % num_heads != 0:
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads}).")
        
        self.causal = causal
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.dropout_p = dropout
        
        # Fused QKV projection: 3x the output size
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        
        self.rotary_emb = RotaryEmbedding(dim=self.head_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
    
    def forward(self, x, k_v_cache=None):
        B, T, _ = x.shape
        using_cache = k_v_cache is not None and "K" in k_v_cache
    
        # 1. Single fused projection
        if using_cache:
            x_q = x[:, -1:, :]
            qkv = self.qkv_proj(x_q)  # (B, 1, 3 x embed_dim)
        else:
            qkv = self.qkv_proj(x)  # (B, T, 3 x embed_dim)
        
        # 2. Split into Q, K, V
        Q, K, V = qkv.chunk(3, dim=-1)  # Each is (B, T, embed_dim)
        
        def split_heads(t):
            return t.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 3. Split heads -> (B, H, T, D_head)
        Q = split_heads(Q)
        K = split_heads(K)
        V = split_heads(V)
    
        # 4. Apply RoPE 
        if using_cache:
            past_len = k_v_cache["K"].shape[-2]
            Q = self.rotary_emb.rotate_queries_or_keys(Q, offset=past_len)
            K = self.rotary_emb.rotate_queries_or_keys(K, offset=past_len)
            
            K = torch.cat([k_v_cache["K"], K], dim=-2)
            V = torch.cat([k_v_cache["V"], V], dim=-2)
        else:
            Q = self.rotary_emb.rotate_queries_or_keys(Q)
            K = self.rotary_emb.rotate_queries_or_keys(K)
    
        # 5. Update cache
        if k_v_cache is not None:
            k_v_cache["K"] = K
            k_v_cache["V"] = V
    
        # 6. Flash Attention
        out = F.scaled_dot_product_attention(
            query=Q,
            key=K,
            value=V,
            attn_mask=None, 
            dropout_p=self.dropout_p if self.training else 0.0,
            is_causal=self.causal and (Q.shape[-2] > 1)
        )
        
        # 7. Merge heads
        out = out.transpose(1, 2).contiguous().view(B, -1, self.embed_dim)

        # 8. Linear projection
        return self.out_proj(out), k_v_cache

In [4]:
class MLP(nn.Module):
    def __init__(self, embed_dim, hidden_dim=None, dropout_prob=0.1, use_swiglu=True):
        super().__init__()
        self.use_swiglu = use_swiglu
        if hidden_dim is None:
            hidden_dim = 4 * embed_dim
        # https://arxiv.org/pdf/2002.05202
        # We offer no explanation as to why these
        # architectures seem to work; 
        # we attribute their success, as all else, to divine benevolence.
        if self.use_swiglu:
            # Adjust hidden_dim to ~ match baseline # of parameters
            hidden_dim = int(2 * hidden_dim / 3)
            self.gate_proj = nn.Linear(embed_dim, hidden_dim, bias=False)
            self.up_proj = nn.Linear(embed_dim, hidden_dim, bias=False)
            self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=False)
        else:
            self.linear1 = nn.Linear(embed_dim, hidden_dim, bias=False)
            self.act = nn.GELU()
            self.linear2 = nn.Linear(hidden_dim, embed_dim, bias=False)
        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, x):
        if self.use_swiglu:
            out = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
        else:
            out = self.linear2(self.act(self.linear1(x)))
        return self.dropout(out)

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self,
                 embed_dim,
                 num_heads,
                 mlp_ratio=4,
                 dropout_prob=0.1,
                 causal=True,
                 use_swiglu=True,
                ): 
        """
        Initialize a complete transformer block.
        
        APPROACH:
        1. Multi-head self-attention for sequence modeling
        2. 1st Normalization (pre-norm architecture)
        3. MLP with specified expansion ratio
        4. 2nd Normalization
    
        TRANSFORMER BLOCK ARCHITECTURE:
        x → Norm → MultiHeadAttention → + (residual) →
            Norm → MLP → + (residual) → output
    
        NB: We use pre-norm architecture (before attention/MLP)
        """
    
        super().__init__()
        self.norm1 = nn.RMSNorm(embed_dim)  # Nb : modern architectures seem to use rmsnorm instead
        self.mha = MultiHeadAttention(embed_dim, num_heads, causal, dropout_prob)  # causal = masking out tokens
        self.norm2 = nn.RMSNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio * embed_dim, dropout_prob, use_swiglu)
    
    def forward(self, x, cache=None):
        x1 = self.norm1(x)
        x2, cache = self.mha(x1, cache)  # will be used when generating tokens during inference
        x2 = x2 + x  # residual path
        x3 = self.norm2(x2)
        x3 = self.mlp(x3) + x2  # residual path
        return x3, cache

In [6]:
class GPT(nn.Module):
    """
    Complete Generative Pre-trained Transformer model.
    """

    def __init__(self,
                 vocab_size,
                 embed_dim,
                 num_layers,
                 num_heads,
                 mlp_ratio=4,
                 dropout_prob=0.1,
                 use_swiglu=True,
                ):
        """
        Initialize complete GPT model.
        """
        super().__init__()

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio

        self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
        self.dropout = nn.Dropout(dropout_prob)
        self.blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout_prob, use_swiglu) for _ in range(num_layers)])
        self.norm = nn.RMSNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # weight tying

    def forward(self, tokens):
        embeddings = self.embedding(tokens)
        x = self.dropout(embeddings)
        for b in self.blocks:
            x, _ = b(x)
        features = self.norm(x)
        return self.lm_head(features)

## Tokenizer

Must be the exact same used for pretraining, on top just add 2 extra tokens for assistant / user roles and assign these the same embedding we learnt during pretraining as the end of text token.

In [7]:
# Need to re-use the exact same tokenizer for SFT + 2 new tokens for conversation
tokenizer = Tokenizer.from_pretrained("GPT2")
eot_id = tokenizer.token_to_id("<|endoftext|>")

tokenizer.add_special_tokens(ROLE_TOKENS)
user_id = tokenizer.token_to_id("<|user|>")
assistant_id = tokenizer.token_to_id("<|assistant|>")

print(user_id, assistant_id, eot_id)  # 50257 50258 50256
tokenizer.decode([user_id, assistant_id, eot_id], skip_special_tokens=False)  # '<|user|><|assistant|><|endoftext|>'

50257 50258 50256


'<|user|><|assistant|><|endoftext|>'

## Supervised Fine-Tuning Dataset

### Goal
Train the model to generate assistant responses, NOT to predict instructions -> Autoregressive LMs predict the NEXT token at each position. We use this by masking instruction tokens in the loss calculation.

---

### The Process

#### 1. Format the Data
```
Input text: "<|user|>\n{instruction}\n<|assistant|>\n{response}"

```

#### 2. Tokenize
```
ids = [user_tok, What, is, 2, +, 2, ?, asst_tok, The, answer, is, 4, eot]
idx:   0        1     2   3  4  5  6  7         8    9      10  11 12
```

#### 3. Create Shifted Labels
```python
labels[:-1] = ids[1:]  # Each label is the NEXT token to predict

labels = [What, is, 2, +, 2, ?, asst_tok, The, answer, is, 4, eot, IGNORE]
```
**Meaning:** `labels[i]` = what should be predicted after seeing `ids[i]`

#### 4. Mask the Instruction
```python
# Find position of <|assistant|> token (position 7 in example)
labels[:assistant_pos+1] = IGNORE_INDEX (-100)

labels = [IGN, IGN, IGN, IGN, IGN, IGN, IGN, The, answer, is, 4, eot, IGN]
          └─────────instruction masked───────────┘  └────train here────┘
```

#### 5. Compute Loss (during training)
```python
logits = model(ids)  # Model predicts next token at each position
loss = CrossEntropyLoss(logits, labels, ignore_index=-100)
```

**What happens:**
- Position 0-6: `labels[i] = -100` → loss ignored (don't train on instruction)
- Position 7: predict "The" after `<|assistant|>` → **COMPUTE LOSS** ✓
- Position 8: predict "answer" after "The" → **COMPUTE LOSS** ✓
- Position 9: predict "is" after "answer" → **COMPUTE LOSS** ✓
- Position 10: predict "4" after "is" → **COMPUTE LOSS** ✓
- Position 11: predict `<eot>` after "4" → **COMPUTE LOSS** ✓

---

### Why This Works

**Causal Attention Mask**
- Prevents the model from "seeing" future tokens
- At position i, model only attends to tokens 0 to i

**Teacher Forcing**
- Model sees correct previous tokens during training
- Learns to predict the next one

**Masking with -100**
- `CrossEntropyLoss` ignores these positions
- Gradients only flow through response tokens

**Result:** Model learns "given instruction X, generate response Y" without wasting compute trying to predict the instruction itself.

---

### Key Facts

- `IGNORE_INDEX = -100` (standard PyTorch convention)
- Only ~5-20% of tokens typically contribute to loss (just the responses)
- The shift (`labels[i] = ids[i+1]`) aligns predictions with targets
- The masking + smaller size dataset is going to finetune behavior but not (or barely) knowledge !

In [8]:
# Load multiple datasets
print("Loading datasets...")
alpaca_cleaned = load_dataset("yahma/alpaca-cleaned")
platypus = load_dataset("garage-bAInd/Open-Platypus")
no_robots = load_dataset("HuggingFaceH4/no_robots")

# Format functions for each dataset
def format_alpaca(example):
    if example["input"].strip():
        user_text = (
            f"{example['instruction']}\n\n"
            f"{example['input']}"
        )
    else:
        user_text = example["instruction"]
    return {
        "user": user_text,
        "assistant": example["output"]
    }

def format_platypus(example):
    return {
        "user": example["instruction"],
        "assistant": example["output"]
    }

def format_no_robots(example):
    return {
        "user": example["prompt"],
        "assistant": example["messages"][1]["content"]
    }

# Map each dataset to common format
print("Formatting datasets...")
alpaca_formatted = alpaca_cleaned.map(
    format_alpaca, 
    remove_columns=alpaca_cleaned["train"].column_names
)
platypus_formatted = platypus.map(
    format_platypus, 
    remove_columns=platypus["train"].column_names
)
no_robots_formatted = no_robots.map(
    format_no_robots, 
    remove_columns=no_robots["train"].column_names
)

combined_datasets = [
    alpaca_formatted["train"],
    platypus_formatted["train"],
    no_robots_formatted["train"]
]

# OPTIONAL : can use any modern LLM to generate custom SFT data
if os.path.isfile("synthetic_sft_data.jsonl"):
    print("Loading synthetic data...")
    synthetic_data = load_synthetic_data("synthetic_sft_data.jsonl")
    print(f"Loaded {len(synthetic_data)} synthetic examples")
    
    # Transform list of dicts to dict of lists
    data_dict = {}
    for key in synthetic_data[0].keys():
        data_dict[key] = [item[key] for item in synthetic_data]
    
    synthetic_dataset = ds.from_dict(data_dict)
    
    # Print example to verify format
    print("\nExample from synthetic dataset:")
    pprint(synthetic_dataset[0])

    # Add to all datasets
    combined_datasets.append(synthetic_dataset)

# Combine all datasets
print("\n\nCombining datasets...")
combined_train = concatenate_datasets(combined_datasets)

print(f"Total training examples: {len(combined_train)}")
print("\nExample from combined dataset:")
pprint(next(iter(combined_train)))

# Tokenization function
def tokenize_sft(example):
    text = (
        "<|user|>\n"
        f"{example['user']}\n"
        "<|assistant|>\n"
        f"{example['assistant']}"
    )
    ids = tokenizer.encode(text).ids
    ids.append(eot_id)
    
    # Find assistant token
    try:
        assistant_pos = ids.index(assistant_id)
    except ValueError:
        return None
    
    # Create labels as shifted version of ids
    labels = [IGNORE_INDEX] * len(ids)
    labels[:-1] = ids[1:]  # Shift: label[i] = ids[i+1]
    
    # Now mask out everything before assistant response
    labels[:assistant_pos + 1] = [IGNORE_INDEX] * (assistant_pos + 1)
    
    # Truncate
    ids = ids[:block_size]
    labels = labels[:block_size]
    
    return {
        "input_ids": ids,
        "labels": labels,
    }

# Tokenize combined dataset
print("Tokenizing combined dataset...")
combined_tokenized = combined_train.map(
    tokenize_sft,
    remove_columns=combined_train.column_names,
    num_proc=4,
)
print("Done !")

Loading datasets...
Formatting datasets...
Loading synthetic data...
Loaded 21 synthetic examples

Example from synthetic dataset:
{'assistant': '**\n'
              '\n'
              'To solve for the initial investment $X$, we must break the '
              'problem down into sequential steps, calculating the value of '
              'each component in terms of $X$.\n'
              '\n'
              '### **Step 1: Calculate Revenue from Inventory Sales**\n'
              '*   Initial Inventory Cost = $0.60X$\n'
              '*   Value of sellable inventory (after $10\\%$ loss) = $0.60X '
              '\\times (1 - 0.10) = 0.54X$\n'
              '*   The markup is $150\\%$, which means the sales multiplier is '
              '$2.5$ ($100\\%$ original cost + $150\\%$ markup).\n'
              '*   **Total Sales Revenue** = $0.54X \\times 2.5 = 1.35X$\n'
              '\n'
              '### **Step 2: Calculate the Marketing Rebate**\n'
              '*   Initial Marketing Cost = 

In [9]:
def collate_fn(batch):
    # Inputs from user / assistant conversations are of variable length -> pad for training
    # To squeeze out performance, can assign inputs to buckets or pad with fixed len -> compile model
    # Here the data is reasonably sized so we can skip
    batch = [x for x in batch if x is not None]
    max_len = max(len(x["input_ids"]) for x in batch)
    
    input_ids = []
    labels = []
    
    for x in batch:
        pad_len = max_len - len(x["input_ids"])
        input_ids.append(
            x["input_ids"] + [eot_id] * pad_len
        )
        labels.append(
            x["labels"] + [IGNORE_INDEX] * pad_len
        )
    
    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "labels": torch.tensor(labels, dtype=torch.long),
    }

train_loader = DataLoader(
    combined_tokenized,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    prefetch_factor=prefetch,
    pin_memory=True,
)

print(f"DataLoader created with {len(train_loader)} batches of {batch_size} seqs")

DataLoader created with 5388 batches of 16 seqs


## Model Loading

Start from pretrained model, set `dropout` to 0, extend embedding table to the new role tokens

In [10]:
model_config = {
    "vocab_size": tokenizer.get_vocab_size() - len(ROLE_TOKENS),
    "embed_dim": embed_dim,
    "num_layers": num_layers,
    "num_heads": num_heads,
    "mlp_ratio": mlp_ratio,
    "dropout_prob": dropout_prob,
    "use_swiglu": True,
}

print("Initializing model with config : ")
pprint(model_config)

model = GPT(**model_config).to(device)

print(f"Loading : {pretrained_weights}...")
trained_weights = strip_compile_prefix(torch.load(pretrained_weights, map_location=device))  # handles naming in case of compiled model

model.load_state_dict(trained_weights, strict=True)
print("Loaded pretrained weights")

# Need to handle the embeddings for the 2 new tokens we added
old_vocab_size, dim = model.embedding.weight.shape
new_vocab_size = old_vocab_size + len(ROLE_TOKENS)

new_embedding = torch.nn.Embedding(new_vocab_size, dim).to(device)
new_embedding.weight.data[:old_vocab_size] = model.embedding.weight.data
new_embedding.weight.data[old_vocab_size:] = model.embedding.weight.data[eot_id]  # copy paste embedding of end of text to usr/assistant

model.embedding = new_embedding
model.lm_head.weight = model.embedding.weight

# Ready to train ! 
# model = torch.compile(model)  ## bad idea unless we implement bucketing or fixed size padding
model.train()
_, _ = count_parameters(model)

Initializing model with config : 
{'dropout_prob': 0,
 'embed_dim': 768,
 'mlp_ratio': 4,
 'num_heads': 12,
 'num_layers': 12,
 'use_swiglu': True,
 'vocab_size': 50257}
Loading : gpt_model_1024_158417.pt...
Loaded pretrained weights
Parameter Breakdown:
embeddings          :   38,598,912 (31.24%)
norms               :       19,200 ( 0.02%)
other               :   28,311,936 (22.91%)
mlp                 :   56,623,104 (45.83%)
TOTAL               :  123,553,152


Weight decay is quite huge compared to habitual CNN but seems to be the standard for LLMs (empirical evidence), usually 0.1 for pretraining, 0.01 for SFT, helps avoiding memorization etc. Small lr because we're finetuning.

In [11]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=2e-5,
    betas=(0.9, 0.95),
    weight_decay=0.01,  # good practice
)

We train over `epochs` because SFT data is 1) manageable 2) fixed Q/A samples while pretraining data is intractable and random contiguous chunk of texts.

In [14]:
for epoch in range(NUM_EPOCHS):
    for step, batch in enumerate(train_loader):
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        
        with torch.amp.autocast('cuda', dtype=dtype):
            logits = model(input_ids)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
                ignore_index=IGNORE_INDEX,
            )
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        if step % 100 == 0:
            print(
                f"epoch {epoch} | step {step} | loss {loss.item():.4f}"
            )
    
    # Save model in full precision
    torch.save(model.state_dict(), f"GPT_SFT_epoch_{epoch}.pt")

epoch 0 | step 0 | loss 2.7657
epoch 0 | step 100 | loss 2.2716
epoch 0 | step 200 | loss 2.2547
epoch 0 | step 300 | loss 2.2833
epoch 0 | step 400 | loss 2.5284
epoch 0 | step 500 | loss 2.1692
epoch 0 | step 600 | loss 2.1431
epoch 0 | step 700 | loss 2.4482
epoch 0 | step 800 | loss 2.5566
epoch 0 | step 900 | loss 2.0463
epoch 0 | step 1000 | loss 2.4977
epoch 0 | step 1100 | loss 2.1114
epoch 0 | step 1200 | loss 2.3105
epoch 0 | step 1300 | loss 2.1581
epoch 0 | step 1400 | loss 2.3696
epoch 0 | step 1500 | loss 2.4273
epoch 0 | step 1600 | loss 2.4766
epoch 0 | step 1700 | loss 2.1687
epoch 0 | step 1800 | loss 2.4495
epoch 0 | step 1900 | loss 2.5453
epoch 0 | step 2000 | loss 2.2902
epoch 0 | step 2100 | loss 1.8036
epoch 0 | step 2200 | loss 1.8021
epoch 0 | step 2300 | loss 1.9774
epoch 0 | step 2400 | loss 2.0191
epoch 0 | step 2500 | loss 1.9997
epoch 0 | step 2600 | loss 1.9025
epoch 0 | step 2700 | loss 2.6555
epoch 0 | step 2800 | loss 2.2636
epoch 0 | step 2900 | loss

KeyboardInterrupt: 

In [18]:
prompt = (
    "<|user|>\n"
    "Explain why sorting by hash is a bad idea.\n"
    "<|assistant|>\n"
)

enc = tokenizer.encode(
    prompt,
)

input_ids = torch.tensor(enc.ids, dtype=torch.long).unsqueeze(0).to(device)

@torch.no_grad()
def generate(
    model,
    input_ids,
    max_new_tokens=100,
    stop_token_id=eot_id,
    temperature=0.7,
    greedy=False,
):
    model.eval()
    
    for _ in range(max_new_tokens):
        # Crop to block_size if needed
        input_ids_cond = input_ids if input_ids.shape[1] <= block_size else input_ids[:, -block_size:]
        
        # Get logits for next token
        logits = model(input_ids_cond)
        next_token_logits = logits[:, -1, :]
        
        if greedy:
            # Greedy decoding
            next_token = next_token_logits.argmax(dim=-1, keepdim=True)
        else:
            # Sample with temperature
            probs = F.softmax(next_token_logits / temperature, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
        
        # Append sampled token
        input_ids = torch.cat([input_ids, next_token], dim=1)
        
        # Stop if we hit the stop token
        if stop_token_id is not None and (next_token == stop_token_id).all():
            break
    
    return input_ids

# Hopefully 1) it follows the conversation workflow 2) it's not slop !
out = generate(model, input_ids, max_new_tokens=150)
print(tokenizer.decode(out[0].tolist(), skip_special_tokens=False))

<|user|>
Explain why sorting by hash is a bad idea.
<|assistant|>
Dictionaries can be unethical, depending on their reasoning. While hash is a good idea, it can be harmful for many people. For example, if you try to sell hash books to someone else, you can sell them to someone else. This makes it difficult for both your likely and future users to know how much they have in store for you.

Dictionaries can also be harmful to certain groups of people, such as children or teenagers. For instance, if you try to sell your family's home to someone else, you can sell it to someone else for the same purpose. This can make it difficult for many people to know the exact value you want them to know about your decision.

Dictionaries are also harmful
