In [None]:
import os, math, random, gc, sys, time
import itertools
from itertools import cycle
from typing import Any, Sequence, Tuple, Optional, Iterable
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import sentencepiece as spm
import numpy as np
from torch.amp import GradScaler
import matplotlib.pyplot as plt
import re

In [None]:
# ========== Runtime / Performance Configurations ==========
USE_CHECKPOINTING = False
CHECKPOINT_ATTENTION_ONLY = False
ACCUM_STEPS = 1
BATCH_SIZE = 16
SEQ_LEN = 512
NUM_WORKERS = 6
PIN_MEMORY = True
PERSISTENT_WORKERS = True
MOE_CHUNK_SIZE = 32768
MOE_NUM_EXPERTS = 32
MOE_TOP_K = 6
MOE_CAPACITY_FACTOR = 2.5  
MOE_AUX_LAMBDA = 0.0025  
OOM_RETRY_LIMIT = 3
EPOCHS = 5
STEPS_PER_EPOCH = 50000
LOG_INTERVAL = 50
EVAL_INTERVAL = 200
STORY_GEN_INTERVAL = 1000
COMPARISON_GRAPH_INTERVAL = 5000
VOCAB_SIZE = 8000
n_embd = 768
n_head = 8
n_layer = 8
dropout = 0.0
learning_rate = 6e-4  
weight_decay = 0.1
warmup_ratio = 0.2

In [None]:
# ========== Utility Functions ==========
def cleanup_memory():
    try:
        torch.cuda.empty_cache()
    except Exception:
        pass
    gc.collect()

def set_seed(seed: int = 1337):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_device():
    if torch.cuda.is_available():
        torch.cuda.set_device(0)
        return torch.device("cuda:0")
    return torch.device("cpu")

torch.backends.cuda.matmul.allow_tf32 = True
try:
    torch.set_float32_matmul_precision('high')
except Exception:
    pass

In [None]:
# ========== Mixed Precision Manager ==========
class MixedPrecisionManager:
    def __init__(self, device):
        self.device = device
        self.use_bf16 = (device.type == "cuda")
        self.use_fp16 = False
        self.use_fp32 = False
        self.scaler = None
        if self.use_bf16:
            self._info = "start: bf16 (no GradScaler)"
        else:
            self.use_fp32 = True
            self._info = "start: fp32"

    def autocast_dtype(self):
        if self.use_bf16 and self.device.type == "cuda":
            return torch.bfloat16
        elif self.use_fp16 and self.device.type == "cuda":
            return torch.float16
        else:
            return torch.float32

    def enable_fp16_with_scaler(self):
        self.use_bf16 = False
        self.use_fp16 = True
        self.use_fp32 = False
        self.scaler = GradScaler(enabled=(self.device.type == "cuda"))
        self._info = "switched to fp16 + GradScaler"

    def enable_fp32(self):
        self.use_bf16 = False
        self.use_fp16 = False
        self.use_fp32 = True
        self.scaler = None
        self._info = "switched to fp32"

    def info(self):
        return self._info

In [None]:
# ========== NanoKimiK2 Components ==========

# SwiGLU Module
class SwiGLU(nn.Module):
    def __init__(self, in_features, out_features=None):
        super().__init__()
        out_features = out_features or in_features
        self.fc1 = nn.Linear(in_features, out_features, bias=True)
        self.fc2 = nn.Linear(in_features, out_features, bias=True)
    def forward(self, x):
        return F.silu(self.fc1(x)) * self.fc2(x)

