# GPT with Mixture of Experts (MoE)

Architectural improvements over base GPT:
- **Mixture of Experts (MoE)**: Multiple expert MLPs with learned routing. Only top-k experts activated per token.
  - More total parameters but same compute per forward pass
  - Each expert can specialize in different types of tokens/patterns
- **QK-Norm**: RMSNorm on queries and keys before RoPE for training stability
- **Load balancing loss**: Prevents expert collapse (all tokens going to same expert)

Expected behavior:
- ~4x total parameters vs dense model at same compute budget
- Better quality at fixed FLOP budget
- Requires load balancing to prevent expert collapse

In [None]:
# Standard library
import csv
import math
import multiprocessing
import os
import random
import time
from pprint import pprint
from datetime import datetime

# Environment config
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Third-party
import numpy as np
from datasets import Dataset as ds, concatenate_datasets, load_dataset
from rotary_embedding_torch import RotaryEmbedding
from tokenizers import Tokenizer

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# Torch runtime config
torch.set_float32_matmul_precision("medium")
torch.cuda.empty_cache()

# Custom
from utils import count_parameters, load_synthetic_data, strip_compile_prefix, round_up

## Config & Model Definition

Mostly the same code as the pretraining notebook, including Flash Attention, RMSNorm etc.

In [None]:
#### CONFIG #####

# Model architecture - Moderate scale MoE for single 32GB GPU
# Strategy: Go deeper (better for small models) + MoE for capacity
block_size = 1024
batch_size = 12       # conservative for MoE memory overhead
embed_dim = 1024      # GPT-2 Medium width
num_layers = 16       # deeper than GPT-2 Small (12), helps reasoning
num_heads = 16        # head_dim = 64, good for efficiency
dropout_prob = 0.1
mlp_ratio = 4

# MoE config
num_experts = 8       # 8 experts total
top_k_experts = 2     # activate 2 per token → 4x params, same compute
aux_loss_coef = 0.01  # load balancing (0.01-0.1 typical range)

# Expected params:
#   Total:  ~800M (all experts)
#   Active: ~250M (similar compute to GPT-2 Medium)

# Tokenizer
ROLE_TOKENS = ["<|user|>", "<|assistant|>"]
IGNORE_INDEX = -100

# Training
NUM_EPOCHS = 3
num_workers = 4
prefetch = 8
dtype = torch.bfloat16
device = "cuda"

# Estimated VRAM usage (bf16):
#   Model params:     ~800M × 2B = 1.6 GB
#   Optimizer states: ~800M × 8B = 6.4 GB (AdamW)
#   Gradients:        ~800M × 2B = 1.6 GB
#   Activations:      ~8-12 GB (batch=12, seq=1024)
#   Total:            ~18-22 GB → fits in 32GB comfortably

