# GPT

SFT part : teach the model the assistant format

Basically : 
- the prompt tokens are masked (no loss)
- the response tokens are trained on (loss applied)

In [1]:
# Standard library
import csv
import math
import multiprocessing
import os
import random
import time
from pprint import pprint
from datetime import datetime

# Environment config
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Third-party
import numpy as np
from datasets import concatenate_datasets, load_dataset
from rotary_embedding_torch import RotaryEmbedding
from tokenizers import Tokenizer

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# Torch runtime config
torch.set_float32_matmul_precision("medium")
torch.cuda.empty_cache()

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, causal: bool = True, dropout: float = 0.1):
        super().__init__()
        if embed_dim % num_heads != 0:
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads}).")
        
        self.causal = causal
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.dropout_p = dropout
        
        # Fused QKV projection: 3x the output size
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        
        self.rotary_emb = RotaryEmbedding(dim=self.head_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
    
    def forward(self, x, k_v_cache=None):
        B, T, _ = x.shape
        using_cache = k_v_cache is not None and "K" in k_v_cache
    
        # 1. Single fused projection
        if using_cache:
            x_q = x[:, -1:, :]
            qkv = self.qkv_proj(x_q)  # (B, 1, 3*embed_dim)
        else:
            qkv = self.qkv_proj(x)  # (B, T, 3*embed_dim)
        
        # 2. Split into Q, K, V
        Q, K, V = qkv.chunk(3, dim=-1)  # Each is (B, T, embed_dim)
        
        def split_heads(t):
            return t.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 3. Split heads -> (B, H, T, D_head)
        Q = split_heads(Q)
        K = split_heads(K)
        V = split_heads(V)
    
        # 4. Apply RoPE 
        if using_cache:
            past_len = k_v_cache["K"].shape[-2]
            Q = self.rotary_emb.rotate_queries_or_keys(Q, offset=past_len)
            K = self.rotary_emb.rotate_queries_or_keys(K, offset=past_len)
            
            K = torch.cat([k_v_cache["K"], K], dim=-2)
            V = torch.cat([k_v_cache["V"], V], dim=-2)
        else:
            Q = self.rotary_emb.rotate_queries_or_keys(Q)
            K = self.rotary_emb.rotate_queries_or_keys(K)
    
        # 5. Update cache
        if k_v_cache is not None:
            k_v_cache["K"] = K
            k_v_cache["V"] = V
    
        # 6. Attention
        out = F.scaled_dot_product_attention(
            query=Q,
            key=K,
            value=V,
            attn_mask=None, 
            dropout_p=self.dropout_p if self.training else 0.0,
            is_causal=self.causal and (Q.shape[-2] > 1)
        )
        
        # 7. Merge heads
        out = out.transpose(1, 2).contiguous().view(B, -1, self.embed_dim)
    
        return self.out_proj(out), k_v_cache

In [3]:
class MLP(nn.Module):
    def __init__(self, embed_dim, hidden_dim=None, dropout_prob=0.1, use_swiglu=True):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * embed_dim
        
        self.use_swiglu = use_swiglu
        
        if use_swiglu:
            # Adjust hidden_dim for param count matching
            hidden_dim = int(2 * hidden_dim / 3)
            self.gate_proj = nn.Linear(embed_dim, hidden_dim, bias=False)
            self.up_proj = nn.Linear(embed_dim, hidden_dim, bias=False)
            self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=False)
        else:
            self.linear1 = nn.Linear(embed_dim, hidden_dim, bias=False)
            self.act = nn.GELU()
            self.linear2 = nn.Linear(hidden_dim, embed_dim, bias=False)
        
        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, x):
        if self.use_swiglu:
            return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
        else:
            return self.dropout(self.linear2(self.act(self.linear1(x))))

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self,
                 embed_dim,
                 num_heads,
                 mlp_ratio=4,
                 dropout_prob=0.1,
                 causal=True,
                 use_swiglu=True,
                ): 
        """
        Initialize a complete transformer block.
        
        APPROACH:
        1. Multi-head self-attention for sequence modeling
        2. 1st Normalization (pre-norm architecture)
        3. MLP with specified expansion ratio
        4. 2nd Normalization
    
        TRANSFORMER BLOCK ARCHITECTURE:
        x → Norm → MultiHeadAttention → + (residual) →
            Norm → MLP → + (residual) → output
    
        NB: We use pre-norm architecture (before attention/MLP)
        """
    
        super().__init__()
        self.norm1 = nn.RMSNorm(embed_dim)
        self.mha = MultiHeadAttention(embed_dim, num_heads, causal, dropout_prob)  # causal = masking out tokens
        self.norm2 = nn.RMSNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio * embed_dim, dropout_prob, use_swiglu)
    
    def forward(self, x, cache=None):
        x1 = self.norm1(x)
        x2, cache = self.mha(x1, cache)  # will be used when generating tokens during inference
        x2 = x2 + x  # residual path
    
        x3 = self.norm2(x2)
        x3 = self.mlp(x3) + x2  # residual path
        return x3, cache

