# GPT with Mixture of Experts (MoE)

In `GPT_pretraining.ipynb` we added a bunch of architectural improvements and optimization tricks, but we still worked with a tiny model (`GPT2-small` scale so ~ $120M$ parameters) which restricted our final output (which wasn't too bad !). In this quick notebook, we try to scale the model a bit and experiment with the idea of mixture of experts.


- **Mixture of Experts (MoE)**: Multiple expert MLPs with learned routing. Only top-k experts activated per token.
  - More total parameters but same compute per forward pass
  - Each expert can specialize in different types of tokens/patterns
- **QK-Norm**: RMSNorm on queries and keys before RoPE for training stability
- **Load balancing loss**: Prevents expert collapse (all tokens going to same expert)

Expected behavior:
- ~4x total parameters vs dense model at same compute budget
- Better quality at fixed FLOP budget
- Requires load balancing to prevent expert collapse

In [None]:
# Standard library
import csv
import math
import multiprocessing
import os
import random
import time
from pprint import pprint
from datetime import datetime

# Environment config
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


# Environment config
from huggingface_hub import login
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
if os.path.isfile(".env"):
    with open(".env") as f:
        for line in f:
            key, value = line.strip().split("=")
            os.environ[key] = value

    HF_TOKEN = os.environ["HF_TOKEN"]
    login(HF_TOKEN)

# Third-party
import numpy as np
import tiktoken
from datasets import Dataset as ds, concatenate_datasets, load_dataset
from rotary_embedding_torch import RotaryEmbedding

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# Torch runtime config
torch.set_float32_matmul_precision("medium")
torch.cuda.empty_cache()

# Custom
from utils import count_parameters, load_synthetic_data, strip_compile_prefix, round_up, clean_columns

## Config & Model Definition

Mostly the same code as the pretraining notebook, including Flash Attention, RMSNorm etc.

In [None]:
#### CONFIG #####

# Model architecture - Moderate scale MoE
# Strategy: Go deeper (better for small models) + MoE for capacity
block_size = 1024
batch_size = 10       # conservative for MoE memory overhead
embed_dim = 1024      # GPT-2 Medium width
num_layers = 16       # deeper than GPT-2 Small (12), helps reasoning
num_heads = 16        # head_dim = 64, good for efficiency
dropout_prob = 0.1    # should help at this scale
mlp_ratio = 4

# MoE config
num_experts = 8       # 8 experts total
top_k_experts = 2     # activate 2 per token → 4x params, same compute
aux_loss_coef = 0.01  # load balancing (0.01-0.1 typical range)

# Expected params:
#   Total:  ~800M (all experts)
#   Active: ~250M (similar compute to GPT-2 Medium)

# Training
MAX_STEPS = 700_000                        # Total number of micro-batches to process
GRAD_ACCUM_STEPS = 20                      # Gradient accumulation steps
LOG_INTERVAL = MAX_STEPS // 1000           # Log every xxx micro-batches
num_workers = 4
prefetch = 8
dtype = torch.bfloat16
device = "cuda"
model_path = "gpt_moe_pretrain.pt"

# Estimated VRAM usage (bf16):
#   Model params:     ~800M × 2B = 1.6 GB
#   Optimizer states: ~800M × 8B = 6.4 GB (AdamW)
#   Gradients:        ~800M × 2B = 1.6 GB
#   Activations:      ~8-12 GB (batch=12, seq=1024)
#   Total:            ~18-22 GB → fits in 32GB comfortably

print("=" * 60)
print("MoE Model Configuration")
print("=" * 60)
print(f"  Architecture:    {embed_dim}d × {num_layers}L × {num_heads}H")
print(f"  MoE:             {num_experts} experts, top-{top_k_experts}")
print(f"  Batch:           {batch_size} × {block_size} tokens")
print(f"  Grad Accum:      {GRAD_ACCUM_STEPS} steps")
print(f"  Effective Batch: {batch_size * GRAD_ACCUM_STEPS * block_size:,} tokens")
print(f"  Params:          ~800M total, ~250M active")
print(f"  bf16 supported:  {torch.cuda.is_bf16_supported()}")
print("=" * 60)

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention with QK-Norm for improved training stability."""
    
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 rotary_emb: RotaryEmbedding,
                 causal: bool = True,
                 dropout: float = 0.1
                ):
        super().__init__()
        if embed_dim % num_heads != 0:
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads}).")
        
        self.causal = causal
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.dropout_p = dropout
        
        # Fused QKV projection
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        
        # QK-Norm: normalize queries and keys before RoPE
        self.q_norm = nn.RMSNorm(self.head_dim)
        self.k_norm = nn.RMSNorm(self.head_dim)
        
        self.rotary_emb = rotary_emb
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
    
    def forward(self, x, k_v_cache=None):
        B, T, _ = x.shape
        using_cache = k_v_cache is not None and "K" in k_v_cache
    
        if using_cache:
            x_q = x[:, -1:, :]
            qkv = self.qkv_proj(x_q)
        else:
            qkv = self.qkv_proj(x)
        
        Q, K, V = qkv.chunk(3, dim=-1)
        
        def split_heads(t):
            return t.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        Q = split_heads(Q)
        K = split_heads(K)
        V = split_heads(V)
        
        # Apply QK-Norm before RoPE
        Q = self.q_norm(Q)
        K = self.k_norm(K)
    
        # Apply RoPE
        if using_cache:
            past_len = k_v_cache["K"].shape[-2]
            Q = self.rotary_emb.rotate_queries_or_keys(Q, offset=past_len)
            K = self.rotary_emb.rotate_queries_or_keys(K, offset=past_len)
            
            K = torch.cat([k_v_cache["K"], K], dim=-2)
            V = torch.cat([k_v_cache["V"], V], dim=-2)
            is_causal_step = False
        else:
            Q = self.rotary_emb.rotate_queries_or_keys(Q)
            K = self.rotary_emb.rotate_queries_or_keys(K)
            is_causal_step = self.causal
    
        if k_v_cache is not None:
            k_v_cache["K"] = K.detach()
            k_v_cache["V"] = V.detach()
    
        out = F.scaled_dot_product_attention(
            query=Q, key=K, value=V,
            attn_mask=None, 
            dropout_p=self.dropout_p if self.training else 0.0,
            is_causal=is_causal_step
        )
        
        out = out.transpose(1, 2).contiguous().view(B, -1, self.embed_dim)
        out = self.out_proj(out)
        return out, k_v_cache


class Expert(nn.Module):
    """Single expert MLP (SwiGLU)."""
    
    def __init__(self, embed_dim, hidden_dim, dropout_prob=0.1):
        super().__init__()
        hidden_dim = round_up(2 * hidden_dim // 3, 8)

        # Still using fused implementation for efficiency
        self.gate_up_proj = nn.Linear(embed_dim, 2 * hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=False)
        self.dropout = nn.Dropout(dropout_prob)
       
    def forward(self, x):
        gate_up = self.gate_up_proj(x)
        gate, up = gate_up.chunk(2, dim=-1)
        return self.dropout(self.down_proj(F.silu(gate) * up))


class MoEMLP(nn.Module):
    """
    Mixture of Experts MLP layer.
    
    Routes each token to top-k experts and combines their outputs.
    Includes auxiliary load balancing loss to prevent expert collapse.
    """
    
    def __init__(self, embed_dim, hidden_dim, num_experts, top_k, dropout_prob=0.1):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.embed_dim = embed_dim
        
        # Router: learned linear layer to score experts
        self.router = nn.Linear(embed_dim, num_experts, bias=False)
        
        # Expert MLPs
        self.experts = nn.ModuleList([
            Expert(embed_dim, hidden_dim, dropout_prob) 
            for _ in range(num_experts)
        ])
        
        # For tracking load balancing
        self.aux_loss = 0.0
    
    def forward(self, x):
        """
        Args:
            x: (B, T, D) input tensor
        Returns:
            output: (B, T, D) combined expert outputs
        """
        B, T, D = x.shape
        x_flat = x.view(-1, D)  # (B*T, D)
        num_tokens = x_flat.shape[0]
        
        # Compute router logits and probabilities
        router_logits = self.router(x_flat)  # (B*T, num_experts)
        router_probs = F.softmax(router_logits, dim=-1)
        
        # Select top-k experts per token
        top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
        
        # Normalize top-k probabilities to sum to 1
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
        
        # Compute auxiliary load balancing loss
        # Goal: encourage uniform expert utilization
        if self.training:
            # Fraction of tokens routed to each expert
            expert_mask = F.one_hot(top_k_indices, num_classes=self.num_experts).sum(dim=1)  # (B*T, E)
            tokens_per_expert = expert_mask.float().mean(dim=0)  # (E,)
            
            # Average router probability per expert
            router_prob_per_expert = router_probs.mean(dim=0)  # (E,)
            
            # Load balancing loss: minimize the product (encourages uniformity)
            self.aux_loss = self.num_experts * (tokens_per_expert * router_prob_per_expert).sum()
        
        # Compute expert outputs (batched for efficiency)
        # This is the "loop over experts" approach - simpler than sparse dispatch
        output = torch.zeros_like(x_flat)
        
        for expert_idx in range(self.num_experts):
            # Find which tokens selected this expert in their top-k
            # expert_mask[i, j] = 1 if token i selected expert expert_idx in position j of top-k
            expert_mask = (top_k_indices == expert_idx)  # (B*T, top_k)
            
            if not expert_mask.any():
                continue
            
            # Get tokens that use this expert
            token_indices = expert_mask.any(dim=-1).nonzero(as_tuple=True)[0]
            
            if len(token_indices) == 0:
                continue
                
            # Get the weight for this expert for these tokens
            # Shape: (num_selected_tokens,)
            weights = (top_k_probs * expert_mask.float()).sum(dim=-1)[token_indices]
            
            # Compute expert output
            expert_input = x_flat[token_indices]
            expert_output = self.experts[expert_idx](expert_input)
            
            # Weighted addition to output
            output[token_indices] += weights.unsqueeze(-1) * expert_output
        
        return output.view(B, T, D)


class TransformerBlock(nn.Module):
    """Transformer block with MoE MLP."""
    
    def __init__(self,
                 embed_dim,
                 num_heads,
                 rotary_emb,
                 mlp_ratio=4,
                 num_experts=8,
                 top_k_experts=2,
                 dropout_prob=0.1,
                 causal=True,
                ): 
        super().__init__()
        self.norm1 = nn.RMSNorm(embed_dim)
        self.mha = MultiHeadAttention(embed_dim, num_heads, rotary_emb, causal, dropout_prob)
        self.norm2 = nn.RMSNorm(embed_dim)
        
        # MoE instead of standard MLP
        hidden_dim = mlp_ratio * embed_dim
        self.moe = MoEMLP(embed_dim, hidden_dim, num_experts, top_k_experts, dropout_prob)
    
    def forward(self, x, cache=None):
        x1 = self.norm1(x)
        x2, cache = self.mha(x1, cache)
        x2 = x2 + x  # residual
    
        x3 = self.norm2(x2)
        x3 = self.moe(x3) + x2  # residual
        return x3, cache
    
    def get_aux_loss(self):
        """Return the MoE auxiliary loss for this block."""
        return self.moe.aux_loss

In [None]:
class GPT_MoE(nn.Module):
    """
    GPT with Mixture of Experts.
    
    Same architecture as base GPT but with MoE layers replacing standard MLPs.
    Includes auxiliary loss collection for load balancing.
    """

    def __init__(self,
                 vocab_size,
                 embed_dim,
                 num_layers,
                 num_heads,
                 mlp_ratio=4,
                 num_experts=8,
                 top_k_experts=2,
                 dropout_prob=0.1,
                 is_causal=True,
                ):
        super().__init__()

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        self.num_experts = num_experts
        self.top_k_experts = top_k_experts

        self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
        self.dropout = nn.Dropout(dropout_prob)
        
        head_dim = embed_dim // num_heads
        self.rotary_emb = RotaryEmbedding(dim=head_dim)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(
                embed_dim, num_heads, self.rotary_emb, 
                mlp_ratio, num_experts, top_k_experts,
                dropout_prob, is_causal
            ) 
            for _ in range(num_layers)
        ])
        self.norm = nn.RMSNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # weight tying

        # Initialize weights
        self.apply(self._init_weights)
        # Scale residual projections
        for pn, p in self.named_parameters():
            if pn.endswith(("out_proj.weight", "down_proj.weight")):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.num_layers))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
       
    def forward(self, tokens):
        embeddings = self.embedding(tokens)
        x = self.dropout(embeddings)
        for b in self.blocks:
            x, _ = b(x)
        features = self.norm(x)
        return self.lm_head(features)
    
    def get_aux_loss(self):
        """Collect and sum auxiliary losses from all MoE layers."""
        total_aux_loss = 0.0
        for block in self.blocks:
            total_aux_loss += block.get_aux_loss()
        return total_aux_loss / self.num_layers  # average over layers

    @property
    def device(self):
        return next(self.parameters()).device
    
    def count_parameters(self):
        """Count total and active parameters."""
        total_params = sum(p.numel() for p in self.parameters())
        
        # Estimate active params (non-MoE + top-k fraction of MoE)
        non_moe_params = 0
        moe_params = 0
        
        for name, p in self.named_parameters():
            if 'experts' in name:
                moe_params += p.numel()
            else:
                non_moe_params += p.numel()
        
        # Active MoE params = (top_k / num_experts) * total_moe_params
        active_moe = moe_params * (self.top_k_experts / self.num_experts)
        active_params = non_moe_params + active_moe
        
        return total_params, int(active_params)

    @torch.no_grad()
    def generate(self,
                 prompt_tokens,
                 max_new_tokens=50,
                 temperature=1.0,
                 top_k=0,
                 top_p=0.0,
                 use_cache=True,
                ):
        self.eval()

        tokens_out = prompt_tokens.clone()
        current_tokens = prompt_tokens.clone()
        tokens_out = tokens_out.to(self.device)
        current_tokens = current_tokens.to(self.device)
        cache = [{} if use_cache else None for _ in range(len(self.blocks))]
        
        for _ in range(max_new_tokens):
            x = self.embedding(current_tokens)
            for i, b in enumerate(self.blocks):
                x, c_i = b(x, cache[i])
                cache[i] = c_i
            
            features = self.norm(x)
            logits = self.lm_head(features)    
            last_logits = logits[:, -1, :]
    
            if temperature == 0:
                next_token = torch.argmax(last_logits, dim=-1, keepdim=True)
            else:
                scaled_logits = last_logits / temperature
                
                if int(top_k) > 0:
                    values, indices = torch.topk(scaled_logits, top_k)
                    scaled_logits = torch.full_like(scaled_logits, float('-inf'))
                    scaled_logits.scatter_(1, indices, values)

                if top_p > 0.0 and top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True, dim=-1)
                    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0
                    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                    scaled_logits[indices_to_remove] = float('-inf')
                
                probs = torch.softmax(scaled_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)

            if next_token.item() == eot_id:
                break
            
            tokens_out = torch.cat([tokens_out, next_token], dim=1)
            current_tokens = next_token if use_cache else tokens_out
       
        return tokens_out

## Tokenizer

For pretraining we use the base GPT-2 tokenizer

In [None]:
# ============================================================================
# Setup Tokenizer (tiktoken gpt2 - 50,257 vocab)
# ============================================================================

# print("Loading tokenizer...")
# tokenizer = tiktoken.get_encoding("gpt2")
# eot_id = tokenizer.eot_token

# # Pad vocab to nearest 128 multiple for GPU efficiency
# vocab_size = round_up(tokenizer.n_vocab, 128)

# print(f"✓ Loaded tiktoken gpt2")
# print(f"  Vocab size: {tokenizer.n_vocab} → padded: {vocab_size}")
# print(f"  EOT token ID: {eot_id}")

In [None]:
# Option 2: Mistral tokenizer
from transformers import AutoTokenizer

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
eot_id = tokenizer.eos_token_id
vocab_size = round_up(len(tokenizer), 128)
print(f"✓ Loaded Mistral tokenizer")
print(f"  Vocab size: {len(tokenizer)} → padded: {vocab_size}")
print(f"  EOT token ID: {eot_id}")

In [None]:
# ============================================================================
# Load and Clean Datasets
# ============================================================================
print("\nLoading & cleaning up datasets...")
cleaned_datasets = {
    # Main dataset, high quality web crawl
    "fineweb-edu": load_dataset(
        "HuggingFaceFW/fineweb-edu", 
        name="sample-10BT", 
        split="train",
        trust_remote_code=True
    ),
}

cleaned_datasets = {n: clean_columns(d) for n, d in cleaned_datasets.items()}

print(f"\n✓ Loaded {len(cleaned_datasets)} datasets")
for n, ds in cleaned_datasets.items():
    print(f"  Dataset {n}: {len(ds):,} examples")


print("\nConcatenating datasets...")
train_ds = concatenate_datasets(list(cleaned_datasets.values()))
print(f"  Combined size: {len(train_ds):,} examples")


print("Shuffling...")
train_ds = train_ds.shuffle(seed=42)
print(f"✓ Final Train Size: {len(train_ds):,} rows")


def clean_text(example):
    """Remove special tokens and other artifacts"""
    text = example["text"]
    
    # Remove all common special tokens
    special_tokens = [
        "<|endoftext|>", "<|im_start|>", "<|im_end|>",
        "<|system|>", "<|user|>", "<|assistant|>",
        "<s>", "</s>", "[INST]", "[/INST]",
        "<<SYS>>", "<</SYS>>"
    ]
    
    for token in special_tokens:
        text = text.replace(token, "")
    
    example["text"] = text
    return example

# Tokenizer is going to complain if special tokens are found in training data
print("\nCleaning special tokens from datasets...")
train_ds = train_ds.map(
    clean_text,
    num_proc=os.cpu_count() // 2,
    desc="Cleaning special tokens"
)

# ============================================================================
# Document Packing Function
# ============================================================================
def pack_documents(examples):
    """
    Concatenate all documents in the batch, then slice into fixed-size blocks.
    Each document is terminated with exactly ONE EOT token.
    
    Output chunks are of length (block_size + 1), suitable for
    x = chunk[:-1], y = chunk[1:].
    """
    all_tokens = []

    for text in examples["text"]:
        if not text or not text.strip():
            continue

        try:
            doc_ids = tokenizer.encode(text)
        except Exception as e:
            continue  # skip bad docs safely

        doc_ids.append(eot_id)  # exactly one end-of-text
        all_tokens.extend(doc_ids)

    # Now chop into blocks of (block_size + 1)
    chunks = []
    total = len(all_tokens)

    for i in range(0, total, block_size):
        chunk = all_tokens[i : i + block_size + 1]
        if len(chunk) == block_size + 1:
            chunks.append(chunk)
        # else: drop the final tiny tail (standard practice)

    return {"chunk_ids": chunks}


# ============================================================================
# Apply Processing
# ============================================================================
print("\nTokenizing and packing documents (this may take a few minutes)...")
train_tokenized = train_ds.map(
    pack_documents,
    batched=True,
    batch_size=1250,
    num_proc=multiprocessing.cpu_count() // 2,
    remove_columns=train_ds.column_names,
    desc="Packing documents"
)

print(f"\n✓ Tokenization complete!")
print(f"  Total chunks: {len(train_tokenized):,}")
print(f"  Chunk size: {block_size + 1} tokens")
print(f"  Approx total tokens: {len(train_tokenized) * block_size:,}")

# ============================================================================
# Create DataLoader
# ============================================================================
class TokenDataset(Dataset):
    def __init__(self, tokenized_data):
        self.data = tokenized_data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        chunk = self.data[idx]["chunk_ids"]
        # Return input (x) and target (y) shifted by 1
        # [0, 1, 2, ..., block_size-1]
        # [1, 2, 3, ..., block_size]
        return chunk[:-1].long(), chunk[1:].long()

# Can also use numpy format and cast to tensors in the loader, not sure which one is slower
train_tokenized.set_format(type="torch", columns=["chunk_ids"])
token_dataset = TokenDataset(train_tokenized)

train_loader = DataLoader(
    token_dataset,
    batch_size=batch_size,
    prefetch_factor=prefetch,
    shuffle=True,
    drop_last=True,  # CRITICAL for torch.compile ! If one batch has dif shape it will trigger re-compilation and slow down training a lot
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True,
)

print(f"\n✓ DataLoader ready")
print(f"  Batches per epoch: {len(train_loader):,}")
print(f"  Tokens per batch: {batch_size * block_size:,}")

## Model Initialization

**Note**: MoE architecture is different from dense GPT, so we train from scratch.
Cannot load pretrained weights from the dense model.

In [None]:
model_config = {
    "vocab_size": vocab_size,
    "embed_dim": embed_dim,
    "num_layers": num_layers,
    "num_heads": num_heads,
    "mlp_ratio": mlp_ratio,
    "num_experts": num_experts,
    "top_k_experts": top_k_experts,
    "dropout_prob": dropout_prob,
}

print("Initializing MoE model with config:")
pprint(model_config)

model = GPT_MoE(**model_config).to(device)

# Count parameters
total_params, active_params = model.count_parameters()
print(f"\nParameter counts:")
print(f"  Total parameters:  {total_params:,} ({total_params/1e6:.1f}M)")
print(f"  Active parameters: {active_params:,} ({active_params/1e6:.1f}M)")
print(f"  Ratio: {total_params/active_params:.2f}x total vs active")

model.train()

## Optimizer

For MoE models:
- Same weight decay rules (don't decay norms, embeddings)
- Router parameters should NOT be decayed (they're like attention-ish)
- Add auxiliary load balancing loss to prevent expert collapse

In [None]:
# Separate parameters into decay and no-decay groups
decay_params = []
no_decay_params = []

for name, param in model.named_parameters():
    if not param.requires_grad:
        continue
    # Don't apply weight decay to:
    # - Norm parameters (RMSNorm weights, QK-norm)
    # - Embedding table (tied with lm_head)
    # - Router (small, should be flexible)
    # - Any bias terms
    if any(kw in name.lower() for kw in ['norm', 'bias', 'embed', 'embedding', 'lm_head', 'router']):
        no_decay_params.append(param)
    else:
        decay_params.append(param)

print(f"Decay params: {len(decay_params)}, No decay params: {len(no_decay_params)}")

# Hyperparameters tuned for ~800M param MoE model
# Rule of thumb: larger models need lower LR
base_lr = 2e-4  # slightly lower than 124M model

optimizer = torch.optim.AdamW([
    {'params': decay_params, 'weight_decay': 0.1},
    {'params': no_decay_params, 'weight_decay': 0.0}
],
    lr=base_lr,
    betas=(0.9, 0.95),  # standard for LLMs
    eps=1e-8,
)

# Learning rate scheduler based on OPTIMIZER steps (not micro steps!)
total_optim_steps = MAX_STEPS // GRAD_ACCUM_STEPS
warmup_steps = int(total_optim_steps * 0.05)  # 5% warmup

from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR

warmup_scheduler = LinearLR(
    optimizer,
    start_factor=0.01,  # start at 1% of base LR
    total_iters=warmup_steps
)
cosine_scheduler = CosineAnnealingLR(
    optimizer,
    T_max=total_optim_steps - warmup_steps,
    eta_min=base_lr * 0.1  # decay to 10% of base
)
scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup_scheduler, cosine_scheduler],
    milestones=[warmup_steps]
)

print(f"\nTraining schedule:")
print(f"  Total micro steps: {MAX_STEPS:,}")
print(f"  Grad accum steps:  {GRAD_ACCUM_STEPS}")
print(f"  Total optim steps: {total_optim_steps:,}")
print(f"  Warmup steps:      {warmup_steps:,} ({100*warmup_steps/total_optim_steps:.0f}%)")
print(f"  Base LR:           {base_lr}")
print(f"  Min LR:            {base_lr * 0.1}")
print(f"  Weight decay:      0.1 (linear), 0.0 (norm/embed/router)")
print(f"  Aux loss coef:     {aux_loss_coef}")

In [None]:
# ============================================================================
# Fused Cross-Entropy Loss (Triton)
# ============================================================================
# Saves ~3GB memory and is ~2x faster than standard F.cross_entropy for large vocab

import triton
import triton.language as tl

@triton.jit
def fused_cross_entropy_fwd_kernel(
    logits_ptr, losses_ptr, lse_ptr, targets_ptr,
    stride_logits_row,
    n_cols,
    BLOCK_SIZE: tl.constexpr
):
    row_idx = tl.program_id(0).to(tl.int64)
    logits_row_ptr = logits_ptr + row_idx * stride_logits_row
    
    max_val = -float('inf')
    sum_exp = 0.0
    
    for off in range(0, n_cols, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        mask = cols < n_cols
        logits = tl.load(logits_row_ptr + cols, mask=mask, other=-float('inf')).to(tl.float32)
        
        curr_max = tl.max(logits, axis=0)
        new_max = tl.maximum(max_val, curr_max)
        sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(logits - new_max), axis=0)
        max_val = new_max
    
    lse = max_val + tl.log(sum_exp)
    tl.store(lse_ptr + row_idx, lse)
    
    target = tl.load(targets_ptr + row_idx).to(tl.int64)
    target_logit = tl.load(logits_row_ptr + target).to(tl.float32)
    loss = lse - target_logit
    
    tl.store(losses_ptr + row_idx, loss)


@triton.jit  
def fused_cross_entropy_bwd_kernel(
    grad_input_ptr, grad_output_ptr, lse_ptr, logits_ptr, targets_ptr,
    stride_logits_row, stride_grad_row,
    n_cols,
    BLOCK_SIZE: tl.constexpr
):
    row_idx = tl.program_id(0).to(tl.int64)
    logits_row_ptr = logits_ptr + row_idx * stride_logits_row
    grad_row_ptr = grad_input_ptr + row_idx * stride_grad_row
    
    lse = tl.load(lse_ptr + row_idx)
    grad_loss = tl.load(grad_output_ptr + row_idx)
    target = tl.load(targets_ptr + row_idx).to(tl.int64)
    
    for off in range(0, n_cols, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        mask = cols < n_cols
        
        logits = tl.load(logits_row_ptr + cols, mask=mask, other=0.0).to(tl.float32)
        probs = tl.exp(logits - lse)
        is_target = (cols == target).to(tl.float32)
        grad = grad_loss * (probs - is_target)
        
        tl.store(grad_row_ptr + cols, grad.to(tl.bfloat16), mask=mask)


class FusedCrossEntropyLoss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, logits, targets):
        n_rows, n_cols = logits.shape
        
        losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
        lse = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
        
        logits = logits.contiguous()
        targets = targets.contiguous()
        
        grid = (n_rows,)
        fused_cross_entropy_fwd_kernel[grid](
            logits, losses, lse, targets,
            logits.stride(0),
            n_cols,
            BLOCK_SIZE=1024,
            num_warps=8,
        )
        
        ctx.save_for_backward(logits, targets, lse)
        return losses
    
    @staticmethod
    def backward(ctx, grad_output):
        logits, targets, lse = ctx.saved_tensors
        n_rows, n_cols = logits.shape
        
        grad_input = torch.empty_like(logits, dtype=torch.bfloat16)
        grad_output = grad_output.contiguous()
        
        grid = (n_rows,)
        fused_cross_entropy_bwd_kernel[grid](
            grad_input, grad_output, lse, logits, targets,
            logits.stride(0), grad_input.stride(0),
            n_cols,
            BLOCK_SIZE=1024,
            num_warps=8,
        )
        return grad_input, None


def fused_cross_entropy(logits, targets):
    """Drop-in replacement for F.cross_entropy with fused kernel."""
    return FusedCrossEntropyLoss.apply(logits, targets).mean()


print("✓ Fused cross-entropy loss defined")

In [None]:
# ============================================================================
# Compile Model
# ============================================================================

torch._inductor.config.coordinate_descent_tuning = True
model = torch.compile(model, dynamic=False, fullgraph=False)
model.to(device)
model.train()

# Warmup passes to force compilation
print("Warming up compilation...")
x = torch.randint(0, vocab_size, (batch_size, block_size), device=device)
y = x.clone()

s = time.time()
for i in range(5):
    start = time.time()
    with torch.autocast("cuda", torch.bfloat16):
        logits = model(x)
        ce_loss = fused_cross_entropy(
            logits.view(-1, logits.size(-1)),
            y.view(-1)
        )
        aux_loss = model.get_aux_loss()
        loss = ce_loss + aux_loss_coef * aux_loss
    loss.backward()
    model.zero_grad(set_to_none=True)
    torch.cuda.synchronize()
    print(f"  Warmup {i+1}: {time.time() - start:.2f}s")

print(f"✓ Model fully compiled in {time.time() - s:.2f}s")

del x, y
torch.cuda.empty_cache()

## Training Loop

MoE training includes:
- Main cross-entropy loss (language modeling)
- Auxiliary load balancing loss (prevents expert collapse)

Total loss = CE loss + aux_loss_coef × aux_loss

In [None]:
# --- CSV Logger ---
log_file = f'{model_path.split(".")[0]}__{datetime.now().strftime("%Y-%m-%d")}__pretraining_logs.csv'
print("Saving logs in : ", log_file)
file_exists = os.path.isfile(log_file)
with open(log_file, "a", newline="") as f:
    writer = csv.writer(f)
    if not file_exists:
        writer.writerow(["micro_step", "optim_step", "ce_loss", "aux_loss", "total_loss", "lr", "tokens_seen", "tokens_per_sec", "timestamp"])

# --- Training Loop ---
micro_step = 0
optim_step = 0
tokens_seen = 0

# Accumulate losses on GPU to avoid CPU-GPU sync every step
running_ce_loss = torch.zeros(1, device=device)
running_aux_loss = torch.zeros(1, device=device)

start_time = time.time()
start_training = time.time()
last_tokens_seen = 0

model_params = decay_params + no_decay_params
optimizer.zero_grad(set_to_none=True)

print(f"\nStarting training for {MAX_STEPS:,} micro steps...")
print(f"Logging every {LOG_INTERVAL} steps, checkpointing every 50k steps\n")

while micro_step < MAX_STEPS:
    for x, y in train_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        B, T = x.shape
        tokens_seen += B * T

        # --- Forward ---
        with torch.autocast(device_type="cuda", dtype=dtype):
            logits = model(x)
            ce_loss = fused_cross_entropy(
                logits.view(-1, logits.size(-1)),
                y.view(-1)
            )
            # MoE auxiliary loss for load balancing
            aux_loss = model.get_aux_loss()
            loss = ce_loss + aux_loss_coef * aux_loss

        # --- Backward (gradient accumulation) ---
        (loss / GRAD_ACCUM_STEPS).backward()

        # --- Optimizer step ---
        if (micro_step + 1) % GRAD_ACCUM_STEPS == 0:
            torch.nn.utils.clip_grad_norm_(model_params, 1.0)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()
            optim_step += 1

        # --- Bookkeeping (.item() triggers cpu-gpu sync so we avoid it) ---
        running_ce_loss += ce_loss.detach()
        running_aux_loss += aux_loss.detach() if isinstance(aux_loss, torch.Tensor) else aux_loss
        micro_step += 1

        # --- Logging ---
        if micro_step % LOG_INTERVAL == 0:
            # Only sync with CPU here (once per LOG_INTERVAL steps)
            avg_ce_loss = (running_ce_loss / LOG_INTERVAL).item()
            avg_aux_loss = (running_aux_loss / LOG_INTERVAL).item()
            avg_total_loss = avg_ce_loss + aux_loss_coef * avg_aux_loss
            
            elapsed = time.time() - start_time
            tokens_delta = tokens_seen - last_tokens_seen
            tokens_per_sec = tokens_delta / elapsed

            current_lr = optimizer.param_groups[0]["lr"]
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

            print(
                f"step {micro_step:06d} | "
                f"opt {optim_step:05d} | "
                f"ce {avg_ce_loss:.3f} | "
                f"aux {avg_aux_loss:.3f} | "
                f"lr {current_lr:.2e} | "
                f"{tokens_per_sec:,.0f} tok/s"
            )

            with open(log_file, "a", newline="") as f:
                writer = csv.writer(f)
                writer.writerow([
                    micro_step,
                    optim_step,
                    f"{avg_ce_loss:.4f}",
                    f"{avg_aux_loss:.4f}",
                    f"{avg_total_loss:.4f}",
                    f"{current_lr:.2e}",
                    tokens_seen,
                    int(tokens_per_sec),
                    timestamp,
                ])

            running_ce_loss.zero_()
            running_aux_loss.zero_()
            start_time = time.time()
            last_tokens_seen = tokens_seen

        # --- Checkpointing ---
        if micro_step % 50_000 == 0 and micro_step > 0:
            mid_model_path = model_path.replace(".pt", f"_{micro_step}.pt")
            print(f"Saving intermediate model to {mid_model_path}")
            torch.save(model.state_dict(), mid_model_path)

        # --- Exit ---
        if micro_step >= MAX_STEPS:
            elapsed = int(time.time() - start_training)
            h, m, s = elapsed // 3600, (elapsed % 3600) // 60, elapsed % 60
            print(f"\nProcessed {tokens_seen:,} tokens in {h:02d}:{m:02d}:{s:02d}")
            print(f"Saving final model to {model_path}")
            torch.save(model.state_dict(), model_path)
            break

In [None]:
# Quick test generation
prompt = "The mixture of experts architecture allows" 
x = torch.tensor(tokenizer.encode(prompt))

model.eval()
out = model.generate(
    x.unsqueeze(0).to(device),
    max_new_tokens=100,
    temperature=0.9,
    top_p=0.95,
    top_k=0,
    use_cache=True,
)

print("\nPrompt:", prompt)
print("Output:", tokenizer.decode(out[0].tolist()))