# Expert Module
class Expert(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
        self.act = SwiGLU(hidden_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

# Mixture of Experts (MoE) Layer
class MoELayer(nn.Module):
    def __init__(self, in_features, out_features,
                 num_experts=MOE_NUM_EXPERTS, top_k=MOE_TOP_K,
                 capacity_factor=MOE_CAPACITY_FACTOR, chunk_size=MOE_CHUNK_SIZE):
        super().__init__()
        assert out_features == in_features, "Residual-friendly MoE requires d_in == d_out."
        self.d = in_features
        self.num_experts = num_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor
        self.chunk_size = chunk_size
        self.entropy_lambda = 0.05

        self.gate = nn.Linear(in_features, num_experts, bias=False)
        self.experts = nn.ModuleList([Expert(in_features, in_features * 4, out_features) for _ in range(num_experts)])
        self.shared_expert = Expert(in_features, in_features * 4, out_features)

    def forward(self, x):
        B, T, C = x.shape
        x_flat = x.reshape(B*T, C)
        device = x.device
        N = x_flat.size(0)
        e = self.num_experts
        d = self.d
        k = self.top_k

        logits = self.gate(x_flat)  # (N, e)
        S = F.softmax(logits, dim=1)  # (N, e)

        # Transpose to expert-centric view
        S_T = S.transpose(0, 1)  # (e, N)

        # Compute capacity per expert
        capacity = int(self.capacity_factor * (N / e))

        # Select top-capacity tokens per expert
        G, I = torch.topk(S_T, capacity, dim=1)  # G (e, capacity), I (e, capacity)

        # Create permutation matrix P (one-hot)
        P = torch.zeros(e, capacity, N, dtype=x_flat.dtype, device=device)
        P.scatter_(2, I.unsqueeze(2), 1)  # (e, capacity, N)

        # Gather inputs for experts
        X_in = torch.bmm(P, x_flat.unsqueeze(0).repeat(e, 1, 1))  # (e, capacity, d)

        # Compute expert outputs
        X_e = torch.zeros(e, capacity, d, dtype=x_flat.dtype, device=device)
        for i in range(e):
            X_e[i] = self.experts[i](X_in[i])

        # Aggregate: Weight by G and permute back
        weighted_X_e = G.unsqueeze(2) * X_e  # (e, capacity, d)
        out_flat = torch.einsum('ecd,ecn->nd', weighted_X_e, P)  # (N, d)

        # Add shared expert
        out_flat += self.shared_expert(x_flat)

        # Aux loss 
        importance = S.mean(dim=0)  # (e,)
        load = (I >= 0).sum(dim=1).float() / N  # (e,)
        imp_cv = importance.std() / (importance.mean() + 1e-12)
        load_cv = load.std() / (load.mean() + 1e-12)
        entropy = -(S * torch.log(S + 1e-12)).sum(dim=-1).mean()
        aux_loss = e * (imp_cv + load_cv) + self.entropy_lambda * entropy

        # Stats 
        stats = {
            "importance_cv": float(imp_cv.item()),
            "load_cv": float(load_cv.item()),
            "overflow_pct": 0.0,
            "dropped_pct": 0.0
        }

        return out_flat.view(B, T, C), aux_loss, stats

# RoPE Cache Builder
def build_rope_cache(max_seq_len: int, head_dim: int, device: torch.device):
    inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
    positions = torch.arange(0, max_seq_len, device=device).float()
    freqs = torch.einsum("i,j->ij", positions, inv_freq)
    emb = torch.cat((freqs, freqs), dim=-1)
    cos = torch.cos(emb).to(device)
    sin = torch.sin(emb).to(device)
    return cos, sin

# RoPE Application
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
    B,H,T,D = x.shape
    cos = cos[:T, :].unsqueeze(0).unsqueeze(0)
    sin = sin[:T, :].unsqueeze(0).unsqueeze(0)
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    x_rot_even = x1 * cos[..., ::2] - x2 * sin[..., ::2]
    x_rot_odd  = x1 * sin[..., ::2] + x2 * cos[..., ::2]
    x_out = torch.stack((x_rot_even, x_rot_odd), dim=-1).flatten(-2)
    return x_out

# MLAAttention
class MLAAttention(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 q_lora_rank: int = 192,
                 kv_lora_rank: int = 32,
                 rope_dim: int = 32,
                 head_dim: int = 224,
                 dropout: float = 0.0,
                 max_seq_len: int = SEQ_LEN):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.rope_dim = rope_dim
        assert head_dim == (q_lora_rank + rope_dim), "head_dim must equal q_lora_rank + rope_dim."

        self.q_a_proj = nn.Linear(embed_dim, q_lora_rank, bias=False)
        self.q_b_proj = nn.Linear(q_lora_rank, num_heads * head_dim, bias=False)
        self.kv_a_proj = nn.Linear(embed_dim, kv_lora_rank, bias=False)
        self.kv_b_proj = nn.Linear(kv_lora_rank, num_heads * head_dim * 2, bias=False)
        self.out_proj = nn.Linear(num_heads * head_dim, embed_dim, bias=False)
        self.dropout_p = dropout

        cos, sin = build_rope_cache(max_seq_len, rope_dim, device=torch.device("cpu"))
        self.register_buffer("rope_cos", cos, persistent=False)
        self.register_buffer("rope_sin", sin, persistent=False)

        self.attn_dropout = nn.Dropout(dropout)

    def _shape_heads(self, x: torch.Tensor, B: int, T: int, D: int):
        return x.view(B, T, self.num_heads, D).transpose(1, 2).contiguous()

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
        B, T, C = x.shape
        device = x.device
        dtype = x.dtype

        q = self.q_b_proj(F.silu(self.q_a_proj(x)))
        q = q.view(B, T, self.num_heads, self.head_dim)
        kv = self.kv_b_proj(F.silu(self.kv_a_proj(x)))
        kv = kv.view(B, T, self.num_heads, self.head_dim * 2)
        k, v = kv.chunk(2, dim=-1)

        q_nope = q[..., :self.q_lora_rank]
        q_pe = q[..., self.q_lora_rank:]
        k_nope = k[..., :self.q_lora_rank]
        k_pe = k[..., self.q_lora_rank:]
        v = v

        cos = self.rope_cos.to(device)
        sin = self.rope_sin.to(device)
        q_pe = apply_rope(q_pe, cos, sin)
        k_pe = apply_rope(k_pe, cos, sin)

        q = torch.cat([q_nope, q_pe], dim=-1)
        k = torch.cat([k_nope, k_pe], dim=-1)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        scale = math.sqrt(self.head_dim)
        scores = torch.matmul(q, k.transpose(-2, -1)) / scale
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        attn = self.attn_dropout(attn)
        y = torch.matmul(attn, v)

        y = y.transpose(1, 2).contiguous().view(B, T, self.num_heads * self.head_dim)
        y = self.out_proj(y)
        return y

# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, max_seq_len,
                 q_lora_rank=192, kv_lora_rank=32, rope_dim=32, head_dim=224, dropout=0.0):
        super().__init__()
        self.attn = MLAAttention(embed_dim, num_heads,
                                 q_lora_rank=q_lora_rank, kv_lora_rank=kv_lora_rank,
                                 rope_dim=rope_dim, head_dim=head_dim,
                                 dropout=dropout, max_seq_len=max_seq_len)
        self.moe = MoELayer(embed_dim, embed_dim,
                            num_experts=MOE_NUM_EXPERTS,
                            top_k=MOE_TOP_K,
                            capacity_factor=MOE_CAPACITY_FACTOR)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        aux_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)

        def attn_fn(h):
            return self.attn(self.norm1(h))
        if self.training and USE_CHECKPOINTING:
            attn_out = torch.utils.checkpoint.checkpoint(attn_fn, x, use_reentrant=True, preserve_rng_state=False)
        else:
            attn_out = attn_fn(x)
        x = x + attn_out

        def moe_fn(h):
            return self.moe(h)
        if self.training and USE_CHECKPOINTING:
            moe_out_tuple = torch.utils.checkpoint.checkpoint(lambda h: moe_fn(h), self.norm2(x), use_reentrant=True, preserve_rng_state=False)
        else:
            moe_out_tuple = moe_fn(self.norm2(x))

        if isinstance(moe_out_tuple, tuple) and len(moe_out_tuple) == 3:
            moe_out, moe_aux, moe_stats = moe_out_tuple
        elif isinstance(moe_out_tuple, tuple) and len(moe_out_tuple) == 2:
            moe_out, moe_aux = moe_out_tuple
            moe_stats = {}
        else:
            moe_out = moe_out_tuple
            moe_aux = torch.tensor(0.0, device=x.device, dtype=x.dtype)
            moe_stats = {}

        x = x + moe_out
        aux_loss = aux_loss + moe_aux
        return x, aux_loss, moe_stats

