In [None]:
!pip install torch torchvision



In [None]:
import math
import os
import io
import sys
import time
import urllib.request
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
class cfg:
  use_moe = True # set False for dense baseline
  n_layer = 6
  n_head = 8
  d_model = 512
  d_mlp = 2048  # dense MLP hidden
  vocab_limit = None # None = use all chars
  block_size = 256 #sequence length
  batch_size = 24 #tokens per batch = batch_size * block_size
  grad_accum_steps = 2 #effective batch = batch_size * grad_accum_steps
  max_steps = 400 #quick demo ; increase for better loss
  lr = 3e-4
  weight_decay = 0.1
  warmup_steps = 0.1
  compile_model = False #torch.compile may slow first step
  dropout = 0.0

  #MOE specifics
  n_experts = 4
  top_k = 1 #switch-style
  capacity_factor = 1.25 #per-expert token capacity
  load_balance_coef = 0.01
  zloss_coef = 0.001

  #precision + device
  device = "cuda" if torch.cuda.is_available() else "cpu"
  dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
  seed = 42


In [None]:
cfg = cfg()

In [None]:
torch.manual_seed(cfg.seed)

<torch._C.Generator at 0x7948929bfcb0>

In [None]:
torch.cuda.manual_seed_all(cfg.seed)

In [None]:
def load_tinyshakespeare():
    url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    try:
        txt = urllib.request.urlopen(url, timeout=10).read().decode("utf-8")
    except Exception:
        # Fallback tiny corpus
        txt = (
            "To be, or not to be, that is the question:\n"
            "Whether 'tis nobler in the mind to suffer\n"
            "The slings and arrows of outrageous fortune,\n"
            "Or to take arms against a sea of troubles\n"
            "And by opposing end them.\n"
        ) * 200
    return txt

In [None]:
text = load_tinyshakespeare()

In [None]:
chars = sorted(list(set(text)))

In [None]:
chars

