In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sentencepiece as spm
from datasets import load_dataset
import random
import os
from IPython.display import display
import ipywidgets as widgets
import uuid

In [None]:
# Configuration & Setup
# =======================
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(1337)
random.seed(1337)

In [None]:
# Load Dataset
# =======================
# Load 10% of TinyStories dataset
dataset = load_dataset("roneneldan/TinyStories", split="train[:10%]")
texts = [item['text'] for item in dataset]

# Save texts to file for tokenizer training
with open("corpus.txt", "w") as f:
    f.write("\n".join(texts))

In [None]:
# Train SentencePiece Tokenizer
# =======================
spm.SentencePieceTrainer.train(
    input='corpus.txt', model_prefix='tokenizer', vocab_size=4096, model_type='bpe'
)
sp = spm.SentencePieceProcessor(model_file='tokenizer.model')
vocab_size = sp.get_piece_size()

In [None]:
# Tokenize Dataset
# =======================
tokens = []
for text in texts:
    tokens.extend(sp.encode(text))
data = torch.tensor(tokens, dtype=torch.long)

In [None]:
# Train/Validation Split
# =======================
n = int(0.9 * len(data))
train_data, val_data = data[:n], data[n:]

In [None]:
# Batch Preparation
# =======================
block_size = 64
batch_size = 32

def get_batch(split):
    src = train_data if split == 'train' else val_data
    ix = torch.randint(0, len(src) - block_size - 1, (batch_size,))
    x = torch.stack([src[i:i+block_size] for i in ix])
    y = torch.stack([src[i+1:i+1+block_size] for i in ix])
    return x.to(device), y.to(device)

In [None]:
# Activation Functions
# =======================
class SwiGLU(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=-1)
        return F.silu(x1) * x2

In [None]:
# Expert and MoE Layer
# =======================
class Expert(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, 2 * dim),
            SwiGLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, x):
        return self.net(x)

class MoELayer(nn.Module):
    def __init__(self, dim, num_experts=4):
        super().__init__()
        self.experts = nn.ModuleList([Expert(dim) for _ in range(num_experts)])
        self.gate = nn.Linear(dim, num_experts)

    def forward(self, x):
        B, T, D = x.shape
        weights = F.softmax(self.gate(x), dim=-1)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=0)
        weights = weights.permute(2, 0, 1).unsqueeze(-1)
        out = (weights * expert_outputs).sum(dim=0)
        return out

In [None]:
# Latent Attention Layer
# =======================
class LatentAttention(nn.Module):
    def __init__(self, dim, latent_dim=64, num_latents=16):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(1, num_latents, latent_dim))
        self.to_q = nn.Linear(latent_dim, latent_dim)
        self.to_kv = nn.Linear(dim, 2 * latent_dim)
        self.to_out = nn.Linear(latent_dim, dim)

    def forward(self, x):
        B, T, _ = x.shape
        latents = self.latents.expand(B, -1, -1)
        q = self.to_q(latents)
        k, v = self.to_kv(x).chunk(2, dim=-1)
        attn = (q @ k.transpose(-2, -1)) / (k.shape[-1] ** 0.5)
        attn = F.softmax(attn, dim=-1)
        out = attn @ v
        out = self.to_out(out).mean(dim=1, keepdim=True)
        return x + out.expand(-1, T, -1)

In [None]:
# Transformer Model
# =======================
class NanoKimiTransformer(nn.Module):
    def __init__(self, vocab_size, dim=256, depth=6):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, dim)
        self.pos_emb = nn.Embedding(block_size, dim)
        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(dim),
                LatentAttention(dim),
                nn.LayerNorm(dim),
                MoELayer(dim)
            ) for _ in range(depth)
        ])
        self.ln = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, vocab_size)

    def forward(self, x):
        B, T = x.shape
        x = self.token_emb(x) + self.pos_emb(torch.arange(T, device=x.device))
        for block in self.blocks:
            x = x + block(x)
        return self.head(self.ln(x))

In [None]:
# Muon Optimizer
# =======================
class Muon(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), weight_decay=1e-2):
        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            lr = group["lr"]
            beta1, beta2 = group["betas"]
            weight_decay = group["weight_decay"]

            for p in group["params"]:
                if p.grad is None:
                    continue
                g = p.grad
                state = self.state[p]

                if len(state) == 0:
                    state["exp_avg"] = torch.zeros_like(p)
                    state["exp_avg_diff"] = torch.zeros_like(p)

                exp_avg, exp_avg_diff = state["exp_avg"], state["exp_avg_diff"]
                update = p - exp_avg_diff
                exp_avg.mul_(beta1).add_(g, alpha=1 - beta1)
                exp_avg_diff.mul_(beta2).add_(update, alpha=1 - beta2)

                if weight_decay > 0:
                    p.data.mul_(1 - lr * weight_decay)

                p.add_(exp_avg_diff, alpha=-lr)

In [None]:
# Training Utilities
# =======================
model = NanoKimiTransformer(vocab_size).to(device)
optimizer = Muon(model.parameters(), lr=1e-3, betas=(0.9, 0.99), weight_decay=1e-2)

def compute_val_loss(model, val_data):
    model.eval()
    with torch.no_grad():
        x, y = get_batch('val')
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
    model.train()
    return loss.item()

def train(model, optimizer, data, steps=5000):
    model.train()
    for step in range(steps):
        x, y = get_batch('train')
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step % 100 == 0:
            val_loss = compute_val_loss(model, val_data)
            print(f"Step {step}, Train Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}")

In [None]:
# Train the Model
# =======================
train(model, optimizer, train_data, steps=5000)

In [None]:
# Text Generation
# =======================
def generate_text(prompt, steps=50, temperature=0.8, top_k=50):
    model.eval()
    tokens = torch.tensor([sp.encode(prompt)], dtype=torch.long).to(device)
    for _ in range(steps):
        logits = model(tokens)
        logits = logits[:, -1, :] / temperature
        top_k_logits, top_k_indices = torch.topk(logits, top_k, dim=-1)
        probs = F.softmax(top_k_logits, dim=-1)
        next_token = torch.multinomial(probs, 1)
        next_token = top_k_indices.gather(-1, next_token)
        tokens = torch.cat([tokens, next_token], dim=1)
    return sp.decode(tokens[0].tolist())

In [None]:
# Interactive Prompt Box
# =======================
def on_generate_click(b):
    prompt = prompt_box.value
    print(f"\nPrompt: {prompt}")
    print("Generated:", generate_text(prompt, steps=50, temperature=0.8, top_k=50))

prompt_box = widgets.Text(value="Once upon a time", description="Prompt:")
generate_button = widgets.Button(description="Generate")
generate_button.on_click(on_generate_click)
display(prompt_box, generate_button)