<a href="https://colab.research.google.com/github/WilliamZhang20/Transformer-from-Scratch/blob/main/MoE_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Building a GPT

Added KV Caching + Mixture of Experts

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer # Use a real tokenizer, e.g., a simple one for char-level/word-level

# --- Setup for Tokenization and Block Processing ---
# Use the same block_size from your model
block_size = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# A simple tokenizer (e.g., a GPT-2 style BPE tokenizer for illustration)
# For simplicity and scale, let's use a standard word/subword tokenizer.
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token # Necessary for batching

vocab_size = tokenizer.vocab_size

# --- 1. Load the Dataset ---
print("Loading WikiText-103 dataset...")
raw_datasets = load_dataset("wikitext", "wikitext-103-v1")

# --- 2. Tokenize the Data ---
def tokenize_function(examples):
    # 'text' is the column name in the wikitext dataset
    return tokenizer(examples["text"], truncation=False, max_length=100000)

tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    num_proc=4, # Use multiple processes for fast tokenization
    remove_columns=["text"],
)
print("Tokenization complete.")

# --- 3. Group and Block the Data (Crucial for LM Training) ---
# Concatenate all texts and split them into fixed-size chunks (block_size)
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])

    # Drop the last chunk if it's smaller than block_size
    total_length = (total_length // block_size) * block_size

    # Split by block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }

    # 'labels' (targets) are just the shifted input 'input_ids'
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    num_proc=4,
)
print("Grouping complete.")


# --- 4. Custom DataLoader/get_batch for the Model ---
train_data = lm_datasets['train'].with_format("torch")
val_data = lm_datasets['validation'].with_format("torch")

def get_batch(split):
    data = train_data if split == 'train' else val_data

    # ix generates random indices from the dataset
    ix = torch.randint(len(data), (batch_size,))

    # The dataset now directly provides tokenized and blocked data
    # We retrieve the 'input_ids' as 'x' and 'labels' as 'y'
    batch = data[ix]
    x = batch['input_ids'].to(device)
    y = batch['labels'].to(device)

    return x, y

# Example: Get a batch (for testing)
xb, yb = get_batch('train')
print(xb.shape, yb.shape)

In [2]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-09-26 13:34:34--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-09-26 13:34:34 (40.3 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MoE(nn.Module):
    """
    Simple Mixture of Experts (top-1 routing).
    - n_experts: number of expert FFNs
    - expert_hidden: hidden dim for each expert (usually 4*n_embd)
    - aux_loss_coef: coefficient for simple load balancing loss
    """
    def __init__(self, n_embd, n_experts=4, expert_hidden=None, aux_loss_coef=1e-2):
        super().__init__()
        if expert_hidden is None:
            expert_hidden = 4 * n_embd
        self.n_experts = n_experts
        self.n_embd = n_embd
        self.aux_loss_coef = aux_loss_coef

        # create expert modules (simple FFN)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(n_embd, expert_hidden),
                nn.ReLU(),
                nn.Linear(expert_hidden, n_embd)
            ) for _ in range(n_experts)
        ])

        # gating network: produce scores per expert for each token
        self.gate = nn.Linear(n_embd, n_experts)

    def forward(self, x):
        # x: (B, T, C)
        B, T, C = x.shape
        x_flat = x.view(B*T, C)  # (BT, C)

        # gating logits -> probabilities
        gate_logits = self.gate(x_flat)           # (BT, n_experts)
        gate_probs = F.softmax(gate_logits, dim=-1)  # (BT, n_experts)

        # top-1 route
        top1_idx = torch.argmax(gate_probs, dim=-1)  # (BT,)
        # one-hot dispatch mask (BT, n_experts)
        dispatch_mask = F.one_hot(top1_idx, num_classes=self.n_experts).float()

        # --- compute expert outputs by gathering assigned tokens per expert ---
        # prepare output placeholder
        outputs = torch.zeros_like(x_flat)  # (BT, C)
        device = x_flat.device

        expert_losses = []
        for e, expert in enumerate(self.experts):
            # find indices assigned to this expert
            mask_e = (top1_idx == e)  # (BT,)
            if mask_e.any():
                x_e = x_flat[mask_e]            # (N_e, C)
                y_e = expert(x_e)               # (N_e, C)
                outputs[mask_e] = y_e
            else:
                # no tokens routed here; skip
                continue

        outputs = outputs.view(B, T, C)

        # --- simple load balancing aux loss ---
        # importance = sum of gate probs per expert, load = number of tokens chosen per expert
        importance = gate_probs.sum(dim=0)            # (n_experts,)
        load = dispatch_mask.sum(dim=0)               # (n_experts,)
        # normalize and penalize deviation from uniform
        importance_norm = importance / (importance.sum() + 1e-9)
        load_norm = load / (load.sum() + 1e-9)
        # squared deviation
        aux_loss = ((importance_norm - (1.0/self.n_experts)).pow(2).mean()
                    + (load_norm - (1.0/self.n_experts)).pow(2).mean())
        aux_loss = aux_loss * self.aux_loss_coef

        # return outputs and aux_loss (aux_loss is small scalar)
        return outputs, aux_loss