['\n',
 ' ',
 '!',
 '$',
 '&',
 "'",
 ',',
 '-',
 '.',
 '3',
 ':',
 ';',
 '?',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z']

In [None]:
if cfg.vocab_limit :
  chars = chars[: cfg.vocab_limit]

In [None]:
stoi = {ch : i for i ,ch in enumerate(chars)}

itos = { i : ch for ch,i in stoi.items()}

In [None]:
vocab_size = len(chars)

In [None]:
vocab_size

65

In [None]:
def encode(s):
  return torch.tensor([stoi[c] for c in s if c in stoi] , dtype = torch.long)

In [None]:
def decode(t):
  return ''.join([itos[int(i)] for i in t])

In [None]:
data = encode(text)

In [None]:
data

tensor([18, 47, 56,  ..., 45,  8,  0])

In [None]:
n = int(0.9 * len(data))

In [None]:
n

1003854

In [None]:
train_data , val_data = data[:n],data[n:]

In [None]:
def get_batch(split):
  d = train_data if split == "train" else val_data
  ix = torch.randint(len(d) - cfg.block_size -1 , (cfg.block_size,))
  x = torch.stack([d[i : i+cfg.block_size] for i in ix])
  y = torch.stack([d[i+1 : i+1+cfg.block_size] for i in ix])
  return x.to(cfg.device) , y.to(cfg.device)

In [None]:
class RMSNorm(nn.Module):
  def __init__(self, d, eps = 1e-5):
    super().__init__()
    self.weight = nn.Parameter(torch.ones(d))
    self.eps = eps

  def forward(self,x):
    norm = x.norm(dim = -1 ,keepdim = True)*(1.0/ math.sqrt(x.shape[-1]))
    return self.weight * (x/ (norm + self.eps))

In [None]:
class CausalSelfAttention(nn.Module):
  def __init__(self,d_model,n_head,dropout = 0.0):
    super().__init__()
    assert d_model % n_head == 0
    self.n_head = n_head
    self.head_dim = d_model // n_head
    self.qkv = nn.Linear(d_model , 3*d_model,bias = False)
    self.proj = nn.Linear(d_model,d_model,bias = False)
    self.dropout = nn.Dropout(dropout)

  def forward(self,x):
    B,T,C = x.shape
    qkv = self.qkv(x).view(B,T,3,self.n_head,self.head_dim).transpose(1,2)
    q,k,v = qkv[:,0],qkv[:,1],qkv[:,2]
    y = F.scaled_dot_product_attention(
        q.transpose(1,2),k.transpose(1,2),v.transpose(1,2),attn_mask= None,
    )
    y = y.transpose(1,2).contiguous().view(B,T,C)
    y = self.proj(y)
    return y

In [None]:
class ExpertMLP(nn.Module):
  def __init__(self, d_model, d_hidden):
    super().__init__() # Add this line
    self.fc1 = nn.Linear(d_model,d_hidden)
    self.fc2 = nn.Linear(d_hidden,d_model)

  def forward(self,x):
    return self.fc2(F.gelu(self.fc1(x)))

In [None]:
class Top1Router(nn.Module):
  def __init__(self,d_model,n_experts = 5):
    super().__init__() # Add this line
    self.proj = nn.Linear(d_model,n_experts)
    self.n_experts = n_experts

  def forward(self,x):
    B,T,D = x.shape
    h = x.reshape(B*T,D)
    logits = self.proj(h)
    probs = F.softmax(logits,dim=-1)
    top1 = probs.argmax(dim=-1)
    w = probs.gather(1,top1.unsqueeze(-1)) # Ensure .squeeze(1) is removed
    #load balance proxy
    with torch.no_grad():
      assign = F.one_hot(top1,num_classes= self.n_experts).float()
    importance = probs.mean(dim=0)
    load = assign.mean(dim=0)
    lb_loss = self.n_experts * torch.sum(importance * load)
    z_loss =(torch.logsumexp(logits, dim=-1)**2).mean()
    return top1, w, lb_loss, z_loss

In [None]:
class MoEMLP(nn.Module):
    def __init__(self, d_model, d_hidden, n_experts, capacity_factor=1.25,
                 lbl_coef=0.01, zloss_coef=0.0, top_k=1):
        super().__init__()
        assert top_k == 1, "This minimal demo implements top-1 routing (Switch-style)."
        self.router = Top1Router(d_model, n_experts)
        self.experts = nn.ModuleList([ExpertMLP(d_model, d_hidden) for _ in range(n_experts)])
        self.n_experts = n_experts
        self.capacity_factor = capacity_factor
        self.lbl_coef = lbl_coef
        self.zloss_coef = zloss_coef

    def forward(self, x):
        B,T,D = x.shape
        N = B*T
        top1, w, lb_loss, z_loss = self.router(x)
        cap = int(self.capacity_factor * (N / self.n_experts) + 1)

        flat_x = x.reshape(N, D)
        out = torch.zeros_like(flat_x)

        for e in range(self.n_experts):
            # Get indices of tokens routed to this expert, flattened
            idx = (top1 == e).nonzero(as_tuple=False).flatten() # Flatten the indices
            if idx.numel() == 0:
                continue
            if idx.numel() > cap:  # drop overflow
                idx = idx[:cap]
            xe = flat_x[idx]
            ye = self.experts[e](xe)

            we = w[idx].reshape(-1, 1) # Explicitly reshape to ensure shape is [num_tokens, 1]
            out[idx] = ye * we # Assignment back to the flattened output tensor

        y = out.reshape(B, T, D)
        aux = self.lbl_coef * lb_loss + self.zloss_coef * z_loss
        return y, aux

In [None]:
class DenseMLP(nn.Module):
    def __init__(self, d_model, d_hidden):
        super().__init__() # Add this line
        self.ff = ExpertMLP(d_model, d_hidden)
    def forward(self, x):
        y = self.ff(x)
        aux = x.new_tensor(0.0)
        return y, aux

In [None]:
class Block(nn.Module):
    def __init__(self, d_model, n_head, d_mlp, use_moe, n_experts, capacity_factor, lbl_coef, zloss_coef):
        super().__init__()
        self.ln1 = RMSNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_head)
        self.ln2 = RMSNorm(d_model)
        if use_moe:
            self.mlp = MoEMLP(d_model, d_mlp, n_experts, capacity_factor, lbl_coef, zloss_coef, top_k=1)
        else:
            self.mlp = DenseMLP(d_model, d_mlp)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        y, aux = self.mlp(self.ln2(x))
        x = x + y
        return x, aux

In [None]:
class TinyGPT(nn.Module):
    def __init__(self, cfg, vocab_size):
        super().__init__()
        self.cfg = cfg
        self.tok_emb = nn.Embedding(vocab_size, cfg.d_model)
        self.pos_emb = nn.Parameter(torch.zeros(1, cfg.block_size, cfg.d_model))
        self.drop = nn.Dropout(cfg.dropout)

        blocks = []
        for i in range(cfg.n_layer):
            use_moe_this = cfg.use_moe and (i % 2 == 1)  # MoE every other block
            blocks.append(Block(cfg.d_model, cfg.n_head, cfg.d_mlp, use_moe_this,
                                cfg.n_experts, cfg.capacity_factor, cfg.load_balance_coef, cfg.zloss_coef))
        self.blocks = nn.ModuleList(blocks)
        self.ln_f = RMSNorm(cfg.d_model)
        self.head = nn.Linear(cfg.d_model, vocab_size, bias=False)

    def forward(self, idx, targets=None):
        B,T = idx.shape
        assert T <= self.cfg.block_size
        x = self.tok_emb(idx) + self.pos_emb[:, :T, :]
        x = self.drop(x)

        aux_total = x.new_tensor(0.0)
        for blk in self.blocks:
            x, aux = blk(x)
            aux_total = aux_total + aux

        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1)
            )
        return logits, loss, aux_total

    def generate(self, idx, max_new_tokens=200, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.cfg.block_size:]
            with torch.no_grad():
                logits, _, _ = self.forward(idx_cond)
                logits = logits[:, -1, :] / max(1e-6, temperature)
                if top_k is not None:
                    v, _ = torch.topk(logits, top_k)
                    logits[logits < v[:, [-1]]] = -float('inf')
                probs = F.softmax(logits, dim=-1)
                next_id = torch.multinomial(probs, num_samples=1)
                idx = torch.cat((idx, next_id), dim=1)
        return idx

