In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import matplotlib.pyplot as plt
import time
import psutil
import numpy as np

In [None]:
# Hyperparameters & Device Setup
# ------------------------------
batch_size = 16
block_size = 32
max_iters = 10000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0

torch.manual_seed(1337)

In [None]:
# Data Preparation
# ------------------------------
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
chars = sorted(set(text))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data_ = train_data if split=='train' else val_data
    ix = torch.randint(len(data_) - block_size, (batch_size,))
    x = torch.stack([data_[i:i+block_size] for i in ix])
    y = torch.stack([data_[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ('train', 'val'):
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean().item()
    model.train()
    return out

In [None]:
# Model Components
# ------------------------------

# Activation Functions
# -------------------
class SwiGLU(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=-1)
        return F.silu(x1) * x2

# Expert Layer for Mixture of Experts
# -----------------------------------
class Expert(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, 2*dim)
        self.act = SwiGLU()
        self.fc2 = nn.Linear(dim, dim)
    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

# Mixture of Experts (MoE) Layer
# -------------------------------
class MoELayer(nn.Module):
    def __init__(self, dim, num_experts=4):
        super().__init__()
        self.experts = nn.ModuleList([Expert(dim) for _ in range(num_experts)])
        self.gate = nn.Linear(dim, num_experts)
    def forward(self, x):
        scores = F.softmax(self.gate(x), dim=-1)
        expert_outs = torch.stack([e(x) for e in self.experts], dim=-1)
        scores = scores.unsqueeze(-2)
        return (expert_outs * scores).sum(-1)

# Standard Feedforward Layer
# --------------------------
class FeedForward(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, 4*dim),
            nn.ReLU(),
            nn.Linear(4*dim, dim)
        )
    def forward(self, x):
        return self.net(x)

# Self-Attention Head
# -------------------
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        B,T,C = x.size()
        k = self.key(x); q = self.query(x)
        wei = q @ k.transpose(-2,-1) * (C**-0.5)
        wei = wei.masked_fill(self.tril[:T,:T]==0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        return wei @ v

# Multi-Head Attention Layer
# --------------------------
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        return self.dropout(self.proj(torch.cat([h(x) for h in self.heads], dim=-1)))

# Transformer Block (Attention + FeedForward/MoE)
# -----------------------------------------------
class Block(nn.Module):
    def __init__(self, n_embd, n_head, use_moe):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = MoELayer(n_embd) if use_moe else FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# Full Language Model
# -------------------
class LanguageModel(nn.Module):
    def __init__(self, use_moe=False):
        super().__init__()
        # Embedding Layers
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        
        # Transformer Blocks
        self.blocks = nn.Sequential(*[Block(n_embd, n_head, use_moe) for _ in range(n_layer)])
        
        # LayerNorm + Output head
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
    def forward(self, idx, targets=None):
        B,T = idx.size()
        
        # Token and Position Embeddings
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb
        
        # Transformer Forward Pass
        x = self.blocks(x)
        x = self.ln_f(x)
        
        # Output logits
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
        return logits, loss


In [None]:
# Muon(TrueMuon) Optimizer and Utility Functions
# -----------------------------------

# Matrix Sign Function (for TrueMuon optimizer)
# ---------------------------------------------
def matrix_sign(G, steps=5, eps=1e-7):
    """
    Computes an approximation of the matrix sign function.
    Used internally by the TrueMuon optimizer for hidden parameters.
    
    Args:
        G (torch.Tensor): Input 2D matrix.
        steps (int): Number of iterative steps.
        eps (float): Small epsilon to avoid division by zero.
    
    Returns:
        torch.Tensor: Matrix sign approximation of G.
    """
    assert G.ndim == 2
    a, b, c = 3.4445, -4.7750, 2.0315
    X = G / (G.norm() + eps)
    if G.size(0) > G.size(1):
        X = X.T
    for _ in range(steps):
        A = X @ X.T
        B = b * A + c * (A @ A)
        X = a * X + B @ X
    if G.size(0) > G.size(1):
        X = X.T
    return X

# TrueMuon Optimizer
# -----------------
class TrueMuon(torch.optim.Optimizer):
    """
    Custom optimizer targeting hidden layers using a matrix sign update.
    Combines momentum with weight decay for stabilizing training.
    """
    def __init__(self, params, lr=1e-2, beta=0.9, weight_decay=0.01):
        defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = closure() if closure else None
        for group in self.param_groups:
            beta = group['beta']; lr = group['lr']; wd = group['weight_decay']
            for p in group['params']:
                if p.grad is None or p.grad.ndim < 2:
                    continue
                grad = p.grad
                state = self.state[p]
                if 'momentum' not in state:
                    state['momentum'] = torch.zeros_like(p)
                m = state['momentum']
                m.mul_(beta).add_(grad, alpha=(1-beta))
                X = beta * m + grad
                O = matrix_sign(X)
                if wd != 0:
                    O = O.add(p, alpha=wd)
                p.add_(O, alpha=-lr)
        return loss

# Optimizer Builder
# -----------------
def build_optimizers(model):
    """
    Splits model parameters into two groups:
    - Hidden 2D parameters (using TrueMuon optimizer)
    - Other parameters (using AdamW)
    
    Args:
        model (nn.Module): The language model to optimize.
    
    Returns:
        tuple: (AdamW optimizer, TrueMuon optimizer)
    """
    hidden = [p for n, p in model.named_parameters() if p.ndim >= 2 and 'blocks' in n]
    hidden_ids = set(id(p) for p in hidden)
    others = [p for n, p in model.named_parameters() if id(p) not in hidden_ids]
    adamw = torch.optim.AdamW(others, lr=learning_rate)
    muon = TrueMuon(hidden, lr=1e-2, beta=0.9, weight_decay=0.01)
    return adamw, muon

# Perplexity Calculation
# ----------------------
def calculate_perplexity(loss):
    """
    Converts cross-entropy loss to perplexity metric.
    
    Args:
        loss (float): Cross-entropy loss.
    
    Returns:
        float: Perplexity.
    """
    return math.exp(loss)

# Inference Timing Utility
# ------------------------
@torch.no_grad()
def measure_inference_time(model, input_data, n_trials=100):
    """
    Measures average inference time for a given model and input.
    
    Args:
        model (nn.Module): Model to evaluate.
        input_data (torch.Tensor): Input tensor for inference.
        n_trials (int): Number of trials to average.
    
    Returns:
        float: Average inference time in seconds.
    """
    model.eval()
    times = []
    for _ in range(n_trials):
        start = time.time()
        model(input_data)
        times.append(time.time() - start)
    return np.mean(times)

# Memory Usage Utility
# --------------------
def get_memory_usage():
    """
    Returns the current memory usage of the process in MB.
    
    Returns:
        float: Memory usage in megabytes.
    """
    process = psutil.Process()
    return process.memory_info().rss / (1024 ** 2)  # in MB


In [None]:
# Training & Evaluation
# ------------------------------
models = {
    "NanoGPT": (LanguageModel(use_moe=False).to(device), "adamw"),
    "NanoKimi": (LanguageModel(use_moe=True).to(device), "muon")
}

results = {}
test_input = torch.randint(0, vocab_size, (batch_size, block_size), device=device)

for name, (model, opt_type) in models.items():
    print(f"\nTraining {name} with {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters")
    if opt_type == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        muon = None
    else:
        optimizer, muon = build_optimizers(model)

    best_val = float('inf')
    train_times = []
    memory_usages = []
    
    for it in range(max_iters):
        start_time = time.time()
        memory_before = get_memory_usage()
        
        if it % eval_interval == 0 or it == max_iters-1:
            losses = estimate_loss(model)
            print(f"{name} | step {it}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
            best_val = min(best_val, losses['val'])

        xb, yb = get_batch('train')
        _, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        if muon:
            muon.zero_grad()
        loss.backward()
        optimizer.step()
        if muon:
            muon.step()
            
        train_times.append(time.time() - start_time)
        memory_usages.append(get_memory_usage() - memory_before)

    inference_time = measure_inference_time(model, test_input)
    param_count = sum(p.numel() for p in model.parameters()) / 1e6  # in millions
    efficiency = calculate_perplexity(best_val) / param_count
    
    results[name] = {
        "val_loss": best_val,
        "perplexity": calculate_perplexity(best_val),
        "avg_train_time": np.mean(train_times),
        "avg_memory_usage": np.mean(memory_usages),
        "inference_time": inference_time,
        "param_efficiency": efficiency,
        "param_count": param_count
    }

In [None]:
# Results Summary & Visualization
# ------------------------------
print("\n--- Comparison Report ---")
if results['NanoKimi']['val_loss'] < results['NanoGPT']['val_loss']:
    print(f"NanoKimi outperformed NanoGPT in terms of validation loss.")
else:
    print("NanoGPT performed better or equally in terms of validation loss.")

for model_name in results:
    print(f"\n{model_name}")
    print(f"  - Validation Loss        : {results[model_name]['val_loss']:.4f}")
    print(f"  - Perplexity            : {results[model_name]['perplexity']:.2f}")
    print(f"  - Avg Training Time (s) : {results[model_name]['avg_train_time']:.4f}")
    print(f"  - Avg Memory Usage (MB) : {results[model_name]['avg_memory_usage']:.2f}")
    print(f"  - Inference Time (s)    : {results[model_name]['inference_time']:.4f}")
    print(f"  - Param Efficiency      : {results[model_name]['param_efficiency']:.2f}")
    print(f"  - Parameter Count (M)   : {results[model_name]['param_count']:.2f}")

# Plot
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(8, 12))
metrics = [
    ('val_loss', 'Validation Loss', 'blue'),
    ('avg_train_time', 'Average Training Time (s)', 'green'),
    ('avg_memory_usage', 'Average Memory Usage (MB)', 'red'),
    ('inference_time', 'Inference Time (s)', 'purple')
]

for ax, (metric, title, color) in zip([ax1, ax2, ax3, ax4], metrics):
    ax.bar(results.keys(), [results[m][metric] for m in results], color=color)
    ax.set_title(title)
    ax.set_ylabel(metric.replace('_', ' ').title())
    ax.set_xlabel("Model")
    ax.grid(True, linestyle='--', alpha=0.6)

plt.tight_layout()
plt.savefig("comparison_graph.png")
print("Saved comparison graphs to comparison_graph.png")