# NanoKimi Transformer Model
class NanoKimiTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim=768, max_length=512, num_layers=8, num_heads=8,
                 q_lora_rank=192, kv_lora_rank=32, rope_dim=32, head_dim=224, dropout=0.0):
        super().__init__()
        self.max_length = max_length
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, max_seq_len=max_length,
                             q_lora_rank=q_lora_rank, kv_lora_rank=kv_lora_rank,
                             rope_dim=rope_dim, head_dim=head_dim, dropout=dropout)
            for _ in range(num_layers)
        ])
        self.norm_f = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size, bias=False)

    def forward(self, x):
        B, T = x.shape
        h = self.token_embed(x)
        total_aux = torch.tensor(0.0, device=h.device, dtype=h.dtype)
        total_importance_cv = 0.0
        total_load_cv = 0.0
        total_overflow = 0.0
        stats_count = 0
        for blk in self.blocks:
            h, aux, stats = blk(h)
            total_aux = total_aux + aux
            if isinstance(stats, dict) and stats:
                total_importance_cv += stats.get("importance_cv", 0.0)
                total_load_cv += stats.get("load_cv", 0.0)
                total_overflow += stats.get("overflow_pct", 0.0)
                stats_count += 1
        h = self.norm_f(h)
        logits = self.head(h)
        agg_stats = {}
        if stats_count > 0:
            agg_stats = {
                "importance_cv": total_importance_cv / stats_count,
                "load_cv": total_load_cv / stats_count,
                "overflow_pct": total_overflow / stats_count
            }
        else:
            agg_stats = {"importance_cv": 0.0, "load_cv": 0.0, "overflow_pct": 0.0}
        return logits, total_aux, agg_stats

# Utility Functions
def _add_identity(M, eps):
    if M.ndim == 2:
        return M + eps * torch.eye(M.size(0), device=M.device, dtype=M.dtype)
    elif M.ndim == 3:
        I = torch.eye(M.size(-1), device=M.device, dtype=M.dtype).expand(M.size(0), -1, -1)
        return M + eps * I
    return M

@torch.no_grad()
def _newton_schulz_inverse_pth_root(A, p=2, iters=5, eps=1e-6):
    A = 0.5 * (A + A.transpose(-1, -2))
    A = _add_identity(A, eps)
    norm = A.norm(p='fro') + 1e-12
    Y = A / norm
    I = torch.eye(A.size(-1), device=A.device, dtype=A.dtype)
    Z = torch.eye(A.size(-1), device=A.device, dtype=A.dtype)
    for _ in range(iters):
        T = (p + 1) * 0.5 * I - 0.5 * (Y @ Z + Z @ Y) / p
        Y = Y @ T
        Z = T @ Z
    Z = Z * (norm ** (-1.0 / p))
    return Z

def _rms(x: torch.Tensor) -> torch.Tensor:
    return (x.pow(2).mean()).sqrt()