In [13]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
# ------------

torch.manual_seed(1337)

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = MoE(n_embd, n_experts=4, expert_hidden=4*n_embd, aux_loss_coef=1e-2)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        ff_out, ff_aux = self.ffwd(self.ln2(x))
        x = x + ff_out
        return x, ff_aux

# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        # run through blocks, accumulating auxiliary MoE losses
        total_aux_loss = 0.0
        for block in self.blocks:
            # if block is MoE-enabled, it returns (x, aux_loss)
            out = block(x)
            if isinstance(out, tuple):
                x, aux_loss = out
                total_aux_loss = total_aux_loss + aux_loss
            else:
                x = out

        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            ce_loss = F.cross_entropy(logits, targets)
            # add aux MoE loss
            loss = ce_loss + total_aux_loss

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = BigramLanguageModel()
model = model.to(device)
m = torch.compile(model)

@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss(m)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

0.607825 M parameters
step 0: train loss 4.3679, val loss 4.3681
step 100: train loss 2.6425, val loss 2.6571
step 200: train loss 2.5153, val loss 2.5222
step 300: train loss 2.4343, val loss 2.4386
step 400: train loss 2.3731, val loss 2.3890
step 500: train loss 2.3285, val loss 2.3297
step 600: train loss 2.2852, val loss 2.2995
step 700: train loss 2.2419, val loss 2.2681
step 800: train loss 2.2045, val loss 2.2313
step 900: train loss 2.1785, val loss 2.2027
step 1000: train loss 2.1332, val loss 2.1704
step 1100: train loss 2.1123, val loss 2.1552
step 1200: train loss 2.0774, val loss 2.1168
step 1300: train loss 2.0546, val loss 2.1066
step 1400: train loss 2.0253, val loss 2.0788
step 1500: train loss 1.9972, val loss 2.0657
step 1600: train loss 1.9735, val loss 2.0449
step 1700: train loss 1.9511, val loss 2.0506
step 1800: train loss 1.9491, val loss 2.0413
step 1900: train loss 1.9325, val loss 2.0117
step 2000: train loss 1.9056, val loss 2.0005
step 2100: train loss 1.

In [14]:
# generate from the model
import time
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))


And they brides.

SOTRLOTES:
Kind Peomb, you whaS greats that hause:
Whither us hath but comelancase away, my fears
acumzonour
Yoursel'sful their vart; dill, at misters, hain;
Sir, that GrEt, and Warrner one wars!
Alaring that this much him sptup; and all,
yet lord's know patelives; but you, then again Wilst our Ceiizase,
Stillourion a guident to-disore-homb
To king thrust for their hand: mowhe is expath,
Madaried my offery his burs, you ar
ards be his gentle up the king
And kiry to-charmost! My fire youk,
If you see, thy my mesore and see--'Sir;
But with ready the custil son't weep one.

Thun? Sevomats:
Decks lord
I issure to eirsurable I deabere over a mains!

RAMILLO:
I cam stie so upon thou fear, as nyther's,
Why, knows hone duste tee, our hohearth.
I'll againce's, with soul. Dive lore made so lack.

