<a href="https://colab.research.google.com/github/Algocrat/simplified-MoR-mvp/blob/main/MoR_Colab_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Small-Data LM Demo (with optional Mixture-of-Recursions)
**Goal:** Produce meaningful generations even on a **tiny dataset**, then (optionally) toggle on a simplified **MoR** block.

In [None]:
import sys, torch
print("Python:", sys.version.split()[0])
print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name(0))
else:
    print("Tip: In Colab, use Runtime → Change runtime type → GPU")

Python: 3.11.13
PyTorch: 2.6.0+cu124
CUDA available: True
CUDA device: Tesla T4


In [None]:
!pip -q install tiktoken datasets

## Data: Tiny Shakespeare (default)
This is a ~1MB text file that trains quickly and yields coherent completions with a small model.

> You can also try **FineWeb‑Edu (streaming)** later for variety (but Tiny Shakespeare is best for quick wins).

In [None]:
import os, requests, pathlib

DATA_DIR = pathlib.Path("data"); DATA_DIR.mkdir(exist_ok=True)
TS_PATH = DATA_DIR / "tiny_shakespeare.txt"
URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"

if not TS_PATH.exists():
    print("Downloading Tiny Shakespeare...")
    r = requests.get(URL, timeout=30)
    r.raise_for_status()
    TS_PATH.write_text(r.text, encoding="utf-8")
else:
    print("Found existing:", TS_PATH)

text = TS_PATH.read_text(encoding="utf-8")
print("Chars:", len(text))
print("Preview:\n", text[:400])

Downloading Tiny Shakespeare...
Chars: 1115394
Preview:
 First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it 


### Optional: FineWeb‑Edu (stream just a few rows)
Uncomment to sample a few rows and append to the corpus for variety.

In [None]:
# from datasets import load_dataset
# from itertools import islice
# ds_stream = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True)
# tiny_rows = list(islice(ds_stream, 50))
# text += "\n\n" + "\n".join([r.get("text","") for r in tiny_rows])
# print("Augmented corpus length:", len(text))

## Tokenizer & dataset builder
We use **GPT‑2 BPE** via `tiktoken`. We chunk the token stream into fixed-length sequences for train/val splits.

In [None]:
import math, random
import tiktoken
import torch

enc = tiktoken.get_encoding("gpt2")  # 50k BPE
ids = enc.encode(text)

# Train/val split
split = int(0.9 * len(ids))
train_ids = ids[:split]
val_ids = ids[split:]

ctx_len = 256

def make_examples(token_ids, num_examples=2000):
    examples = []
    if len(token_ids) < ctx_len + 1:
        return examples
    for _ in range(num_examples):
        i = random.randint(0, len(token_ids) - ctx_len - 2)
        seq = token_ids[i:i+ctx_len+1]
        x = torch.tensor(seq[:-1], dtype=torch.long)
        y = torch.tensor(seq[1:], dtype=torch.long)
        examples.append((x, y))
    return examples

train_pairs = make_examples(train_ids, num_examples=3000)
val_pairs = make_examples(val_ids, num_examples=300)
print(f"Train pairs: {len(train_pairs)} | Val pairs: {len(val_pairs)} | Vocab≈{enc.n_vocab}")

Train pairs: 3000 | Val pairs: 300 | Vocab≈50257


In [None]:
from torch.utils.data import Dataset, DataLoader

class PairDataset(Dataset):
    def __init__(self, pairs): self.pairs = pairs
    def __len__(self): return len(self.pairs)
    def __getitem__(self, i): return self.pairs[i]

def collate(batch):
    xs, ys = zip(*batch)
    return torch.stack(xs), torch.stack(ys)

train_loader = DataLoader(PairDataset(train_pairs), batch_size=16, shuffle=True, collate_fn=collate)
val_loader   = DataLoader(PairDataset(val_pairs),   batch_size=16, shuffle=False, collate_fn=collate)
print("Batches:", len(train_loader), len(val_loader))

Batches: 188 19


## Small Transformer (baseline)
We build a small GPT‑style block with **positional embeddings**, **Multi-Head Attention**, and **SwiGLU** MLP.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-6):
        super().__init__()
        self.w = nn.Parameter(torch.ones(d)); self.eps = eps
    def forward(self, x): return self.w * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

class SwiGLU(nn.Module):
    def __init__(self, d, h):
        super().__init__()
        self.w1 = nn.Linear(d, h, bias=False)
        self.w2 = nn.Linear(d, h, bias=False)
        self.w3 = nn.Linear(h, d, bias=False)
    def forward(self, x): return self.w3(F.silu(self.w1(x)) * self.w2(x))

class MHA(nn.Module):
    def __init__(self, d, n_heads):
        super().__init__()
        assert d % n_heads == 0
        self.d, self.h, self.dh = d, n_heads, d // n_heads
        self.q = nn.Linear(d, d, bias=False)
        self.k = nn.Linear(d, d, bias=False)
        self.v = nn.Linear(d, d, bias=False)
        self.o = nn.Linear(d, d, bias=False)
    def forward(self, x):
        B, T, D = x.shape
        q = self.q(x).view(B,T,self.h,self.dh).transpose(1,2)
        k = self.k(x).view(B,T,self.h,self.dh).transpose(1,2)
        v = self.v(x).view(B,T,self.h,self.dh).transpose(1,2)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1,2).contiguous().view(B,T,D)
        return self.o(y)