print("=" * 60)
print("MoE Model Configuration")
print("=" * 60)
print(f"  Architecture:    {embed_dim}d × {num_layers}L × {num_heads}H")
print(f"  MoE:             {num_experts} experts, top-{top_k_experts}")
print(f"  Batch:           {batch_size} × {block_size} tokens")
print(f"  Params:          ~800M total, ~250M active")
print(f"  bf16 supported:  {torch.cuda.is_bf16_supported()}")
print("=" * 60)

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention with QK-Norm for improved training stability."""
    
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 rotary_emb: RotaryEmbedding,
                 causal: bool = True,
                 dropout: float = 0.1
                ):
        super().__init__()
        if embed_dim % num_heads != 0:
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads}).")
        
        self.causal = causal
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.dropout_p = dropout
        
        # Fused QKV projection
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        
        # QK-Norm: normalize queries and keys before RoPE
        self.q_norm = nn.RMSNorm(self.head_dim)
        self.k_norm = nn.RMSNorm(self.head_dim)
        
        self.rotary_emb = rotary_emb
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
    
    def forward(self, x, k_v_cache=None):
        B, T, _ = x.shape
        using_cache = k_v_cache is not None and "K" in k_v_cache
    
        if using_cache:
            x_q = x[:, -1:, :]
            qkv = self.qkv_proj(x_q)
        else:
            qkv = self.qkv_proj(x)
        
        Q, K, V = qkv.chunk(3, dim=-1)
        
        def split_heads(t):
            return t.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        Q = split_heads(Q)
        K = split_heads(K)
        V = split_heads(V)
        
        # Apply QK-Norm before RoPE
        Q = self.q_norm(Q)
        K = self.k_norm(K)
    
        # Apply RoPE
        if using_cache:
            past_len = k_v_cache["K"].shape[-2]
            Q = self.rotary_emb.rotate_queries_or_keys(Q, offset=past_len)
            K = self.rotary_emb.rotate_queries_or_keys(K, offset=past_len)
            
            K = torch.cat([k_v_cache["K"], K], dim=-2)
            V = torch.cat([k_v_cache["V"], V], dim=-2)
            is_causal_step = False
        else:
            Q = self.rotary_emb.rotate_queries_or_keys(Q)
            K = self.rotary_emb.rotate_queries_or_keys(K)
            is_causal_step = self.causal
    
        if k_v_cache is not None:
            k_v_cache["K"] = K.detach()
            k_v_cache["V"] = V.detach()
    
        out = F.scaled_dot_product_attention(
            query=Q, key=K, value=V,
            attn_mask=None, 
            dropout_p=self.dropout_p if self.training else 0.0,
            is_causal=is_causal_step
        )
        
        out = out.transpose(1, 2).contiguous().view(B, -1, self.embed_dim)
        return self.out_proj(out), k_v_cache


class Expert(nn.Module):
    """Single expert MLP (SwiGLU)."""
    
    def __init__(self, embed_dim, hidden_dim, dropout_prob=0.1):
        super().__init__()
        hidden_dim = round_up(2 * hidden_dim // 3, 8)
        
        self.gate_up_proj = nn.Linear(embed_dim, 2 * hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=False)
        self.dropout = nn.Dropout(dropout_prob)
       
    def forward(self, x):
        gate_up = self.gate_up_proj(x)
        gate, up = gate_up.chunk(2, dim=-1)
        return self.dropout(self.down_proj(F.silu(gate) * up))


class MoEMLP(nn.Module):
    """
    Mixture of Experts MLP layer.
    
    Routes each token to top-k experts and combines their outputs.
    Includes auxiliary load balancing loss to prevent expert collapse.
    """
    
    def __init__(self, embed_dim, hidden_dim, num_experts, top_k, dropout_prob=0.1):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.embed_dim = embed_dim
        
        # Router: learned linear layer to score experts
        self.router = nn.Linear(embed_dim, num_experts, bias=False)
        
        # Expert MLPs
        self.experts = nn.ModuleList([
            Expert(embed_dim, hidden_dim, dropout_prob) 
            for _ in range(num_experts)
        ])
        
        # For tracking load balancing
        self.aux_loss = 0.0
    
    def forward(self, x):
        """
        Args:
            x: (B, T, D) input tensor
        Returns:
            output: (B, T, D) combined expert outputs
        """
        B, T, D = x.shape
        x_flat = x.view(-1, D)  # (B*T, D)
        num_tokens = x_flat.shape[0]
        
        # Compute router logits and probabilities
        router_logits = self.router(x_flat)  # (B*T, num_experts)
        router_probs = F.softmax(router_logits, dim=-1)
        
        # Select top-k experts per token
        top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
        
        # Normalize top-k probabilities to sum to 1
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
        
        # Compute auxiliary load balancing loss
        # Goal: encourage uniform expert utilization
        if self.training:
            # Fraction of tokens routed to each expert
            expert_mask = F.one_hot(top_k_indices, num_classes=self.num_experts).sum(dim=1)  # (B*T, E)
            tokens_per_expert = expert_mask.float().mean(dim=0)  # (E,)
            
            # Average router probability per expert
            router_prob_per_expert = router_probs.mean(dim=0)  # (E,)
            
            # Load balancing loss: minimize the product (encourages uniformity)
            self.aux_loss = self.num_experts * (tokens_per_expert * router_prob_per_expert).sum()
        
        # Compute expert outputs (batched for efficiency)
        # This is the "loop over experts" approach - simpler than sparse dispatch
        output = torch.zeros_like(x_flat)
        
        for expert_idx in range(self.num_experts):
            # Find which tokens selected this expert in their top-k
            # expert_mask[i, j] = 1 if token i selected expert expert_idx in position j of top-k
            expert_mask = (top_k_indices == expert_idx)  # (B*T, top_k)
            
            if not expert_mask.any():
                continue
            
            # Get tokens that use this expert
            token_indices = expert_mask.any(dim=-1).nonzero(as_tuple=True)[0]
            
            if len(token_indices) == 0:
                continue
                
            # Get the weight for this expert for these tokens
            # Shape: (num_selected_tokens,)
            weights = (top_k_probs * expert_mask.float()).sum(dim=-1)[token_indices]
            
            # Compute expert output
            expert_input = x_flat[token_indices]
            expert_output = self.experts[expert_idx](expert_input)
            
            # Weighted addition to output
            output[token_indices] += weights.unsqueeze(-1) * expert_output
        
        return output.view(B, T, D)


class TransformerBlock(nn.Module):
    """Transformer block with MoE MLP."""
    
    def __init__(self,
                 embed_dim,
                 num_heads,
                 rotary_emb,
                 mlp_ratio=4,
                 num_experts=8,
                 top_k_experts=2,
                 dropout_prob=0.1,
                 causal=True,
                ): 
        super().__init__()
        self.norm1 = nn.RMSNorm(embed_dim)
        self.mha = MultiHeadAttention(embed_dim, num_heads, rotary_emb, causal, dropout_prob)
        self.norm2 = nn.RMSNorm(embed_dim)
        
        # MoE instead of standard MLP
        hidden_dim = mlp_ratio * embed_dim
        self.moe = MoEMLP(embed_dim, hidden_dim, num_experts, top_k_experts, dropout_prob)
    
    def forward(self, x, cache=None):
        x1 = self.norm1(x)
        x2, cache = self.mha(x1, cache)
        x2 = x2 + x  # residual
    
        x3 = self.norm2(x2)
        x3 = self.moe(x3) + x2  # residual
        return x3, cache
    
    def get_aux_loss(self):
        """Return the MoE auxiliary loss for this block."""
        return self.moe.aux_loss

In [None]:
class GPT_MoE(nn.Module):
    """
    GPT with Mixture of Experts.
    
    Same architecture as base GPT but with MoE layers replacing standard MLPs.
    Includes auxiliary loss collection for load balancing.
    """

    def __init__(self,
                 vocab_size,
                 embed_dim,
                 num_layers,
                 num_heads,
                 mlp_ratio=4,
                 num_experts=8,
                 top_k_experts=2,
                 dropout_prob=0.1,
                 is_causal=True,
                ):
        super().__init__()

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        self.num_experts = num_experts
        self.top_k_experts = top_k_experts

        self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
        self.dropout = nn.Dropout(dropout_prob)
        
        head_dim = embed_dim // num_heads
        self.rotary_emb = RotaryEmbedding(dim=head_dim)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(
                embed_dim, num_heads, self.rotary_emb, 
                mlp_ratio, num_experts, top_k_experts,
                dropout_prob, is_causal
            ) 
            for _ in range(num_layers)
        ])
        self.norm = nn.RMSNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # weight tying

        # Initialize weights
        self.apply(self._init_weights)
        # Scale residual projections
        for pn, p in self.named_parameters():
            if pn.endswith(("out_proj.weight", "down_proj.weight")):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.num_layers))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
       
    def forward(self, tokens):
        embeddings = self.embedding(tokens)
        x = self.dropout(embeddings)
        for b in self.blocks:
            x, _ = b(x)
        features = self.norm(x)
        return self.lm_head(features)
    
    def get_aux_loss(self):
        """Collect and sum auxiliary losses from all MoE layers."""
        total_aux_loss = 0.0
        for block in self.blocks:
            total_aux_loss += block.get_aux_loss()
        return total_aux_loss / self.num_layers  # average over layers

    @property
    def device(self):
        return next(self.parameters()).device
    
    def count_parameters(self):
        """Count total and active parameters."""
        total_params = sum(p.numel() for p in self.parameters())
        
        # Estimate active params (non-MoE + top-k fraction of MoE)
        non_moe_params = 0
        moe_params = 0
        
        for name, p in self.named_parameters():
            if 'experts' in name:
                moe_params += p.numel()
            else:
                non_moe_params += p.numel()
        
        # Active MoE params = (top_k / num_experts) * total_moe_params
        active_moe = moe_params * (self.top_k_experts / self.num_experts)
        active_params = non_moe_params + active_moe
        
        return total_params, int(active_params)

    @torch.no_grad()
    def generate(self,
                 prompt_tokens,
                 max_new_tokens=50,
                 temperature=1.0,
                 top_k=0,
                 top_p=0.0,
                 use_cache=True,
                ):
        self.eval()

        tokens_out = prompt_tokens.clone()
        current_tokens = prompt_tokens.clone()
        tokens_out = tokens_out.to(self.device)
        current_tokens = current_tokens.to(self.device)
        cache = [{} if use_cache else None for _ in range(len(self.blocks))]
        
        for _ in range(max_new_tokens):
            x = self.embedding(current_tokens)
            for i, b in enumerate(self.blocks):
                x, c_i = b(x, cache[i])
                cache[i] = c_i
            
            features = self.norm(x)
            logits = self.lm_head(features)    
            last_logits = logits[:, -1, :]
    
            if temperature == 0:
                next_token = torch.argmax(last_logits, dim=-1, keepdim=True)
            else:
                scaled_logits = last_logits / temperature
                
                if int(top_k) > 0:
                    values, indices = torch.topk(scaled_logits, top_k)
                    scaled_logits = torch.full_like(scaled_logits, float('-inf'))
                    scaled_logits.scatter_(1, indices, values)

                if top_p > 0.0 and top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True, dim=-1)
                    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0
                    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                    scaled_logits[indices_to_remove] = float('-inf')
                
                probs = torch.softmax(scaled_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)

            if next_token.item() == eot_id:
                break
            
            tokens_out = torch.cat([tokens_out, next_token], dim=1)
            current_tokens = next_token if use_cache else tokens_out
       
        return tokens_out

## Tokenizer

Must be the exact same used for pretraining, on top just add 2 extra tokens for assistant / user roles and assign these the same embedding we learnt during pretraining as the end of text token.

In [None]:
import tiktoken
from tiktoken.core import Encoding

base = tiktoken.get_encoding("gpt2")

special_tokens = {
    "<|endoftext|>": base.eot_token,        # must be preserved
    "<|user|>": base.n_vocab,
    "<|assistant|>": base.n_vocab + 1,
}

tokenizer = Encoding(
    name="gpt2-with-roles",
    pat_str=base._pat_str,
    mergeable_ranks=base._mergeable_ranks,
    special_tokens=special_tokens,
)

eot_id = tokenizer.eot_token
user_id = tokenizer.encode_single_token("<|user|>")
assistant_id = tokenizer.encode_single_token("<|assistant|>")

print(user_id, assistant_id, eot_id)
print(tokenizer.decode([user_id, assistant_id, eot_id]))

vocab_size = round_up(tokenizer.n_vocab, 128)
print("Vocab size:", tokenizer.n_vocab, "→ padded:", vocab_size)

## Supervised Fine-Tuning Dataset

### Goal
Train the model to generate assistant responses, NOT to predict instructions -> Autoregressive LMs predict the NEXT token at each position. We use this by masking instruction tokens in the loss calculation.

---

### The Process

#### 1. Format the Data
```
Input text: "<|user|>\n{instruction}\n<|assistant|>\n{response}"

```

#### 2. Tokenize
```
ids = [user_tok, What, is, 2, +, 2, ?, asst_tok, The, answer, is, 4, eot]
idx:   0        1     2   3  4  5  6  7         8    9      10  11 12
```

#### 3. Create Shifted Labels
```python
labels[:-1] = ids[1:]  # Each label is the NEXT token to predict

labels = [What, is, 2, +, 2, ?, asst_tok, The, answer, is, 4, eot, IGNORE]
```
**Meaning:** `labels[i]` = what should be predicted after seeing `ids[i]`

#### 4. Mask the Instruction
```python
# Find position of <|assistant|> token (position 7 in example)
labels[:assistant_pos+1] = IGNORE_INDEX (-100)

labels = [IGN, IGN, IGN, IGN, IGN, IGN, IGN, The, answer, is, 4, eot, IGN]
          └─────────instruction masked───────────┘  └────train here────┘
```

#### 5. Compute Loss (during training)
```python
logits = model(ids)  # Model predicts next token at each position
loss = CrossEntropyLoss(logits, labels, ignore_index=-100)
```

**What happens:**
- Position 0-6: `labels[i] = -100` → loss ignored (don't train on instruction)
- Position 7: predict "The" after `<|assistant|>` → **COMPUTE LOSS** ✓
- Position 8: predict "answer" after "The" → **COMPUTE LOSS** ✓
- Position 9: predict "is" after "answer" → **COMPUTE LOSS** ✓
- Position 10: predict "4" after "is" → **COMPUTE LOSS** ✓
- Position 11: predict `<eot>` after "4" → **COMPUTE LOSS** ✓

---

### Why This Works

**Causal Attention Mask**
- Prevents the model from "seeing" future tokens
- At position i, model only attends to tokens 0 to i

**Teacher Forcing**
- Model sees correct previous tokens during training
- Learns to predict the next one

**Masking with -100**
- `CrossEntropyLoss` ignores these positions
- Gradients only flow through response tokens

**Result:** Model learns "given instruction X, generate response Y" without wasting compute trying to predict the instruction itself.

---

### Key Facts

- `IGNORE_INDEX = -100` (standard PyTorch convention)
- Only ~5-20% of tokens typically contribute to loss (just the responses)
- The shift (`labels[i] = ids[i+1]`) aligns predictions with targets
- The masking + smaller size dataset is going to finetune behavior but not (or barely) knowledge !

In [None]:
# Load multiple datasets
print("Loading datasets...")
alpaca_cleaned = load_dataset("yahma/alpaca-cleaned")
platypus = load_dataset("garage-bAInd/Open-Platypus")
no_robots = load_dataset("HuggingFaceH4/no_robots")

# Format functions for each dataset
def format_alpaca(example):
    if example["input"].strip():
        user_text = (
            f"{example['instruction']}\n\n"
            f"{example['input']}"
        )
    else:
        user_text = example["instruction"]
    return {
        "user": user_text,
        "assistant": example["output"]
    }

def format_platypus(example):
    return {
        "user": example["instruction"],
        "assistant": example["output"]
    }

def format_no_robots(example):
    return {
        "user": example["prompt"],
        "assistant": example["messages"][1]["content"]
    }

# Map each dataset to common format
print("Formatting datasets...")
alpaca_formatted = alpaca_cleaned.map(
    format_alpaca, 
    remove_columns=alpaca_cleaned["train"].column_names
)
platypus_formatted = platypus.map(
    format_platypus, 
    remove_columns=platypus["train"].column_names
)
no_robots_formatted = no_robots.map(
    format_no_robots, 
    remove_columns=no_robots["train"].column_names
)

combined_datasets = [
    alpaca_formatted["train"],
    platypus_formatted["train"],
    no_robots_formatted["train"]
]

# OPTIONAL : can use any modern LLM to generate custom SFT data
if os.path.isfile("synthetic_sft_data.jsonl"):
    print("Loading synthetic data...")
    synthetic_data = load_synthetic_data("synthetic_sft_data.jsonl")
    print(f"Loaded {len(synthetic_data)} synthetic examples")
    
    # Transform list of dicts to dict of lists
    data_dict = {}
    for key in synthetic_data[0].keys():
        data_dict[key] = [item[key] for item in synthetic_data]
    
    synthetic_dataset = ds.from_dict(data_dict)
    
    # Print example to verify format
    print("\nExample from synthetic dataset:")
    pprint(synthetic_dataset[0])

    # Add to all datasets
    combined_datasets.append(synthetic_dataset)

# Combine all datasets
print("\n\nCombining datasets...")
combined_train = concatenate_datasets(combined_datasets)

print(f"Total training examples: {len(combined_train)}")
print("\nExample from combined dataset:")
pprint(next(iter(combined_train)))

def tokenize_sft(example):
    text = (
        "<|user|>\n"
        f"{example['user']}\n"
        "<|assistant|>\n"
        f"{example['assistant']}"
    )

    ids = tokenizer.encode(text, allowed_special=set(special_tokens.keys()))
    ids.append(eot_id)

    # Find assistant token
    try:
        assistant_pos = ids.index(assistant_id)
    except ValueError:
        # Return empty tensors that will be filtered out
        return {"input_ids": [], "labels": []}

    # Create labels as shifted version of ids
    labels = [IGNORE_INDEX] * len(ids)
    labels[:-1] = ids[1:]

    # Mask out everything before assistant response
    labels[:assistant_pos + 1] = [IGNORE_INDEX] * (assistant_pos + 1)

    # Truncate
    ids = ids[:block_size]
    labels = labels[:block_size]

    return {
        "input_ids": ids,
        "labels": labels,
    }

# Tokenize combined dataset
print("Tokenizing combined dataset...")
combined_tokenized = combined_train.map(
    tokenize_sft,
    remove_columns=combined_train.column_names,
    num_proc=4,
)

# Filter out empty examples (failed tokenization)
pre_filter_len = len(combined_tokenized)
combined_tokenized = combined_tokenized.filter(lambda x: len(x["input_ids"]) > 0)
print(f"Filtered {pre_filter_len - len(combined_tokenized)} invalid examples")
print(f"Final dataset size: {len(combined_tokenized)}")

In [None]:
def collate_fn(batch):
    # Inputs from user / assistant conversations are of variable length -> pad for training
    # To squeeze out performance, can assign inputs to buckets or pad with fixed len -> compile model
    # Here the data is reasonably sized so we can skip
    batch = [x for x in batch if x is not None]
    max_len = max(len(x["input_ids"]) for x in batch)
    
    input_ids = []
    labels = []
    
    for x in batch:
        pad_len = max_len - len(x["input_ids"])
        input_ids.append(
            x["input_ids"] + [eot_id] * pad_len
        )
        labels.append(
            x["labels"] + [IGNORE_INDEX] * pad_len
        )
    
    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "labels": torch.tensor(labels, dtype=torch.long),
    }

train_loader = DataLoader(
    combined_tokenized,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    prefetch_factor=prefetch,
    pin_memory=True,
)

print(f"DataLoader created with {len(train_loader)} batches of {batch_size} seqs")

## Model Initialization

**Note**: MoE architecture is different from dense GPT, so we train from scratch.
Cannot load pretrained weights from the dense model.

In [None]:
model_config = {
    "vocab_size": vocab_size,
    "embed_dim": embed_dim,
    "num_layers": num_layers,
    "num_heads": num_heads,
    "mlp_ratio": mlp_ratio,
    "num_experts": num_experts,
    "top_k_experts": top_k_experts,
    "dropout_prob": dropout_prob,
}

print("Initializing MoE model with config:")
pprint(model_config)

model = GPT_MoE(**model_config).to(device)

# Count parameters
total_params, active_params = model.count_parameters()
print(f"\nParameter counts:")
print(f"  Total parameters:  {total_params:,} ({total_params/1e6:.1f}M)")
print(f"  Active parameters: {active_params:,} ({active_params/1e6:.1f}M)")
print(f"  Ratio: {total_params/active_params:.2f}x total vs active")

# Initialize special token embeddings
with torch.no_grad():
    model.embedding.weight[user_id] = model.embedding.weight[eot_id].clone()
    model.embedding.weight[assistant_id] = model.embedding.weight[eot_id].clone()

print(f"\nInitialized special tokens: <|user|>={user_id}, <|assistant|>={assistant_id}")
model.train()

## Optimizer

For MoE models:
- Same weight decay rules (don't decay norms, embeddings)
- Router parameters should NOT be decayed (they're like attention-ish)
- Add auxiliary load balancing loss to prevent expert collapse

In [None]:
# Separate parameters into decay and no-decay groups
decay_params = []
no_decay_params = []

for name, param in model.named_parameters():
    if not param.requires_grad:
        continue
    # Don't apply weight decay to:
    # - Norm parameters (RMSNorm weights, QK-norm)
    # - Embedding table (tied with lm_head)
    # - Router (small, should be flexible)
    # - Any bias terms
    if any(kw in name.lower() for kw in ['norm', 'bias', 'embed', 'embedding', 'lm_head', 'router']):
        no_decay_params.append(param)
    else:
        decay_params.append(param)

print(f"Decay params: {len(decay_params)}, No decay params: {len(no_decay_params)}")

# Hyperparameters tuned for ~800M param MoE model
# Rule of thumb: larger models need lower LR
base_lr = 2e-4  # slightly lower than 124M model

optimizer = torch.optim.AdamW([
    {'params': decay_params, 'weight_decay': 0.1},
    {'params': no_decay_params, 'weight_decay': 0.0}
],
    lr=base_lr,
    betas=(0.9, 0.95),  # standard for LLMs
    eps=1e-8,
)

# Learning rate scheduler: 10% warmup + cosine decay
total_steps = NUM_EPOCHS * len(train_loader)
warmup_steps = int(total_steps * 0.10)  # 10% warmup for training from scratch

from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR

warmup_scheduler = LinearLR(
    optimizer,
    start_factor=0.01,  # start at 1% of base LR
    total_iters=warmup_steps
)
cosine_scheduler = CosineAnnealingLR(
    optimizer,
    T_max=total_steps - warmup_steps,
    eta_min=base_lr * 0.1  # decay to 10% of base
)
scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup_scheduler, cosine_scheduler],
    milestones=[warmup_steps]
)

print(f"\nTraining schedule:")
print(f"  Total steps:     {total_steps:,}")
print(f"  Warmup steps:    {warmup_steps:,} ({100*warmup_steps/total_steps:.0f}%)")
print(f"  Base LR:         {base_lr}")
print(f"  Min LR:          {base_lr * 0.1}")
print(f"  Weight decay:    0.1 (linear), 0.0 (norm/embed/router)")
print(f"  Aux loss coef:   {aux_loss_coef}")

## Training Loop

MoE training includes:
- Main cross-entropy loss (language modeling)
- Auxiliary load balancing loss (prevents expert collapse)

Total loss = CE loss + aux_loss_coef × aux_loss

In [None]:
import time

global_step = 0
total_batches = len(train_loader)
log_interval = 100

print(f"Starting MoE training: {NUM_EPOCHS} epochs, {total_batches} batches/epoch")
print(f"Aux loss coefficient: {aux_loss_coef}")
print("-" * 80)

training_start = time.time()

for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()
    epoch_ce_loss_sum = 0.0
    epoch_aux_loss_sum = 0.0
    epoch_loss_count = 0

    for step, batch in enumerate(train_loader):
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)

        with torch.amp.autocast('cuda', dtype=dtype):
            logits = model(input_ids)
            
            # Main language modeling loss
            ce_loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
                ignore_index=IGNORE_INDEX,
            )
            
            # Auxiliary load balancing loss
            aux_loss = model.get_aux_loss()
            
            # Combined loss
            loss = ce_loss + aux_loss_coef * aux_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        global_step += 1

        # Track losses
        epoch_ce_loss_sum += ce_loss.item()
        epoch_aux_loss_sum += aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss
        epoch_loss_count += 1

        if step % log_interval == 0:
            current_lr = optimizer.param_groups[0]["lr"]
            pct_complete = 100 * (step + 1) / total_batches
            elapsed = time.time() - epoch_start
            
            if step > 0:
                steps_per_sec = step / elapsed
                remaining_steps = total_batches - step
                eta_sec = remaining_steps / steps_per_sec
                eta_str = f"{int(eta_sec // 60):02d}:{int(eta_sec % 60):02d}"
            else:
                eta_str = "--:--"
            
            avg_ce = epoch_ce_loss_sum / epoch_loss_count
            avg_aux = epoch_aux_loss_sum / epoch_loss_count
            
            print(
                f"epoch {epoch+1}/{NUM_EPOCHS} | "
                f"step {step:>5}/{total_batches} ({pct_complete:5.1f}%) | "
                f"ce {ce_loss.item():.3f} (avg {avg_ce:.3f}) | "
                f"aux {aux_loss:.3f} | "
                f"lr {current_lr:.2e} | "
                f"ETA {eta_str}"
            )

    # End of epoch summary
    epoch_elapsed = time.time() - epoch_start
    epoch_avg_ce = epoch_ce_loss_sum / epoch_loss_count
    epoch_avg_aux = epoch_aux_loss_sum / epoch_loss_count
    print("-" * 80)
    print(
        f"Epoch {epoch+1} complete | "
        f"avg CE loss: {epoch_avg_ce:.4f} | "
        f"avg aux loss: {epoch_avg_aux:.4f} | "
        f"time: {int(epoch_elapsed // 60):02d}:{int(epoch_elapsed % 60):02d}"
    )
    print("-" * 80)

    torch.save(model.state_dict(), f"GPT_MoE_epoch_{epoch}.pt")
    print(f"Saved checkpoint: GPT_MoE_epoch_{epoch}.pt")

# Final summary
total_time = time.time() - training_start
print("=" * 80)
print(f"Training complete! Total time: {int(total_time // 3600):02d}:{int((total_time % 3600) // 60):02d}:{int(total_time % 60):02d}")
print(f"Final checkpoint: GPT_MoE_epoch_{NUM_EPOCHS - 1}.pt")

In [None]:
@torch.no_grad()
def generate(
    model,
    prompt_text,
    max_new_tokens=150,
    temperature=0.7,
    top_p=0.9,
    stop_token_id=eot_id,
):
    """Generate response from a prompt string."""
    model.eval()
    
    enc = tokenizer.encode(prompt_text, allowed_special=set(special_tokens.keys()))
    input_ids = torch.tensor(enc, dtype=torch.long).unsqueeze(0).to(device)
    
    for _ in range(max_new_tokens):
        # Crop to block_size if needed
        input_ids_cond = input_ids if input_ids.shape[1] <= block_size else input_ids[:, -block_size:]
        
        logits = model(input_ids_cond)
        next_token_logits = logits[:, -1, :]
        
        # Apply temperature
        next_token_logits = next_token_logits / temperature
        
        # Top-p sampling
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True, dim=-1)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
            next_token_logits[indices_to_remove] = float('-inf')
        
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        input_ids = torch.cat([input_ids, next_token], dim=1)
        
        if stop_token_id is not None and next_token.item() == stop_token_id:
            break
    
    return tokenizer.decode(input_ids[0].tolist())


def format_prompt(user_message):
    """Format a user message into the chat template."""
    return f"<|user|>\n{user_message}\n<|assistant|>\n"


# =============================================================================
# Quick Evaluation Suite
# =============================================================================
print("=" * 80)
print("POST-TRAINING EVALUATION")
print("=" * 80)

eval_prompts = [
    # Basic instruction following
    ("Simple instruction", "Write a short greeting message."),
    
    # Knowledge recall (don't expect accuracy, just coherence)
    ("Knowledge", "What is the capital of France?"),
    
    # Reasoning (basic)
    ("Basic reasoning", "If I have 3 apples and buy 2 more, how many do I have?"),
    
    # Creative
    ("Creative", "Write a haiku about programming."),
    
    # Explanation
    ("Explanation", "Explain what a neural network is in simple terms."),
    
    # Code (if trained on code data)
    ("Code", "Write a Python function that adds two numbers."),
    
    # Multi-step
    ("Multi-step", "List 3 benefits of exercise."),
    
    # Refusal/boundary (interesting to see behavior)
    ("Edge case", "Summarize the following text: "),
]

model.eval()
for category, user_msg in eval_prompts:
    prompt = format_prompt(user_msg)
    
    print(f"\n[{category}]")
    print(f"User: {user_msg}")
    print("-" * 40)
    
    try:
        response = generate(model, prompt, max_new_tokens=150, temperature=0.7)
        # Extract just the assistant response
        if "<|assistant|>" in response:
            assistant_part = response.split("<|assistant|>\n")[-1]
            # Clean up any trailing special tokens
            assistant_part = assistant_part.replace("<|endoftext|>", "").strip()
        else:
            assistant_part = response
        print(f"Assistant: {assistant_part}")
    except Exception as e:
        print(f"Error: {e}")
    
    print("-" * 40)

# =============================================================================
# Quantitative checks
# =============================================================================
print("\n" + "=" * 80)
print("FORMAT COMPLIANCE CHECK")
print("=" * 80)

# Check if model properly terminates with EOT
test_prompts = [format_prompt(p) for _, p in eval_prompts[:3]]
eot_count = 0
total_length = 0

for prompt in test_prompts:
    enc = tokenizer.encode(prompt, allowed_special=set(special_tokens.keys()))
    input_ids = torch.tensor(enc, dtype=torch.long).unsqueeze(0).to(device)
    
    # Generate with greedy decoding for consistency
    model.eval()
    for _ in range(200):
        logits = model(input_ids[:, -block_size:] if input_ids.shape[1] > block_size else input_ids)
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
        input_ids = torch.cat([input_ids, next_token], dim=1)
        if next_token.item() == eot_id:
            eot_count += 1
            break
    
    total_length += input_ids.shape[1] - len(enc)

avg_response_len = total_length / len(test_prompts)
eot_rate = 100 * eot_count / len(test_prompts)

print(f"EOT termination rate: {eot_count}/{len(test_prompts)} ({eot_rate:.0f}%)")
print(f"Avg response length: {avg_response_len:.0f} tokens")
print(f"  → {'Good: Model learns to stop' if eot_rate > 50 else 'Warning: Model may ramble'}")
print(f"  → {'Good: Reasonable length' if 20 < avg_response_len < 150 else 'Check: Unusual response length'}")

print("\n" + "=" * 80)
print("Evaluation complete!")
print("=" * 80)