Prive, must home,
And I shark poir noth now thie,
Thou hath diiusorthern, thee subgan appyter'd what thou the cusles,
Be we king carrick, the Lorderward, time to tee, for Gly thee?

S

In [None]:
!pip install --quiet torch_tensorrt

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/40.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.7/40.7 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.1/15.1 MB[0m [31m40.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m71.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for tensorrt (setup.py) ... [?25l[?25hdone
  Building wheel for tensorrt_cu12 (setup.py) ... [?25l[?25hdone
  Building wheel for tensorrt-cu12-libs (pyproject.toml) ... [?25l[?25hdone


In [None]:
# TensorRT acceleration
import torch_tensorrt

class BigramInferenceWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, idx):
        idx = idx.to(torch.int32)  # TensorRT-safe
        # patch positional indices
        T = idx.shape[1]
        B = idx.shape[0]
        tok_emb = self.model.token_embedding_table(idx)  # (B,T,C)
        pos_idx = torch.arange(T, device=idx.device, dtype=torch.int32)
        pos_emb = self.model.position_embedding_table(pos_idx)
        x = tok_emb + pos_emb
        x = self.model.blocks(x)
        x = self.model.ln_f(x)
        logits = self.model.lm_head(x)
        return logits

model = model.to('cuda')
inference_model = BigramInferenceWrapper(model)

example_input = torch.zeros((1, 1), dtype=torch.int32, device='cuda')

# Trace the wrapper model
traced_model = torch.jit.trace(inference_model, example_input)
traced_model.save("bigram_traced.pt")



In [None]:
trt_model = torch_tensorrt.compile(
    traced_model,
    inputs=[torch_tensorrt.Input((1, block_size), dtype=torch.int32)],
    enabled_precisions={torch.float32},
    workspace_size=1 << 20,
    truncate_long_and_double=True
)



In [None]:
import torch
import torch.nn.functional as F

def generate_trt(trt_model, idx, max_new_tokens, block_size, decode_fn):
    idx = idx.to(torch.int32)  # Ensure input is int32 for TensorRT compatibility
    for _ in range(max_new_tokens):
        # Crop to the last block_size tokens
        idx_cond = idx[:, -block_size:]  # Shape: (1, block_size)

        # Ensure the input is the correct shape
        if idx_cond.shape[1] < block_size:
            # Pad with zeros if the sequence is shorter than block_size
            padding = torch.zeros((idx_cond.shape[0], block_size - idx_cond.shape[1]),
                                dtype=torch.int32, device=idx.device)
            idx_cond = torch.cat((padding, idx_cond), dim=1)
        idx_cond = idx_cond.to(torch.int32)  # Shape: (1, block_size)

        # Run TensorRT model
        logits = trt_model(idx_cond)  # Shape: (1, block_size, vocab_size)

        # Use the logits for the last time step
        logits = logits[:, -1, :]  # Shape: (1, vocab_size)
        probs = torch.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)  # Shape: (1, 1)

        # Append new token
        idx = torch.cat((idx, idx_next), dim=1)

    return decode_fn(idx[0].tolist())

context = torch.zeros((1, 1), dtype=torch.int32, device='cuda')
generated_text = generate_trt(trt_model, context, max_new_tokens=500, block_size=32, decode_fn=decode)
print(generated_text)


SLERSCUMEREY:
Let is not not you wast. But, and my virture to controse wove?

Secomfort, firsts osys have?

HORTIONA:
O, I screts I some.
I is Alme my notly have beyeing make:
And I shame itto me your clannst of
my sounk, silewiman our sweenger to-discoustractar wames
Wetcher him you.
Ummon teat the sadlingna
To carse, whose been misely opge,
And there such but kind this is guryanny to purt butwere my warre morries
Both loves your wonds it hate.

LUCENTIO:
WhavEN thrus than to my unaght, this sw


As seen below, TensorRT is about 8x faster.

In [None]:
# Benchmark TensorRT
start_time = time.time()
_ = generate_trt(trt_model, context, max_new_tokens=2000, block_size=32, decode_fn=decode)
end_time = time.time()
print(end_time - start_time)

1.6884980201721191
