In [1]:
pip install torch transformers datasets


Note: you may need to restart the kernel to use updated packages.


In [3]:
# enhanced_gpt_moe_training.py
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import GPT2Tokenizer
from torch.amp import autocast, GradScaler

# -----------------------
# Environment and device
# -----------------------
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# -----------------------
# Model / training config
# -----------------------
MODEL_CONFIG = {
    "dim": 512,
    "depth": 8,
    "heads": 8,
    "num_experts": 12,
    "k": 2,
    "max_len": 512,
    "dropout": 0.1,
    "moe_expansion": 2.0
}

TRAINING_CONFIG = {
    "seq_len": 256,
    "batch_size": 1,
    "epochs": 6,
    "lr": 3e-4,
    "weight_decay": 0.1,
    "betas": (0.9, 0.95),
    "cosine_Tmax": 6,
    "empty_cache_every": 200
}

# -----------------------
# Rotary embeddings
# -----------------------
class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, x, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[-2]
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(t, self.inv_freq)            # (seq_len, dim/2)
        emb = torch.cat((freqs, freqs), dim=-1)         # (seq_len, dim)
        return emb.cos().to(x.dtype), emb.sin().to(x.dtype)

def apply_rotary_pos_emb(q, k, cos, sin):
    # q,k: (..., head_dim)
    def rotate_half(x):
        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

# -----------------------
# SwiGLU FFN and Expert
# -----------------------
class SwiGLU(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        return self.drop(self.w2(F.silu(self.w1(x)) * self.w3(x)))

class Expert(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.ffn = SwiGLU(dim, hidden_dim, dropout)
    def forward(self, x):
        return self.ffn(x)

# -----------------------
# Router (Top-k)
# -----------------------
class TopKRouter(nn.Module):
    def __init__(self, dim, num_experts, k=2, jitter=0.01):
        super().__init__()
        self.k = k
        self.num_experts = num_experts
        self.jitter = jitter
        self.gate = nn.Linear(dim, num_experts, bias=False)
        nn.init.normal_(self.gate.weight, 0, 0.1)

    def forward(self, x):
        # x: (B,T,D)
        if self.training and self.jitter > 0:
            x = x + torch.randn_like(x) * self.jitter
        logits = self.gate(x)  # (B,T,E)
        topk_vals, topk_idx = torch.topk(logits, self.k, dim=-1)  # (B,T,k)
        weights = torch.zeros_like(logits, dtype=topk_vals.dtype)  # (B,T,E)
        soft = F.softmax(topk_vals, dim=-1).to(logits.dtype)  # match weights dtype
        weights = torch.zeros_like(logits)                   # no need to force dtype here
        weights.scatter_(-1, topk_idx, soft)
        return weights, topk_idx  # weights: (B,T,E); topk_idx: (B,T,k)

# -----------------------
# MoE Layer (memory-efficient)
# -----------------------
class MoELayer(nn.Module):
    def __init__(self, dim, hidden_dim, num_experts, k, dropout=0.1):
        super().__init__()
        self.router = TopKRouter(dim, num_experts, k)
        self.experts = nn.ModuleList([Expert(dim, hidden_dim, dropout) for _ in range(num_experts)])
        self.num_experts = num_experts

    def forward(self, x):
        B, T, D = x.shape
        weights, topk_idx = self.router(x)  # weights: (B,T,E), topk_idx: (B,T,k)
        flat_x = x.view(-1, D)  # (B*T, D)
        flat_w = weights.view(-1, self.num_experts)  # (B*T, E)
        out_flat = torch.zeros_like(flat_x)

        # For each expert, find tokens routed to it (via topk indices)
        # Build boolean mask for tokens assigned to each expert (if any of the k indices equals e)
        # topk_idx == e -> (B,T,k) -> any(dim=-1) -> (B,T)
        for e, expert in enumerate(self.experts):
            # boolean mask (B,T) whether any of the k picks equals e
            mask_e = (topk_idx == e).any(dim=-1).view(-1)  # (B*T,)
            if not mask_e.any():
                continue
            inp = flat_x[mask_e]  # tokens assigned to this expert
            out_e = expert(inp)   # (num_tokens_for_e, D)
            w_e = flat_w[mask_e, e].unsqueeze(-1)  # (num_tokens_for_e,1)
            out_flat[mask_e] += out_e * w_e

        return out_flat.view(B, T, D)

# -----------------------
# Multi-Head Attention with RoPE & causal masking
# -----------------------
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, heads, dropout=0.1):
        super().__init__()
        assert dim % heads == 0, "dim must be divisible by heads"
        self.dim = dim
        self.heads = heads
        self.head_dim = dim // heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, 3 * dim, bias=False)
        self.proj = nn.Linear(dim, dim)
        self.drop = nn.Dropout(dropout)
        self.rope = RotaryEmbedding(self.head_dim)

    def forward(self, x, attn_mask=None):
        B, T, D = x.shape
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.view(B, T, self.heads, self.head_dim).transpose(1, 2), qkv)
        # RoPE
        cos, sin = self.rope(x, T)  # cos/sin shapes (T, head_dim)
        cos = cos.view(1, 1, T, self.head_dim)
        sin = sin.view(1, 1, T, self.head_dim)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        # Attention
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, H, T, T)
        # causal mask
        causal_mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
        attn = attn.masked_fill(causal_mask, float("-inf"))
        # padding mask
        if attn_mask is not None:
            pad_mask = (~attn_mask.bool()).view(B, 1, 1, T)  # True where padding
            attn = attn.masked_fill(pad_mask, float("-inf"))
        attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
        attn = self.drop(attn)
        out = (attn @ v).transpose(1, 2).contiguous().view(B, T, D)
        return self.proj(out)