In [5]:
class GPT(nn.Module):
    """
    Complete GPT (Generative Pre-trained Transformer) model.

    This combines embeddings, positional encoding, multiple transformer blocks,
    and a language modeling head for text generation.
    """

    def __init__(self,
                 vocab_size,
                 embed_dim,
                 num_layers,
                 num_heads,
                 mlp_ratio=4,
                 dropout_prob=0.1,
                 use_swiglu=True,
                ):
        """
        Initialize complete GPT model.
        """
        super().__init__()

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio

        self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
        self.dropout = nn.Dropout(dropout_prob)
        self.blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout_prob, use_swiglu) for _ in range(num_layers)])
        self.norm = nn.RMSNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        # weight tying
        self.lm_head.weight = self.embedding.weight

        # below shamefully stolen from nano-gpt
        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * self.n_layer))

        # report number of parameters
        print("Number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.embedding.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
       
    def forward(self, tokens):
        embeddings = self.embedding(tokens)
        x = self.dropout(embeddings)
        for b in self.blocks:
            x, _ = b(x)  # iteratively refines features from initial embeddings
        features = self.norm(x)  # normalized to stabilize training
        return self.lm_head(features)

    @property
    def device(self):
        return next(self.parameters()).device

    @torch.no_grad()
    def generate(self,
                 prompt_tokens,
                 max_new_tokens=50,
                 temperature=1.0,
                 use_cache=True,
                 use_top_k=False,
                ):
        self.eval()

        tokens_out = prompt_tokens.clone()
        current_tokens = prompt_tokens.clone()
        tokens_out = tokens_out.to(self.device)
        current_tokens = current_tokens.to(self.device)
        cache = [{} if use_cache else None for _ in range(len(self.blocks))]
        
        for _ in range(max_new_tokens):

            x = self.embedding(current_tokens)
            for i, b in enumerate(self.blocks):
                x, c_i = b(x, cache[i])
                cache[i] = c_i
            
            features = self.norm(x)
            logits = self.lm_head(features)
                    
            last_logits = logits[:, -1, :]
    
            if temperature > 0:
                scaled_logits = last_logits / temperature
                # Only sample from top k tokens to avoid garbage prediction derailing whole prediction
                # We don't simply take max prob token to allow "creativity"
                if use_top_k:
                    # heuristic that is ok for toy project
                    # most of probability mass in on a small amount of tokens
                    k = min(max(5, int(0.01 * self.vocab_size)), 100)
                    values, indices = torch.topk(scaled_logits, k)
                    scaled_logits = torch.full_like(scaled_logits, float('-inf'))
                    scaled_logits.scatter_(1, indices, values)
                probs = torch.softmax(scaled_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                # Greedy decoding if temp is 0 (prevents division by zero)
                next_token = torch.argmax(last_logits, dim=-1, keepdim=True)
    
            tokens_out = torch.cat([tokens_out, next_token], dim=1)

            # If caching, we only need to feed the newest token next time, otherwise full sequence
            current_tokens = next_token if use_cache else tokens_out
       
        return tokens_out

In [6]:
# Need to re-use the exact same tokenizer for SFT

tokenizer = Tokenizer.from_pretrained("GPT2")
eot_id = tokenizer.token_to_id("<|endoftext|>")

ROLE_TOKENS = ["<|user|>", "<|assistant|>"]

tokenizer.add_special_tokens(ROLE_TOKENS)

user_id = tokenizer.token_to_id("<|user|>")
assistant_id = tokenizer.token_to_id("<|assistant|>")

print(user_id, assistant_id, eot_id)  # 50257 50258 50256

tokenizer.decode([user_id, assistant_id, eot_id], skip_special_tokens=False)  # '<|user|><|assistant|><|endoftext|>'

50257 50258 50256


'<|user|><|assistant|><|endoftext|>'

In [7]:
IGNORE_INDEX = -100
block_size = 1024

# Load multiple datasets
print("Loading datasets...")
alpaca_cleaned = load_dataset("yahma/alpaca-cleaned")
platypus = load_dataset("garage-bAInd/Open-Platypus")
no_robots = load_dataset("HuggingFaceH4/no_robots")

# Format functions for each dataset
def format_alpaca(example):
    if example["input"].strip():
        user_text = (
            f"{example['instruction']}\n\n"
            f"{example['input']}"
        )
    else:
        user_text = example["instruction"]
    return {
        "user": user_text,
        "assistant": example["output"]
    }

def format_platypus(example):
    return {
        "user": example["instruction"],
        "assistant": example["output"]
    }

def format_no_robots(example):
    return {
        "user": example["prompt"],
        "assistant": example["messages"][1]["content"]  # Fixed: assistant message is in messages list
    }

# Map each dataset to common format
print("Formatting datasets...")
alpaca_formatted = alpaca_cleaned.map(
    format_alpaca, 
    remove_columns=alpaca_cleaned["train"].column_names
)
platypus_formatted = platypus.map(
    format_platypus, 
    remove_columns=platypus["train"].column_names
)
no_robots_formatted = no_robots.map(
    format_no_robots, 
    remove_columns=no_robots["train"].column_names
)

combined_datasets = [
    alpaca_formatted["train"],
    platypus_formatted["train"],
    no_robots_formatted["train"]
]

# OPTIONAL : can use any modern LLM to generate custom SFT data
if os.path.isfile("synthetic_sft_data.jsonl"):
    def load_synthetic_data(file_path):
        data = []
        with open(file_path, 'r') as f:
            for line in f:
                data.append(json.loads(line))
        return data
    
    print("Loading synthetic data...")
    synthetic_data = load_synthetic_data("synthetic_sft_data.jsonl")
    print(f"Loaded {len(synthetic_data)} synthetic examples")
    
    # Convert to HuggingFace Dataset
    synthetic_dataset = Dataset.from_list(synthetic_data)
    
    # Print example to verify format
    print("\nExample from synthetic dataset:")
    pprint(synthetic_dataset[0])

    # Add to all datasets
    combined_datasets.append(synthetic_dataset)

# Combine all datasets
print("\n\nCombining datasets...")
combined_train = concatenate_datasets(combined_datasets)

print(f"Total training examples: {len(combined_train)}")
print("\nExample from combined dataset:")
pprint(next(iter(combined_train)))

# Tokenization function
def tokenize_sft(example):
    text = (
        "<|user|>\n"
        f"{example['user']}\n"
        "<|assistant|>\n"
        f"{example['assistant']}"
    )
    ids = tokenizer.encode(text).ids
    ids.append(eot_id)
    
    # Find assistant token
    try:
        assistant_pos = ids.index(assistant_id)
    except ValueError:
        return None
    
    # Create labels as shifted version of ids
    labels = [IGNORE_INDEX] * len(ids)
    labels[:-1] = ids[1:]  # Shift: label[i] = ids[i+1]
    
    # Now mask out everything before assistant response
    labels[:assistant_pos + 1] = [IGNORE_INDEX] * (assistant_pos + 1)
    
    # Truncate
    ids = ids[:block_size]
    labels = labels[:block_size]
    
    return {
        "input_ids": ids,
        "labels": labels,
    }

# Tokenize combined dataset
print("Tokenizing combined dataset...")
combined_tokenized = combined_train.map(
    tokenize_sft,
    remove_columns=combined_train.column_names,
    num_proc=4,
)

# Collate function
def collate_fn(batch):
    batch = [x for x in batch if x is not None]
    max_len = max(len(x["input_ids"]) for x in batch)
    
    input_ids = []
    labels = []
    
    for x in batch:
        pad_len = max_len - len(x["input_ids"])
        input_ids.append(
            x["input_ids"] + [eot_id] * pad_len
        )
        labels.append(
            x["labels"] + [IGNORE_INDEX] * pad_len
        )
    
    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "labels": torch.tensor(labels, dtype=torch.long),
    }

Loading datasets...
Formatting datasets...
Combining datasets...
Total training examples: 86186

Example from combined dataset:
{'assistant': '1. Eat a balanced and nutritious diet: Make sure your meals are '
              'inclusive of a variety of fruits and vegetables, lean protein, '
              'whole grains, and healthy fats. This helps to provide your body '
              'with the essential nutrients to function at its best and can '
              'help prevent chronic diseases.\n'
              '\n'
              '2. Engage in regular physical activity: Exercise is crucial for '
              'maintaining strong bones, muscles, and cardiovascular health. '
              'Aim for at least 150 minutes of moderate aerobic exercise or 75 '
              'minutes of vigorous exercise each week.\n'
              '\n'
              '3. Get enough sleep: Getting enough quality sleep is crucial '
              'for physical and mental well-being. It helps to regulate mood, '
        

In [8]:
#### CONFIG #####

# Basically GPT-2 Small
block_size = 1024  # 512 for faster convergence then 1024 to finish training
batch_size = 16
embed_dim = 768
num_layers = 12
num_heads = 12
dropout_prob = 0.1
mlp_ratio = 4  # standard 4x expansion


# Training
MAX_STEPS = 500000       # Total number of micro-batches to process
GRAD_ACCUM_STEPS = 40    # Accumulate gradients over 40 batches
LOG_INTERVAL = 500       # Log every 500 micro-batches
num_workers = 4          # For data loading
prefetch = 4
dtype = torch.bfloat16
device = "cuda"
model_path = f"gpt_model_{block_size}_final.pt"  # where do we store trained model
print("torch.cuda.is_bf16_supported()", torch.cuda.is_bf16_supported())

torch.cuda.is_bf16_supported() True


In [9]:
# Create dataloader
train_loader = DataLoader(
    combined_tokenized,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    prefetch_factor=prefetch,
    pin_memory=True,
)

print(f"DataLoader created with {len(train_loader)} batches")

DataLoader created with 5387 batches


In [10]:
def strip_compile_prefix(state_dict, prefix="_orig_mod."):
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith(prefix):
            new_state_dict[k[len(prefix):]] = v
        else:
            new_state_dict[k] = v
    return new_state_dict


model_config = {
    "vocab_size": tokenizer.get_vocab_size() - len(ROLE_TOKENS),
    "embed_dim": embed_dim,
    "num_layers": num_layers,
    "num_heads": num_heads,
    "mlp_ratio": mlp_ratio,
    "dropout_prob": dropout_prob,
    "use_swiglu": True,
}

print("Initializing model with config : ")
pprint(model_config)

model = GPT(**model_config).to(device)

trained_weights = strip_compile_prefix(torch.load("gpt_model_1024_final_60000.pt", map_location=device))

model.load_state_dict(trained_weights, strict=True)

# Need to handle the embeddings for the 2 new tokens we added
old_vocab_size, dim = model.embedding.weight.shape
new_vocab_size = old_vocab_size + len(ROLE_TOKENS)

new_embedding = torch.nn.Embedding(new_vocab_size, dim).to(device)
new_embedding.weight.data[:old_vocab_size] = model.embedding.weight.data
new_embedding.weight.data[old_vocab_size:] = model.embedding.weight.data[eot_id]  # copy paste embedding of end of text to usr/assistant

model.embedding = new_embedding
model.lm_head.weight = model.embedding.weight

# Ready to train ! 
# model = torch.compile(model)  ## bad idea unless we implement bucketing or fixed size padding
model.train()

Initializing model with config : 
{'dropout_prob': 0.1,
 'embed_dim': 768,
 'mlp_ratio': 4,
 'num_heads': 12,
 'num_layers': 12,
 'use_swiglu': True,
 'vocab_size': 50257}
Number of parameters: 84.95M


GPT(
  (embedding): Embedding(50259, 768)
  (dropout): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (norm1): RMSNorm((768,), eps=None, elementwise_affine=True)
      (mha): MultiHeadAttention(
        (qkv_proj): Linear(in_features=768, out_features=2304, bias=False)
        (rotary_emb): RotaryEmbedding()
        (out_proj): Linear(in_features=768, out_features=768, bias=False)
      )
      (norm2): RMSNorm((768,), eps=None, elementwise_affine=True)
      (mlp): MLP(
        (gate_proj): Linear(in_features=768, out_features=2048, bias=False)
        (up_proj): Linear(in_features=768, out_features=2048, bias=False)
        (down_proj): Linear(in_features=2048, out_features=768, bias=False)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (norm): RMSNorm((768,), eps=None, elementwise_affine=True)
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [11]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=2e-5,
    betas=(0.9, 0.95),
    weight_decay=0.01,
)

In [None]:
from torch.cuda.amp import autocast, GradScaler


for epoch in range(10):
    for step, batch in enumerate(train_loader):
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        
        with autocast(dtype=dtype):
            logits = model(input_ids)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
                ignore_index=IGNORE_INDEX,
            )
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        if step % 100 == 0:
            print(
                f"epoch {epoch} | step {step} | loss {loss.item():.4f}"
            )
    
    # Save model in full precision
    torch.save(model.state_dict(), f"GPT_SFT_epoch_{epoch}.pt")

  with autocast(dtype=dtype):


epoch 0 | step 0 | loss 3.6935
epoch 0 | step 100 | loss 3.2384
epoch 0 | step 200 | loss 2.4780
epoch 0 | step 300 | loss 3.1288
epoch 0 | step 400 | loss 2.6858
epoch 0 | step 500 | loss 2.5710
epoch 0 | step 600 | loss 2.7187
epoch 0 | step 700 | loss 2.2620
epoch 0 | step 800 | loss 2.7491
epoch 0 | step 900 | loss 2.7041
epoch 0 | step 1000 | loss 2.8460
epoch 0 | step 1100 | loss 2.8642
epoch 0 | step 1200 | loss 3.2335
epoch 0 | step 1300 | loss 2.7360
epoch 0 | step 1400 | loss 2.9145
epoch 0 | step 1500 | loss 2.9781
epoch 0 | step 1600 | loss 3.2000
epoch 0 | step 1700 | loss 3.0878
epoch 0 | step 1800 | loss 2.8833
epoch 0 | step 1900 | loss 2.3682
epoch 0 | step 2000 | loss 2.7500
epoch 0 | step 2100 | loss 2.8032
epoch 0 | step 2200 | loss 2.7443
epoch 0 | step 2300 | loss 2.5857
epoch 0 | step 2400 | loss 2.7358
epoch 0 | step 2500 | loss 2.7681
epoch 0 | step 2600 | loss 2.5352
epoch 0 | step 2700 | loss 2.5798
epoch 0 | step 2800 | loss 3.0915
epoch 0 | step 2900 | loss

In [72]:
prompt = (
    "<|user|>\n"
    "Give an example of a hard data mining task.\n"
    "<|assistant|>\n"
)

enc = tokenizer.encode(
    prompt,
)

input_ids = torch.tensor(enc.ids, dtype=torch.long).unsqueeze(0).to(device)

@torch.no_grad()
def generate_greedy(
    model,
    input_ids,
    max_new_tokens=100,
    stop_token_id=eot_id,
):
    model.eval()

    for _ in range(max_new_tokens):
        logits = model(input_ids)
        next_token_logits = logits[:, -1] / 0.7
        probs = torch.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        input_ids = torch.cat([input_ids, next_token], dim=1)

        if next_token.item() == stop_token_id:
            break

    return input_ids


out = generate_greedy(model, input_ids, max_new_tokens=150)

print(tokenizer.decode(out[0].tolist(), skip_special_tokens=False))

<|user|>
Give an example of a hard data mining task.
<|assistant|>
An example of a hard data mining task is by the following steps: 
1. What is needed to store the data? 
2. What is needed to store the data? 
3. How can you store the data? 
4. What gives you the right model?<|endoftext|>