# Muon Optimizer
class Muon(Optimizer):
    def __init__(self,
                 params,
                 lr=6e-4,
                 beta1=0.9, beta2=0.999,
                 eps=1e-8,
                 wd=0.1,
                 ns_iters=5,
                 precond_eps=1e-4,
                 precond_update_freq=20,
                 adam_bias_correction=True,
                 qkclip_tau=20.0,
                 qkclip_every=1,
                 qkclip_calibrate_default=1.0,
                 qkclip_decay=0.95):
        defaults = dict(lr=lr, beta1=beta1, beta2=beta2, eps=eps, wd=wd,
                        ns_iters=ns_iters, precond_eps=precond_eps,
                        precond_update_freq=precond_update_freq,
                        adam_bias_correction=adam_bias_correction,
                        qkclip_tau=qkclip_tau, qkclip_every=qkclip_every,
                        qkclip_calibrate_default=qkclip_calibrate_default,
                        qkclip_decay=qkclip_decay)
        super().__init__(params, defaults)
        self._parent_modules = []
        self._had_forward_once = False

    def bind_modules(self, modules_list):
        if not isinstance(modules_list, (list, tuple)):
            raise ValueError("bind_modules expects list/tuple of modules")
        self._parent_modules = modules_list

        for m in self._iter_attention_modules():
            if not hasattr(m, "_qkclip_rms"):
                m._qkclip_rms = None
            if not hasattr(m, "_qkclip_decay"):
                m._qkclip_decay = None

            def _make_hook(mod):
                def _hook(module, inputs):
                    x = inputs[0]
                    with torch.no_grad():
                        val = _rms(x.detach()).to(x.dtype)
                        if mod._qkclip_rms is None:
                            mod._qkclip_rms = val.detach()
                            mod._qkclip_decay = None
                        else:
                            decay = mod._qkclip_decay if mod._qkclip_decay is not None else 0.95
                            mod._qkclip_rms.mul_(decay).add_(val, alpha=(1.0 - decay))
                    return None
                return _hook

            if not hasattr(m, "_qkclip_hook_handle"):
                m._qkclip_hook_handle = m.register_forward_pre_hook(_make_hook(m), with_kwargs=False)

    def _iter_attention_modules(self):
        for m in self._parent_modules:
            if isinstance(m, MLAAttention):
                yield m

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            beta1, beta2 = group['beta1'], group['beta2']
            eps = group['eps']
            wd = group['wd']
            ns_iters = group['ns_iters']
            precond_eps = group['precond_eps']
            precond_update_freq = group['precond_update_freq']
            adam_bias_correction = group['adam_bias_correction']

            for p in group['params']:
                if p.grad is None:
                    continue

                st = self.state[p]
                if len(st) == 0:
                    st['step'] = 0
                    st['m'] = torch.zeros_like(p)
                    st['v'] = torch.zeros_like(p)
                    if p.ndim >= 2:
                        st['G'] = torch.zeros((p.shape[-2], p.shape[-2]), device=p.device, dtype=p.dtype)
                        st['H'] = torch.zeros((p.shape[-1], p.shape[-1]), device=p.device, dtype=p.dtype)

                st['step'] += 1

                if wd != 0:
                    p.add_(p, alpha=-lr * wd)

                g = p.grad
                v = st['v']
                v.mul_(beta2).addcmul_(g, g, value=1 - beta2)

                if p.ndim >= 2:
                    G = st['G']; H = st['H']
                    G.mul_(beta2).addmm_(g, g.t(), beta=1 - beta2)
                    H.mul_(beta2).addmm_(g.t(), g, beta=1 - beta2)

                    if (st['step'] % precond_update_freq == 0) or ('G_inv_root' not in st):
                        st['G_inv_root'] = _newton_schulz_inverse_pth_root(G, p=2, iters=ns_iters, eps=precond_eps)
                        st['H_inv_root'] = _newton_schulz_inverse_pth_root(H, p=2, iters=ns_iters, eps=precond_eps)

                    Ginv = st['G_inv_root']; Hinv = st['H_inv_root']
                    g_pre = Ginv @ g @ Hinv
                else:
                    g_pre = g

                if adam_bias_correction:
                    t = st['step']
                    v_hat = v / (1.0 - (beta2 ** t))
                else:
                    v_hat = v
                denom = v_hat.sqrt().add_(eps)
                g_adam_ref = g / denom

                rms_pre = _rms(g_pre)
                rms_ref = _rms(g_adam_ref)
                scale = (rms_ref / (rms_pre + 1e-16)).clamp_(min=0.0, max=1e6)
                g_muon = g_pre * scale

                p.add_(g_muon, alpha=-lr)

        self._had_forward_once = True

        for group in self.param_groups:
            tau = group['qkclip_tau']
            every = group['qkclip_every']
            decay = group['qkclip_decay']
            default_rms = group['qkclip_calibrate_default']

            any_param = None
            for p in group['params']:
                if p in self.state:
                    any_param = p
                    break
            if any_param is None:
                continue
            t = self.state[any_param].get('step', 1)
            if (t % max(1, every)) != 0:
                continue

            for attn in self._iter_attention_modules():
                if hasattr(attn, "_qkclip_rms") and attn._qkclip_rms is not None:
                    attn._qkclip_decay = decay
                    a_rms = float(attn._qkclip_rms.detach().item())
                else:
                    a_rms = float(default_rms)

                H = attn.num_heads
                dh = attn.head_dim

                if hasattr(attn, "q_b_proj") and hasattr(attn, "q_a_proj") and hasattr(attn, "kv_b_proj") and hasattr(attn, "kv_a_proj"):
                    W_qb = attn.q_b_proj.weight
                    W_qa = attn.q_a_proj.weight
                    W_kvb = attn.kv_b_proj.weight
                    W_kva = attn.kv_a_proj.weight
                    try:
                        Mq = W_qb @ W_qa
                        Mk = W_kvb @ W_kva
                        used_composite = True
                    except Exception:
                        used_composite = False

                if not used_composite:
                    continue

                for h in range(H):
                    rs = h * dh
                    re = rs + dh
                    if rs >= Mq.size(0) or re > Mq.size(0):
                        continue
                    Mq_h = Mq[rs:re, :]
                    Mk_h = Mk[rs:re, :]

                    nq = torch.linalg.norm(Mq_h, ord='fro')
                    nk = torch.linalg.norm(Mk_h, ord='fro')
                    s_h = (float(nq.item()) * a_rms) * (float(nk.item()) * a_rms) / math.sqrt(max(1, dh))

                    if s_h > tau and math.isfinite(s_h):
                        alpha = math.sqrt(tau / s_h)
                        attn.q_b_proj.weight[rs:re, :].mul_(alpha)
                        attn.kv_b_proj.weight[rs:re, :].mul_(alpha)

        return loss

In [None]:
# ========== Data ==========

# Story Dataset
class StoryDataset(Dataset):
    def __init__(self, tokens, max_length=512, stride=512):
        self.max_length = max_length
        self.input_ids = []
        self.target_ids = []
        if len(tokens) < max_length:
            return
        for i in range(0, max(1, len(tokens) - max_length) + 1, stride):
            inp = tokens[i:i+max_length]
            if len(inp) < max_length:
                continue
            tgt = tokens[i+1:i+max_length+1]
            self.input_ids.append(torch.tensor(inp, dtype=torch.long))
            self.target_ids.append(torch.tensor(tgt, dtype=torch.long))
    def __len__(self):
        return len(self.input_ids)
    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

# Collate Function
def collate(batch: Sequence):
    xs, ys = zip(*batch)
    x = torch.stack([torch.as_tensor(xx, dtype=torch.long) for xx in xs])
    y = torch.stack([torch.as_tensor(yy, dtype=torch.long) for yy in ys])
    return x, y

# DataLoader Creation
def make_dataloaders(train_ds, val_ds, batch_size):
    train_loader = DataLoader(train_ds, batch_size=batch_size,
                              shuffle=True,
                              num_workers=min(NUM_WORKERS, 8),
                              pin_memory=PIN_MEMORY,
                              persistent_workers=PERSISTENT_WORKERS,
                              collate_fn=collate, drop_last=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size,
                            shuffle=False,
                            num_workers=min(NUM_WORKERS, 8),
                            pin_memory=PIN_MEMORY,
                            persistent_workers=PERSISTENT_WORKERS,
                            collate_fn=collate, drop_last=False)
    return train_loader, val_loader

# Bits Per Character (BPC) Computation
def compute_bpc(nll, tokens, encoding, is_gpt=False):
    if is_gpt:
        text = ''.join([encoding[t.item()] for t in tokens.flatten()])
    else:
        text = encoding.decode(tokens.flatten().tolist())
    num_chars = len(text)
    num_tokens = tokens.numel()
    return (nll * num_tokens) / max(1, num_chars) / math.log(2)

In [None]:
# ========== Eval & Generation for NanoKimiK2 ==========

