In [None]:
!pip install datasets torch matplotlib tqdm --quiet

In [None]:
from datasets import load_dataset
import torch
import re

# 100k TinyStories
dataset = load_dataset("roneneldan/TinyStories", split="train[:100000]")
texts = dataset["text"]
full_text = "\n\n".join(texts)

# Word-level tokenizer
def word_tokenize(text):
    return re.findall(r"\w+|[^\w\s]", text.lower())

words = word_tokenize(full_text)
vocab = sorted(set(words))
stoi = {w: i for i, w in enumerate(vocab)}
itos = {i: w for w, i in stoi.items()}
vocab_size = len(vocab)

def encode(text): return [stoi[w] for w in word_tokenize(text) if w in stoi]
def decode(tokens): return ' '.join([itos[i] for i in tokens])

data = torch.tensor(encode(full_text), dtype=torch.long)
split = int(0.9 * len(data))
train_data, val_data = data[:split], data[split:]

def get_batch(split, block_size=64, batch_size=32):
    data_split = train_data if split == 'train' else val_data
    ix = torch.randint(len(data_split) - block_size, (batch_size,))
    x = torch.stack([data_split[i:i+block_size] for i in ix])
    y = torch.stack([data_split[i+1:i+block_size+1] for i in ix])
    return x, y


In [None]:
import torch.nn as nn
import torch.nn.functional as F

class SwiGLU(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=-1)
        return x1 * F.silu(x2)

class MuonClip(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, weight_decay=0.01, clip_thresh=1.0):
        defaults = dict(lr=lr, weight_decay=weight_decay, clip_thresh=clip_thresh)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None: continue
                grad = p.grad
                grad_norm = grad.norm()
                if grad_norm > group['clip_thresh']:
                    grad = grad * (group['clip_thresh'] / grad_norm)
                if group['weight_decay'] != 0:
                    grad += group['weight_decay'] * p.data
                p.data.add_(grad, alpha=-group['lr'])

class MoELayer(nn.Module):
    def __init__(self, d_model, d_ff, n_experts=4, k=2):
        super().__init__()
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff * 2),
                SwiGLU(),
                nn.Linear(d_ff, d_model)
            ) for _ in range(n_experts)
        ])
        self.router = nn.Linear(d_model, n_experts)
        self.k = k

    def forward(self, x):
        B, T, C = x.shape
        scores = self.router(x)
        topk = torch.topk(scores, self.k, dim=-1)
        out = torch.zeros_like(x)
        for i in range(self.k):
            expert_idx = topk.indices[..., i]
            weight = torch.softmax(topk.values[..., i], dim=-1).unsqueeze(-1)
            for j, expert in enumerate(self.experts):
                mask = (expert_idx == j).float().unsqueeze(-1)
                out += weight * mask * expert(x)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.moe = MoELayer(d_model, d_ff)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        B, T, C = x.shape
        mask = torch.triu(torch.ones(T, T), diagonal=1).bool().to(x.device)
        attn_out, _ = self.attn(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_out))
        moe_out = self.moe(x)
        return self.norm2(x + self.dropout(moe_out))

class NanoKimi(nn.Module):
    def __init__(self, vocab_size, block_size=256, d_model=512, n_heads=8, d_ff=1024, n_layers=6, dropout=0.1):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(block_size, d_model)
        self.dropout = nn.Dropout(dropout)
        self.blocks = nn.Sequential(*[TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.output_head = nn.Linear(d_model, vocab_size)
        self.block_size = block_size

    def forward(self, x):
        B, T = x.size()
        token_emb = self.token_embed(x)
        pos = self.pos_embed(torch.arange(T, device=x.device))
        x = self.dropout(token_emb + pos)
        x = self.blocks(x)
        x = self.norm(x)
        return self.output_head(x)


In [None]:
from tqdm import tqdm
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
block_size = 256
batch_size = 32
steps = 2000
eval_interval = 100

model = NanoKimi(vocab_size, block_size=block_size).to(device)
optimizer = MuonClip(model.parameters(), lr=1e-3, weight_decay=1e-2)

losses = []

for step in tqdm(range(steps)):
    model.train()
    xb, yb = get_batch("train", block_size, batch_size)
    xb, yb = xb.to(device), yb.to(device)

    logits = model(xb)
    loss = F.cross_entropy(logits.view(-1, vocab_size), yb.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % eval_interval == 0:
        print(f"Step {step} | Loss: {loss.item():.4f}")
        losses.append(loss.item())

# Training loss Plot
plt.plot(range(0, steps, eval_interval), losses)
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.grid(True)
plt.show()


In [None]:
def generate(model, prompt, max_new_tokens=100, temperature=1.0, top_k=20, repetition_penalty=1.2):
    model.eval()
    input_ids = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0).to(device)

    for _ in range(max_new_tokens):
        with torch.no_grad():
            logits = model(input_ids[:, -block_size:])
            logits = logits[:, -1, :] / temperature
            for token_id in set(input_ids[0].tolist()):
                logits[0, token_id] /= repetition_penalty
            top_k_logits, top_k_indices = torch.topk(logits, k=top_k, dim=-1)
            probs = F.softmax(top_k_logits, dim=-1)
            next_token = top_k_indices.gather(-1, torch.multinomial(probs, 1))
            input_ids = torch.cat([input_ids, next_token], dim=1)

    return decode(input_ids[0].tolist())

print(generate(model, "Once upon a time", max_new_tokens=100))