# -----------------------
# Transformer Block & Model
# -----------------------
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, num_experts, k, dropout=0.1, moe_expansion=2.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, heads, dropout)
        self.ln2 = nn.LayerNorm(dim)
        hidden_dim = int(dim * moe_expansion)
        self.moe = MoELayer(dim, hidden_dim, num_experts, k, dropout)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        resid = x
        x = self.ln1(x)
        x = self.attn(x, attn_mask)
        x = self.drop(x)
        x = resid + x
        resid = x
        x = self.ln2(x)
        x = self.moe(x)
        x = self.drop(x)
        x = resid + x
        return x

class EnhancedGPTMoE(nn.Module):
    def __init__(self, vocab_size, dim, depth, heads, num_experts, k, max_len, dropout, moe_expansion):
        super().__init__()
        self.max_len = max_len
        self.tok = nn.Embedding(vocab_size, dim)
        self.layers = nn.ModuleList([
            TransformerBlock(dim, heads, num_experts, k, dropout, moe_expansion) for _ in range(depth)
        ])
        self.ln_f = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, vocab_size, bias=False)
        # tie weights
        self.head.weight = self.tok.weight
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0.0, 0.02)
            if getattr(m, "bias", None) is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, 0.0, 0.02)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, ids, attn_mask=None, targets=None, pad_id=None):
        B, T = ids.shape
        x = self.tok(ids)  # (B, T, D)
        for blk in self.layers:
            x = blk(x, attn_mask)
        x = self.ln_f(x)
        logits = self.head(x)  # (B, T, V)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=pad_id if pad_id is not None else -100
            )
        return logits, loss

    @torch.no_grad()
    def generate(self, ids, attn_mask=None, max_new_tokens=50, temp=0.7, top_p=0.9):
        # ids: (B, cur_len)
        for _ in range(max_new_tokens):
            if ids.shape[1] > self.max_len:
                ids = ids[:, -(self.max_len // 2):]
                if attn_mask is not None:
                    attn_mask = attn_mask[:, -(self.max_len // 2):]
            logits, _ = self(ids, attn_mask)
            logits = logits[:, -1, :] / max(temp, 1e-8)
            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                # mask tokens with cumulative prob > top_p
                sorted_indices_to_remove = cumulative_probs > top_p
                # keep first token that exceeds threshold
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = False
                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits.scatter_(1, indices_to_remove, float("-inf"))
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)  # (B,1)
            ids = torch.cat([ids, next_token], dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat([attn_mask, torch.ones((attn_mask.shape[0], 1), device=ids.device, dtype=attn_mask.dtype)], dim=1)
        return ids

# -----------------------
# Data: tokenizer, dataset, dataloader
# -----------------------
def setup_data(seq_len=256, fraction="train[:5%]"):
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=fraction)
    ds = ds.filter(lambda e: len(e["text"].strip()) > 30)

    def tok_fn(batch):
        return tokenizer(batch["text"], truncation=True, max_length=seq_len, padding="max_length")

    ds_tok = ds.map(tok_fn, batched=True, remove_columns=["text"])
    ds_tok = ds_tok.filter(lambda e: sum(e["attention_mask"]) > 10)
    ds_tok.set_format(type="torch", columns=["input_ids", "attention_mask"])

    dl = DataLoader(ds_tok, batch_size=TRAINING_CONFIG["batch_size"], shuffle=True)
    return tokenizer, dl

# -----------------------
# Training loop (OOM-safe + AMP)
# -----------------------
def train():
    seq_len = TRAINING_CONFIG["seq_len"]
    tokenizer, dataloader = setup_data(seq_len)
    vocab_size = len(tokenizer)

    model = EnhancedGPTMoE(
        vocab_size=vocab_size,
        dim=MODEL_CONFIG["dim"],
        depth=MODEL_CONFIG["depth"],
        heads=MODEL_CONFIG["heads"],
        num_experts=MODEL_CONFIG["num_experts"],
        k=MODEL_CONFIG["k"],
        max_len=MODEL_CONFIG["max_len"],
        dropout=MODEL_CONFIG["dropout"],
        moe_expansion=MODEL_CONFIG["moe_expansion"],
    ).to(device)

    n_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {n_params:,}")

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=TRAINING_CONFIG["lr"],
        betas=TRAINING_CONFIG["betas"],
        weight_decay=TRAINING_CONFIG["weight_decay"]
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=TRAINING_CONFIG["cosine_Tmax"], eta_min=1e-5
    )
    scaler = GradScaler()

    model.train()
    empty_every = TRAINING_CONFIG["empty_cache_every"]

    for epoch in range(1, TRAINING_CONFIG["epochs"] + 1):
        total_loss = 0.0
        steps = 0
        for i, batch in enumerate(dataloader, start=1):
            input_ids = batch["input_ids"].to(device, non_blocking=True)
            attn_mask = batch["attention_mask"].to(device, non_blocking=True)

            # Shift tokens: predict next token
            targets = input_ids[:, 1:].contiguous()
            input_ids = input_ids[:, :-1].contiguous()
            attn_mask = attn_mask[:, :-1].contiguous()

            optimizer.zero_grad(set_to_none=True)
            try:
                # autocast device type
                with autocast(device_type="cuda" if device == "cuda" else "cpu"):
                    logits, loss = model(input_ids, attn_mask, targets, pad_id=tokenizer.pad_token_id)

                if loss is None or torch.isnan(loss):
                    raise RuntimeError("Invalid loss (None or NaN)")

                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()

                total_loss += loss.item()
                steps += 1

            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    print("GPU OOM - skipping batch safely...")
                    optimizer.zero_grad(set_to_none=True)
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    continue
                else:
                    raise

            if i % 100 == 0 and steps > 0:
                print(f"Epoch {epoch} Step {i} Loss {loss.item():.4f}")

            if empty_every and (i % empty_every == 0) and torch.cuda.is_available():
                torch.cuda.empty_cache()

        if steps > 0:
            avg = total_loss / steps
            print(f"Epoch {epoch}/{TRAINING_CONFIG['epochs']} Avg Loss: {avg:.4f}")
        scheduler.step()

    return model, tokenizer

# -----------------------
# Inference helper
# -----------------------
def generate_text(model, tokenizer, prompt="Once upon a time", max_new_tokens=50, temp=0.7, top_p=0.9):
    model.eval()
    with torch.no_grad():
        ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
        attn = torch.ones_like(ids).to(device)
        out_ids = model.generate(ids, attn, max_new_tokens=max_new_tokens, temp=temp, top_p=top_p)
        # out_ids is tensor (1, L)
        return tokenizer.decode(out_ids[0].tolist(), skip_special_tokens=True)

# -----------------------
# Main
# -----------------------
if __name__ == "__main__":
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    model, tokenizer = train()
    sample = generate_text(model, tokenizer, "Hello world", max_new_tokens=40)
    print("\nGenerated sample:\n", sample)

    # Save a checkpoint
    save_path = "enhanced_gpt_moe_checkpoint.pt"
    torch.save({
        "model_state_dict": model.state_dict(),
        "model_config": MODEL_CONFIG,
        "vocab_size": len(tokenizer)
    }, save_path)
    print(f"Saved checkpoint to {save_path}")


Using device: cuda
Model parameters: 185,185,792
Epoch 1 Step 100 Loss 7.4320
Epoch 1 Step 200 Loss 7.2516
Epoch 1 Step 300 Loss 7.1622
Epoch 1 Step 400 Loss 6.3386
Epoch 1 Step 500 Loss 7.3772
Epoch 1 Step 600 Loss 6.2084
Epoch 1 Step 700 Loss 7.1914
Epoch 1 Step 800 Loss 7.0790
Epoch 1 Step 900 Loss 7.5319
Epoch 1/6 Avg Loss: 7.0181
Epoch 2 Step 100 Loss 6.1176
Epoch 2 Step 200 Loss 7.1716
Epoch 2 Step 300 Loss 6.5150
Epoch 2 Step 400 Loss 6.4436
Epoch 2 Step 500 Loss 4.7971
Epoch 2 Step 600 Loss 5.5999
Epoch 2 Step 700 Loss 6.1599
Epoch 2 Step 800 Loss 6.7628
Epoch 2 Step 900 Loss 7.5145
Epoch 2/6 Avg Loss: 6.2372
Epoch 3 Step 100 Loss 5.9019
Epoch 3 Step 200 Loss 5.9103
Epoch 3 Step 300 Loss 6.2394
Epoch 3 Step 400 Loss 6.3933
Epoch 3 Step 500 Loss 6.2708
Epoch 3 Step 600 Loss 6.4812
Epoch 3 Step 700 Loss 6.5506
Epoch 3 Step 800 Loss 5.3172
Epoch 3 Step 900 Loss 2.4263
Epoch 3/6 Avg Loss: 5.9500
Epoch 4 Step 100 Loss 7.5102
Epoch 4 Step 200 Loss 6.3402
Epoch 4 Step 300 Loss 6.3631


RuntimeError: Index tensor must have the same number of dimensions as self tensor