# Evaluation Function
@torch.no_grad()
def evaluate_kimi(model, loader, device, mp_mgr, encoding):
    model.eval()
    total_nll_acc = 0.0
    total_aux = 0.0
    total_targets = 0
    total_importance_cv = 0.0
    total_load_cv = 0.0
    total_overflow = 0.0
    total_bpc = 0.0
    count = 0
    dtype = mp_mgr.autocast_dtype()
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        with torch.amp.autocast(device_type="cuda" if device.type=="cuda" else "cpu", dtype=dtype):
            logits, aux, stats = model(xb)
            nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)), yb.reshape(-1))
            _ = nll + MOE_AUX_LAMBDA * aux
        num_targets = yb.numel()
        total_nll_acc += float(nll.detach().cpu().item()) * num_targets
        total_aux += float(aux.detach().cpu().item()) if torch.is_tensor(aux) else float(aux)
        total_importance_cv += stats.get("importance_cv", 0.0)
        total_load_cv += stats.get("load_cv", 0.0)
        total_overflow += stats.get("overflow_pct", 0.0)
        bpc_batch = compute_bpc(float(nll.detach().cpu().item()), yb.cpu(), encoding, is_gpt=False)
        total_bpc += bpc_batch
        total_targets += num_targets
        count += 1

    if count == 0:
        return float("inf"), float("inf"), float("inf"), float("inf"), {}

    mean_nll = total_nll_acc / max(1, total_targets)
    mean_aux = total_aux / count
    perplexity = math.exp(mean_nll)
    mean_bpc = total_bpc / count
    avg_stats = {
        "importance_cv": total_importance_cv / max(1, count),
        "load_cv": total_load_cv / max(1, count),
        "overflow_pct": total_overflow / max(1, count),
        "avg_aux": mean_aux
    }
    torch.cuda.empty_cache()
    return perplexity, mean_nll, mean_aux, mean_bpc, avg_stats

# Generation Function
@torch.no_grad()
def generate_kimi(model, encoding, prompt, device, mp_mgr, max_tokens=200, temperature=1.2, top_k=100):
    model.eval()
    ids = encoding.encode(prompt, out_type=int)
    x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
    for _ in range(max_tokens):
        dtype = mp_mgr.autocast_dtype()
        with torch.amp.autocast(device_type="cuda" if device.type=="cuda" else "cpu", dtype=dtype):
            logits, _, _ = model(x)
            logits = logits[:, -1, :] / max(1e-6, temperature)
            if top_k is not None and top_k > 0:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                thresh = v[:, -1].unsqueeze(-1)
                logits = torch.where(logits < thresh, torch.full_like(logits, -float('inf')), logits)
            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
        x = torch.cat([x, next_id], dim=1)
        if x.size(1) > SEQ_LEN:
            x = x[:, -SEQ_LEN:]
    return encoding.decode(x[0].tolist())

In [None]:
# ========== NanoGPT Components ==========

# Attention Head
class Head(nn.Module):
    def __init__(self, n_embd: int, head_size: int, block_size: int, dropout: float):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
        self.scale = head_size ** -0.5
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * self.scale
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out

# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd: int, num_heads: int, block_size: int, dropout: float):
        super().__init__()
        head_size = n_embd // num_heads
        self.heads = nn.ModuleList([Head(n_embd, head_size, block_size, dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

# Feed-Forward Network
class FeedForward(nn.Module):
    def __init__(self, n_embd: int, dropout: float):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )
    def forward(self, x):
        return self.net(x)

# Transformer Block
class Block(nn.Module):
    def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
        super().__init__()
        self.sa = MultiHeadAttention(n_embd, n_head, block_size, dropout)
        self.ffwd = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# GPT Language Model
class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size: int, block_size: int, n_embd: int, n_head: int, n_layer: int, dropout: float):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
        self.block_size = block_size
        self.apply(self._init_weights)
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    def forward(self, idx, targets: Optional[torch.Tensor] = None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=tok_emb.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    @torch.no_grad()
    def generate(self, idx, max_new_tokens: int, temperature: float = 1.2, top_k: Optional[int] = 100):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / max(1e-6, temperature)
            if top_k is not None and top_k > 0:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                thresh = v[:, -1].unsqueeze(-1)
                logits = torch.where(logits < thresh, torch.full_like(logits, -float('inf')), logits)
            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_id), dim=1)
        return idx

In [None]:
# ========== Eval & Generation for NanoGPT ==========

# Evaluation Function for NanoGPT
@torch.no_grad()
def evaluate_gpt(model, loader, device, mp_mgr, encoding):
    model.eval()
    total_nll_acc = 0.0
    total_targets = 0
    total_bpc = 0.0
    count = 0
    dtype = mp_mgr.autocast_dtype()
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        with torch.amp.autocast(device_type="cuda" if device.type=="cuda" else "cpu", dtype=dtype):
            _, loss = model(xb, yb)
            if loss is None:
                continue
            nll = loss
        num_targets = yb.numel()
        total_nll_acc += float(nll.detach().cpu().item()) * num_targets
        bpc_batch = compute_bpc(float(nll.detach().cpu().item()), yb.cpu(), encoding, is_gpt=False)
        total_bpc += bpc_batch
        total_targets += num_targets
        count += 1
    if count == 0:
        return float("inf"), float("inf"), float("inf"), {}
    mean_nll = total_nll_acc / max(1, total_targets)
    perplexity = math.exp(mean_nll)
    mean_bpc = total_bpc / count
    stats = {"mean_nll": mean_nll}
    return perplexity, mean_nll, mean_bpc, stats

# Generation Function for NanoGPT
@torch.no_grad()
def generate_gpt(model, encoding, prompt, device, mp_mgr, max_tokens=200, temperature=1.2, top_k=100):
    model.eval()
    ids = encoding.encode(prompt, out_type=int)
    x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
    dtype = mp_mgr.autocast_dtype()
    with torch.amp.autocast(device_type="cuda" if device.type=="cuda" else "cpu", dtype=dtype):
        sample_ids = model.generate(x, max_tokens, temperature, top_k)
    return encoding.decode(sample_ids[0].tolist())