class Block(nn.Module):
    def __init__(self, d, n_heads, mlp_h):
        super().__init__()
        self.n1 = RMSNorm(d); self.attn = MHA(d, n_heads)
        self.n2 = RMSNorm(d); self.mlp  = SwiGLU(d, mlp_h)
    def forward(self, x):
        x = x + self.attn(self.n1(x))
        x = x + self.mlp(self.n2(x))
        return x

class TinyGPT(nn.Module):
    def __init__(self, vocab, d=256, n_layers=4, n_heads=8, mlp_h=768, max_t=2048):
        super().__init__()
        self.tok = nn.Embedding(vocab, d)
        self.pos = nn.Embedding(max_t, d)
        self.blocks = nn.ModuleList([Block(d, n_heads, mlp_h) for _ in range(n_layers)])
        self.norm = RMSNorm(d)
        self.head = nn.Linear(d, vocab, bias=False)
    def forward(self, idx):
        B,T = idx.shape
        x = self.tok(idx) + self.pos(torch.arange(T, device=idx.device)[None,:])
        for b in self.blocks: x = b(x)
        x = self.norm(x)
        return self.head(x)

## Train for a few hundred steps
This should already yield **meaningful** completions on Tiny Shakespeare.

In [None]:
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TinyGPT(vocab=enc.n_vocab, d=256, n_layers=4, n_heads=8, mlp_h=768).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=500)

def run_epoch(loader, train=True):
    model.train(train)
    total = 0.0; n = 0
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), label_smoothing=0.05)
        if train:
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            sched.step()
        total += float(loss.detach().cpu()); n += 1
    return total / max(1,n)

steps = 400
for s in range(1, steps+1):
    train_loss = run_epoch(train_loader, train=True)
    if s % 50 == 0:
        val_loss = run_epoch(val_loader, train=False)
        print(f"step {s:03d}  train_loss={train_loss:.3f}  val_loss={val_loss:.3f}")

step 050  train_loss=1.157  val_loss=8.614
step 100  train_loss=0.818  val_loss=9.776
step 150  train_loss=0.852  val_loss=9.579
step 200  train_loss=0.807  val_loss=9.484
step 250  train_loss=0.782  val_loss=9.467
step 300  train_loss=0.794  val_loss=9.285
step 350  train_loss=0.776  val_loss=9.344
step 400  train_loss=0.922  val_loss=9.124


## Sampling decode (temperature/top‑k + repetition penalty)
Use this instead of greedy to avoid repetitive loops on small models.

In [None]:
@torch.no_grad()
def sample_next_token(logits, seq_ids, temperature=0.9, top_k=50, repetition_penalty=1.1, penalty_ctx=128):
    logits = logits.clone()
    if seq_ids is not None and len(seq_ids) > 0:
        recent = torch.tensor(seq_ids[-penalty_ctx:], dtype=torch.long, device=logits.device)
        logits.index_put_((recent,), logits.index_select(0, recent) / repetition_penalty)
    logits = logits / max(1e-6, float(temperature))
    if top_k is not None and top_k > 0:
        topk_vals, topk_idx = torch.topk(logits, k=min(top_k, logits.numel()))
        mask = torch.full_like(logits, float("-inf"))
        mask.scatter_(0, topk_idx, topk_vals)
        logits = mask
    probs = torch.softmax(logits, dim=-1)
    return int(torch.multinomial(probs, num_samples=1).item())

@torch.no_grad()
def generate_text(model, prompt, max_new_tokens=120, temperature=0.9, top_k=50, repetition_penalty=1.1):
    model.eval()
    dev = next(model.parameters()).device
    import tiktoken; enc = tiktoken.get_encoding("gpt2")
    ids = enc.encode(prompt)
    x = torch.tensor(ids, dtype=torch.long, device=dev)[None,:]
    for _ in range(max_new_tokens):
        logits = model(x[:, -256:])
        next_id = sample_next_token(logits[0, -1], x[0].tolist(), temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)
        x = torch.cat([x, torch.tensor([[next_id]], device=dev)], dim=1)
    return enc.decode(x[0].tolist())

print(generate_text(model, "ROMEO:", max_new_tokens=120))
print("----")
print(generate_text(model, "Once upon a time, in Verona,", max_new_tokens=120))

ROMEO:
But that are thy true!
O, dear love goes hard; and giBear from mine ears,
I enter me up blest.

FRIARLook, to teach my br marry,door of wilt thou forth thy birth.

BENVOLIO:
O noble kinsman steeled:
O, have made me, nurse!
Not so did taste else a mad disdaineth in the sun.
Sometimeain that fellow is wings of night
To take the thing, could he further sleep for the Capulet,
Though sometimes royal
----
Once upon a time, in Verona, till comes the sun sets oHath
Doth say, this is come to die.
FRIARENCE:
Upon my life, good C prick.

ESTER:
I know not for my body! belike it well-women:
You me much before:
To have your hands. I pray you will hold me, excepts,
But shall not had a set on thee, than yours.

QUEEN ELIZABETH:
And if you suppose join'd for my word;
Or delay hath so young of his majesty and


## (Optional) MoR sanity swap
Wrap the last block with a one-recursion MoR wrapper (no routing) to show where MoR would plug in.

In [None]:
import torch.nn as nn

class MoRWrapper(nn.Module):
    def __init__(self, block: Block):
        super().__init__()
        self.block = block
    def forward(self, x):
        # Future work: add expert-choice routing and recursion-wise KV here.
        return self.block(x)

model.blocks[-1] = MoRWrapper(model.blocks[-1]).to(device)
print("MoR wrapper installed on last block. Try generation again.")

MoR wrapper installed on last block. Try generation again.