In [None]:
model = TinyGPT(cfg, vocab_size).to(cfg.device)

In [None]:
model

TinyGPT(
  (tok_emb): Embedding(65, 512)
  (drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (ln1): RMSNorm()
      (attn): CausalSelfAttention(
        (qkv): Linear(in_features=512, out_features=1536, bias=False)
        (proj): Linear(in_features=512, out_features=512, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ln2): RMSNorm()
      (mlp): DenseMLP(
        (ff): ExpertMLP(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
        )
      )
    )
    (1): Block(
      (ln1): RMSNorm()
      (attn): CausalSelfAttention(
        (qkv): Linear(in_features=512, out_features=1536, bias=False)
        (proj): Linear(in_features=512, out_features=512, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ln2): RMSNorm()
      (mlp): MoEMLP(
        (router): Top1Router(
          (proj): Linear(in_features=512, 

In [None]:
scaler = torch.cuda.amp.GradScaler(enabled = (cfg.dtype == torch.float16))

  scaler = torch.cuda.amp.GradScaler(enabled = (cfg.dtype == torch.float16))


In [None]:
model_dtype = torch.bfloat16 if cfg.dtype==torch.bfloat16 else torch.float16


In [None]:
model = model.to(dtype=model_dtype)


In [None]:

if cfg.compile_model and hasattr(torch, "compile"):
    model = torch.compile(model)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr = cfg.lr,weight_decay=cfg.weight_decay,betas = (0.9,0.95))

In [None]:
optimizer

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    decoupled_weight_decay: True
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0003
    maximize: False
    weight_decay: 0.1
)

In [None]:


def count_params(m):
    return sum(p.numel() for p in m.parameters())

In [None]:


print(f"Vocab size: {vocab_size}")
print(f"Total params: {count_params(model)/1e6:.2f}M")
print(f"Using MoE: {cfg.use_moe}, n_experts={cfg.n_experts} (active experts/token: {1 if cfg.use_moe else 'N/A'})")
print(f"Device: {cfg.device}, dtype: {model_dtype}")

Vocab size: 65
Total params: 38.00M
Using MoE: True, n_experts=4 (active experts/token: 1)
Device: cpu, dtype: torch.float16


In [None]:
def cosine_lr(step, max_steps, base_lr, warmup):
    if step < warmup:
        return base_lr * (step+1) / warmup
    progress = (step - warmup) / max(1, (max_steps - warmup))
    return base_lr * 0.5 * (1 + math.cos(math.pi * progress))


In [None]:
@torch.no_grad()
def estimate_loss(iters=20):
    model.eval()
    losses, auxes = [], []
    for split in ["train", "val"]:
        ltot, atot = 0.0, 0.0
        for _ in range(iters):
            xb, yb = get_batch(split)
            with torch.autocast(device_type="cuda" if cfg.device=="cuda" else "cpu", dtype=model_dtype):
                _, loss, aux = model(xb, yb)
            ltot += loss.item()
            atot += aux.item()
        losses.append(ltot/iters)
        auxes.append(atot/iters)
    model.train()
    return {"train": losses[0], "val": losses[1], "aux_train": auxes[0], "aux_val": auxes[1]}


In [None]:


model.train()

TinyGPT(
  (tok_emb): Embedding(65, 512)
  (drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (ln1): RMSNorm()
      (attn): CausalSelfAttention(
        (qkv): Linear(in_features=512, out_features=1536, bias=False)
        (proj): Linear(in_features=512, out_features=512, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ln2): RMSNorm()
      (mlp): DenseMLP(
        (ff): ExpertMLP(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
        )
      )
    )
    (1): Block(
      (ln1): RMSNorm()
      (attn): CausalSelfAttention(
        (qkv): Linear(in_features=512, out_features=1536, bias=False)
        (proj): Linear(in_features=512, out_features=512, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ln2): RMSNorm()
      (mlp): MoEMLP(
        (router): Top1Router(
          (proj): Linear(in_features=512, 

In [None]:
context = "JULIET: "


In [None]:
ctx = torch.tensor([[stoi.get(c, 0) for c in context]], dtype=torch.long, device=cfg.device)


In [None]:
out = model.generate(ctx, max_new_tokens=300, temperature=0.8, top_k=50)