In [None]:
# ========== Scheduler ==========
def get_cosine_with_warmup_scheduler(optimizer, warmup_steps, total_steps, min_multiplier=0.0):
    def lr_lambda(current_step):
        if current_step < warmup_steps and warmup_steps > 0:
            return float(current_step) / float(max(1, warmup_steps))
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(min_multiplier, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [None]:
# ========== Aux multiplier for NanoKimiK2 ==========
def compute_aux_multiplier(step: int, batch_size: int):
    step = max(1, int(step))
    bs_scale = math.log1p(max(1, batch_size))
    step_scale = math.sqrt(math.log1p(step))
    multiplier = min(1.0, max(0.5, (bs_scale * step_scale) / 10.0))
    return multiplier

In [None]:
# ========== Main ==========

# Main Training Function
def main():
    cleanup_memory()
    set_seed(1337)

    # Device Setup
    device = get_device()
    print("Device:", device, "| cuda available:", torch.cuda.is_available())
    if device.type == "cuda":
        try:
            print(torch.cuda.get_device_name(0))
        except Exception:
            pass
        print("VRAM reserved (MB):", torch.cuda.memory_reserved()/1e6)

    # Load TinyStories Dataset
    dataset = load_dataset("roneneldan/TinyStories")
    texts = [ex["text"] for ex in dataset["train"]]
    with open("tinystories.txt", "w", encoding="utf-8") as f:
        for line in texts:
            f.write(line.strip() + "\n")

    # Tokenizer Training
    def train_tokenizer(vocab_size=8000, model_type="unigram"):
        model_prefix = f"spm_tinystories_{vocab_size}"
        spm.SentencePieceTrainer.Train(
            input="tinystories.txt",
            model_prefix=model_prefix,
            vocab_size=vocab_size,
            model_type=model_type,
            character_coverage=1.0,
            byte_fallback=True,
            unk_id=0, bos_id=1, eos_id=2, pad_id=3
        )
        print(f"Tokenizer trained: {model_prefix}.model")
        return model_prefix + ".model"

    model_file = train_tokenizer(8000)
    encoding = spm.SentencePieceProcessor(model_file=model_file)
    train_tokens = [tok for ex in dataset["train"] for tok in encoding.encode(ex["text"], out_type=int)]
    val_tokens = [tok for ex in dataset["validation"] for tok in encoding.encode(ex["text"], out_type=int)]
    vocab_size = encoding.get_piece_size()

    # Dataset Preparation
    train_ds = StoryDataset(train_tokens, max_length=SEQ_LEN, stride=512)
    val_ds = StoryDataset(val_tokens, max_length=SEQ_LEN, stride=512)
    if len(train_ds) == 0:
        raise RuntimeError("Train dataset empty — lower SEQ_LEN or check tokenization.")

    # Print General Information
    print("General Information:")
    print(f"  Context Length: {SEQ_LEN} tokens")
    print(f"  Training Dataset Size: {len(train_ds)} samples")
    print(f"  Validation Dataset Size: {len(val_ds)} samples")
    print(f"  Vocabulary Size: {VOCAB_SIZE} tokens")
    print(f"  Embedding dimension: {n_embd}")
    print(f"  Dimension of input & Hidden layer: {n_embd}")
    print(f"  Number of Epochs: {EPOCHS}")
    print(f"  Batch Size: {BATCH_SIZE}")
    print(f"  MoE Experts: {MOE_NUM_EXPERTS}")
    print(f"  MoE Top-K: {MOE_TOP_K}")
    print(f"  Train dataset size: {len(train_ds)} samples")
    print(f"  Validation dataset size: {len(val_ds)} samples")
    print(f"  Batch size: {BATCH_SIZE}")
    print(f"  Steps per epoch: {STEPS_PER_EPOCH}")
    print(f"  Epochs: {EPOCHS}")

    # Mixed Precision Manager
    mp_mgr = MixedPrecisionManager(device)

    # Initialize NanoKimiK2 Model
    model_kimi = NanoKimiTransformer(
        vocab_size=vocab_size,
        embed_dim=n_embd,
        num_layers=n_layer,
        num_heads=n_head,
        q_lora_rank=192,
        kv_lora_rank=32,
        rope_dim=32,
        head_dim=224,
        dropout=dropout,
        max_length=SEQ_LEN
    ).to(device, dtype=(torch.bfloat16 if mp_mgr.use_bf16 and device.type=="cuda" else torch.float32))

    # Initialize Muon Optimizer for NanoKimiK2
    optimizer_kimi = Muon(
        model_kimi.parameters(),
        lr=learning_rate,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
        wd=weight_decay,
        ns_iters=5,
        precond_eps=1e-4,
        precond_update_freq=20,
        adam_bias_correction=True,
        qkclip_tau=20.0,
        qkclip_every=1,
        qkclip_calibrate_default=1.0,
        qkclip_decay=0.95
    )
    optimizer_kimi.bind_modules([model_kimi])

    # Initialize NanoGPT Model
    model_gpt = GPTLanguageModel(
        vocab_size=vocab_size,
        block_size=SEQ_LEN,
        n_embd=n_embd,
        n_head=n_head,
        n_layer=n_layer,
        dropout=dropout
    ).to(device, dtype=(torch.bfloat16 if mp_mgr.use_bf16 and device.type=="cuda" else torch.float32))

    # Initialize AdamW Optimizer for NanoGPT
    optimizer_gpt = torch.optim.AdamW(model_gpt.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # DataLoader Setup
    current_batch = BATCH_SIZE
    train_loader, val_loader = make_dataloaders(train_ds, val_ds, current_batch)

    # Learning Rate Schedulers
    total_steps = EPOCHS * STEPS_PER_EPOCH
    warmup_steps = min(1000, total_steps//10)
    scheduler_kimi = get_cosine_with_warmup_scheduler(optimizer_kimi, warmup_steps, total_steps)
    scheduler_gpt = get_cosine_with_warmup_scheduler(optimizer_gpt, warmup_steps, total_steps)

    # Print Training Configuration
    print(f"Train dataset size: {len(train_ds)} samples")
    print(f"Validation dataset size: {len(val_ds)} samples")
    print(f"Batch size: {current_batch}")
    print(f"Steps per epoch: {STEPS_PER_EPOCH}")
    print(f"Epochs: {EPOCHS}")
    print(f"With eval every {EVAL_INTERVAL} steps.\n")

    # Mixed Precision Setup
    if mp_mgr.use_bf16:
        scaler = None
        print("Starting in bfloat16 (no GradScaler).")
    else:
        scaler = GradScaler(enabled=(device.type=="cuda"))
        print("Starting in fp32 with GradScaler as applicable.")

    print("Starting training loop. MoE experts:", MOE_NUM_EXPERTS, "top_k:", MOE_TOP_K, "chunk_size:", MOE_CHUNK_SIZE)

    # Training Loop
    global_step = 0
    first_run_done = False

    eval_prompts = [
        "Once, there was a small boy named Charlie. Charlie loved stories more than anything else."
    ]

    # Lists for Tracking Losses
    train_losses_kimi = []
    train_losses_gpt = []
    val_losses_kimi = []
    val_losses_gpt = []

    for epoch in range(EPOCHS):
        model_kimi.train()
        model_gpt.train()
        running_loss_kimi = running_nll_kimi = running_aux_kimi = 0.0
        running_aux_multiplier = 0.0
        running_scaled_aux_kimi = 0.0
        running_importance_cv_kimi = running_load_cv_kimi = running_overflow_kimi = 0.0
        running_loss_gpt = 0.0
        running_count = 0

        train_iter = cycle(iter(train_loader))
        optimizer_kimi.zero_grad(set_to_none=True)
        optimizer_gpt.zero_grad(set_to_none=True)

        for step_in_epoch in range(1, STEPS_PER_EPOCH + 1):
            global_step += 1
            attempt = 0
            x, y = next(train_iter)

            # Training Step with OOM Handling
            while attempt <= OOM_RETRY_LIMIT:
                try:
                    x = x.to(device, non_blocking=True)
                    y = y.to(device, non_blocking=True)

                    dtype = mp_mgr.autocast_dtype()
                    with torch.amp.autocast(device_type="cuda" if device.type=="cuda" else "cpu", dtype=dtype):
                        logits_kimi, aux_kimi, stats_kimi = model_kimi(x)
                        nll_kimi = F.cross_entropy(logits_kimi.view(-1, logits_kimi.size(-1)), y.view(-1))
                        eff_batch = x.size(0) * max(1, ACCUM_STEPS)
                        aux_multiplier = compute_aux_multiplier(global_step, eff_batch)
                        scaled_aux_kimi = aux_kimi * aux_multiplier
                        loss_kimi = nll_kimi + MOE_AUX_LAMBDA * scaled_aux_kimi
                        loss_kimi = loss_kimi / max(1, ACCUM_STEPS)

                        _, loss_gpt = model_gpt(x, y)
                        loss_gpt = loss_gpt / max(1, ACCUM_STEPS)

                    if mp_mgr.use_fp16:
                        if mp_mgr.scaler is None:
                            mp_mgr.scaler = GradScaler(enabled=True)
                        mp_mgr.scaler.scale(loss_kimi).backward()
                        mp_mgr.scaler.scale(loss_gpt).backward()
                    else:
                        loss_kimi.backward()
                        loss_gpt.backward()

                    if (global_step % max(1, ACCUM_STEPS)) == 0:
                        if mp_mgr.use_fp16:
                            mp_mgr.scaler.unscale_(optimizer_kimi)
                            mp_mgr.scaler.unscale_(optimizer_gpt)
                        torch.nn.utils.clip_grad_norm_(model_kimi.parameters(), 1.0)
                        torch.nn.utils.clip_grad_norm_(model_gpt.parameters(), 1.0)
                        if mp_mgr.use_fp16:
                            mp_mgr.scaler.step(optimizer_kimi)
                            mp_mgr.scaler.step(optimizer_gpt)
                            mp_mgr.scaler.update()
                        else:
                            optimizer_kimi.step()
                            optimizer_gpt.step()
                        scheduler_kimi.step()
                        scheduler_gpt.step()
                        optimizer_kimi.zero_grad(set_to_none=True)
                        optimizer_gpt.zero_grad(set_to_none=True)

                    reported_loss_kimi = float((loss_kimi * max(1, ACCUM_STEPS)).detach().cpu().item())
                    train_losses_kimi.append(reported_loss_kimi)
                    running_loss_kimi += reported_loss_kimi
                    running_nll_kimi += float(nll_kimi.detach().cpu().item())
                    running_aux_kimi += float(aux_kimi.detach().cpu().item()) if torch.is_tensor(aux_kimi) else float(aux_kimi)
                    running_aux_multiplier += aux_multiplier
                    running_scaled_aux_kimi += float(scaled_aux_kimi.detach().cpu().item())
                    running_importance_cv_kimi += stats_kimi.get("importance_cv", 0.0)
                    running_load_cv_kimi += stats_kimi.get("load_cv", 0.0)
                    running_overflow_kimi += stats_kimi.get("overflow_pct", 0.0)

                    reported_loss_gpt = float((loss_gpt * max(1, ACCUM_STEPS)).detach().cpu().item())
                    train_losses_gpt.append(reported_loss_gpt)
                    running_loss_gpt += reported_loss_gpt

                    running_count += 1

                    print("VRAM allocated (MB):", torch.cuda.memory_allocated()/1e6, "VRAM reserved (MB):", torch.cuda.memory_reserved()/1e6)

                    break

                except RuntimeError as e:
                    err = str(e).lower()
                    if "out of memory" in err or "cuda out of memory" in err:
                        attempt += 1
                        print(f"[OOM] Epoch {epoch} Step {step_in_epoch} attempt {attempt}/{OOM_RETRY_LIMIT}. Strategy: ", mp_mgr.info())
                        cleanup_memory()
                        torch.cuda.empty_cache()
                        time.sleep(1.0)
                        if mp_mgr.use_bf16:
                            print(" -> falling back from bfloat16 to float16 + GradScaler and retrying batch.")
                            mp_mgr.enable_fp16_with_scaler()
                            model_kimi.to(device=device, dtype=torch.float16)
                            model_gpt.to(device=device, dtype=torch.float16)
                            continue
                        elif mp_mgr.use_fp16:
                            print(" -> fp16 path OOM; falling back to fp32.")
                            mp_mgr.enable_fp32()
                            model_kimi.to(device=device, dtype=torch.float32)
                            model_gpt.to(device=device, dtype=torch.float32)
                            continue
                        elif mp_mgr.use_fp32:
                            if attempt >= OOM_RETRY_LIMIT:
                                print(" -> persistent OOM in fp32; re-raising error.")
                                raise
                            else:
                                print(" -> retrying after cache clear (fp32).")
                                continue
                        else:
                            mp_mgr.enable_fp32()
                            model_kimi.to(device=device, dtype=torch.float32)
                            model_gpt.to(device=device, dtype=torch.float32)
                            continue
                    else:
                        raise

            # Logging Training Progress
            if global_step % LOG_INTERVAL == 0:
                avg_loss_kimi = running_loss_kimi / max(1, running_count)
                avg_nll_kimi = running_nll_kimi / max(1, running_count)
                avg_aux_kimi = running_aux_kimi / max(1, running_count)
                avg_aux_multiplier = running_aux_multiplier / max(1, running_count)
                avg_scaled_aux_kimi = running_scaled_aux_kimi / max(1, running_count)
                avg_imp_cv_kimi = running_importance_cv_kimi / max(1, running_count)
                avg_load_cv_kimi = running_load_cv_kimi / max(1, running_count)
                avg_overflow_kimi = running_overflow_kimi / max(1, running_count)
                avg_loss_gpt = running_loss_gpt / max(1, running_count)
                alloc = torch.cuda.memory_allocated()/1e6 if device.type=="cuda" else 0.0
                print(f"Epoch {epoch} Step {step_in_epoch} (Global {global_step}):")
                print(f"  NanoKimiK2: avg_loss {avg_loss_kimi:.6f} | avg_nll {avg_nll_kimi:.6f} | avg_aux {avg_aux_kimi:.6f} | aux_mult {avg_aux_multiplier:.6f} | scaled_aux {avg_scaled_aux_kimi:.6f} | imp_cv {avg_imp_cv_kimi:.6f} | load_cv {avg_load_cv_kimi:.6f} | overflow% {avg_overflow_kimi:.4f}")
                print(f"  NanoGPT: avg_loss {avg_loss_gpt:.6f}")
                print(f"  alloc {alloc:.1f} MB | precision {mp_mgr.autocast_dtype()}")

            # Evaluation
            if global_step % EVAL_INTERVAL == 0:
                perplexity_kimi, mean_nll_kimi, mean_aux_kimi, mean_bpc_kimi, kimi_stats = evaluate_kimi(model_kimi, val_loader, device, mp_mgr, encoding)
                perplexity_gpt, mean_nll_gpt, mean_bpc_gpt, gpt_stats = evaluate_gpt(model_gpt, val_loader, device, mp_mgr, encoding)
                val_losses_kimi.append(mean_nll_kimi)
                val_losses_gpt.append(mean_nll_gpt)
                print(f"\nEvaluation at Step {global_step}:")
                print(f"  NanoKimiK2: Perplexity {perplexity_kimi:.2f}, Mean NLL {mean_nll_kimi:.6f}, Mean Aux {mean_aux_kimi:.6f}, BPC {mean_bpc_kimi:.6f}, Stats {kimi_stats}")
                print(f"  NanoGPT: Perplexity {perplexity_gpt:.2f}, Mean NLL {mean_nll_gpt:.6f}, BPC {mean_bpc_gpt:.6f}, Eval Loss {gpt_stats.get('mean_nll', float('inf')):.6f}")

            # Story Generation
            if global_step % STORY_GEN_INTERVAL == 0:
                print(f"\nGenerating sample stories at Step {global_step}:")
                for prompt in eval_prompts:
                    print(f"\nPrompt: {prompt}")
                    completion_kimi = generate_kimi(model_kimi, encoding, prompt, device, mp_mgr, max_tokens=200)
                    print(f"  NanoKimiK2 Completion: {completion_kimi}")
                    completion_gpt = generate_gpt(model_gpt, encoding, prompt, device, mp_mgr, max_tokens=200)
                    print(f"  NanoGPT Completion: {completion_gpt}")

                # Print Training Loss Table
                if global_step % 1000 == 0:
                    print("\nTraining Loss Table:")
                    print("| Segment | NanoKimiK2 Avg Loss | NanoGPT Avg Loss |")
                    print("|---------|---------------------|------------------|")
                    for seg in range(1, (global_step // 1000) + 1):
                        start = (seg - 1) * 1000
                        end = seg * 1000
                        avg_kimi = sum(train_losses_kimi[start:end]) / len(train_losses_kimi[start:end]) if len(train_losses_kimi[start:end]) > 0 else 0.0
                        avg_gpt = sum(train_losses_gpt[start:end]) / len(train_losses_gpt[start:end]) if len(train_losses_gpt[start:end]) > 0 else 0.0
                        print(f"| {start+50}-{end} | {avg_kimi:.6f} | {avg_gpt:.6f} |")  # +50 for mid-segment avg

                    # Print Validation Loss Table
                    print("\nValidation Loss Table:")
                    print("| Segment | NanoKimiK2 Avg Val Loss | NanoGPT Avg Val Loss |")
                    print("|---------|-------------------------|----------------------|")
                    val_segments_kimi = [val_losses_kimi[i:i+5] for i in range(0, len(val_losses_kimi), 5)]  # 5 evals per 1000 steps (200 interval)
                    val_segments_gpt = [val_losses_gpt[i:i+5] for i in range(0, len(val_losses_gpt), 5)]
                    for seg, (kimi_seg, gpt_seg) in enumerate(zip(val_segments_kimi, val_segments_gpt), 1):
                        avg_kimi_val = sum(kimi_seg) / len(kimi_seg) if len(kimi_seg) > 0 else 0.0
                        avg_gpt_val = sum(gpt_seg) / len(gpt_seg) if len(gpt_seg) > 0 else 0.0
                        print(f"| {(seg-1)*1000 + 200}-{seg*1000} | {avg_kimi_val:.6f} | {avg_gpt_val:.6f} |")

# Entry Point
if __name__ == "__main__":
    main()