<a href="https://colab.research.google.com/github/abdulraheem-hammad/mini-transformer-from-scratch/blob/main/minitransfomer%2BBPE_tokenizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import json
import torch
import torch.nn as nn
from torch.nn import functional as F

# ------------------------------------------
# Utility functions for BPE
# ------------------------------------------
def get_stats(input_integers):
    counts = {}
    for pair in zip(input_integers, input_integers[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(input_integers, input_pair, new_token):
    merged = []
    i = 0
    while i < len(input_integers):
        if i < len(input_integers) - 1 and \
           input_integers[i] == input_pair[0] and \
           input_integers[i+1] == input_pair[1]:
            merged.append(new_token)
            i += 2
        else:
            merged.append(input_integers[i])
            i += 1
    return merged

# ------------------------------------------
# Updated BPETokenizer with save/load
# ------------------------------------------
class BPETokenizer:
    def __init__(self):
        self.merges = {}  # {(a,b): new_token_id}
        self.vocab = {i: bytes([i]) for i in range(256)}
        self.base_vocab_size = 256
        self.vocab_size = self.base_vocab_size

    def train(self, text: str, target_vocab_size: int):
        tokens = list(text.encode("utf-8"))
        num_merges = target_vocab_size - self.base_vocab_size
        for i in range(num_merges):
            stats = get_stats(tokens)
            if not stats:
                break
            pair = max(stats, key=stats.get)
            new_token = self.base_vocab_size + i
            tokens = merge(tokens, pair, new_token)
            self.merges[pair] = new_token
            self.vocab[new_token] = self.vocab[pair[0]] + self.vocab[pair[1]]
        self.vocab_size = len(self.vocab)

    def encode(self, text: str):
        tokens = list(text.encode("utf-8"))
        for pair, idx in self.merges.items():
            tokens = merge(tokens, pair, idx)
        return tokens

    def decode(self, tokens):
        byte_seq = []
        for t in tokens:
            byte_seq.extend(self.vocab[t])
        return bytes(byte_seq).decode("utf-8", errors="replace")

    def save(self, dirpath: str):
        os.makedirs(dirpath, exist_ok=True)
        merges_ser = {f"{a}_{b}": idx for (a, b), idx in self.merges.items()}
        with open(os.path.join(dirpath, "merges.json"), "w") as f:
            json.dump(merges_ser, f)
        meta = {"vocab_size": self.vocab_size, "base_vocab_size": self.base_vocab_size}
        with open(os.path.join(dirpath, "meta.json"), "w") as f:
            json.dump(meta, f)

    @classmethod
    def load(cls, dirpath: str):
        tok = cls()
        with open(os.path.join(dirpath, "merges.json"), "r") as f:
            merges_ser = json.load(f)
        tok.merges = {tuple(map(int, k.split("_"))): v for k, v in merges_ser.items()}
        for (a, b), idx in tok.merges.items():
            tok.vocab[idx] = tok.vocab[a] + tok.vocab[b]
        with open(os.path.join(dirpath, "meta.json"), "r") as f:
            meta = json.load(f)
        tok.vocab_size = meta["vocab_size"]
        tok.base_vocab_size = meta["base_vocab_size"]
        return tok

# ------------------------------------------
# Configuration and sampling
# ------------------------------------------
INPUT_FILE    = "input.txt"
TOKENIZER_DIR = "./tokenizer"
SAMPLE_SIZE   = 20 * 1024 * 1024   # 20 MB sample
MAX_TOKENS    = 5_000_000         # use first 5M tokens for training
TARGET_VOCAB  = 2000              # desired vocab size

# Read sample of the corpus
with open(INPUT_FILE, 'r', encoding='utf-8') as f:
    text = f.read(SAMPLE_SIZE)

# Initialize or load tokenizer
if os.path.isdir(TOKENIZER_DIR) and os.path.isfile(os.path.join(TOKENIZER_DIR, "merges.json")):
    print("Loading existing tokenizer...")
    tokenizer = BPETokenizer.load(TOKENIZER_DIR)
else:
    print("Training new tokenizer on sample...")
    tokenizer = BPETokenizer()
    tokenizer.train(text, TARGET_VOCAB)
    tokenizer.save(TOKENIZER_DIR)

# Encode sampled text and limit tokens
all_tokens = tokenizer.encode(text)
if len(all_tokens) > MAX_TOKENS:
    all_tokens = all_tokens[:MAX_TOKENS]

data = torch.tensor(all_tokens, dtype=torch.long)
n = int(0.9 * len(data))
train_data, val_data = data[:n], data[n:]

# ------------------------------------------
# Transformer hyperparameters adjusted
# ------------------------------------------
vocab_size    = tokenizer.vocab_size
batch_size    = 64
block_size    = 128
max_iters     = 500
eval_interval = 50
learning_rate = 3e-4
device        = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters    = 20
n_embd        = 128
n_head        = 4
n_layer       = 4
dropout       = 0.0

# ------------------------------------------
# Data loading functions
# ------------------------------------------
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size]     for i in ix]).to(device)
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]).to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# ------------------------------------------
# Model definition
# ------------------------------------------
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key   = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x); q = self.query(x)
        wei = q @ k.transpose(-2,-1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T,:T]==0, float('-inf'))
        wei = F.softmax(wei, dim=-1); wei = self.dropout(wei)
        v = self.value(x)
        return wei @ v

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj  = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.dropout(self.proj(out))

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd), nn.ReLU(),
            nn.Linear(4*n_embd, n_embd), nn.Dropout(dropout),
        )
    def forward(self, x): return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa    = MultiHeadAttention(n_head, head_size)
        self.ffwd  = FeedForward(n_embd)
        self.ln1   = nn.LayerNorm(n_embd)
        self.ln2   = nn.LayerNorm(n_embd)
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table    = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
        self.ln_f   = nn.LayerNorm(n_embd)
        self.lm_head= nn.Linear(n_embd, vocab_size)
    def forward(self, idx, targets=None):
        B,T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        if targets is None:
            return logits, None
        B,T,C = logits.shape
        loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T))
        return logits, loss
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

# ------------------------------------------
# Training loop
# ------------------------------------------
model     = BigramLanguageModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
print(f"Model params: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

for it in range(max_iters):
    if it % eval_interval == 0 or it == max_iters-1:
        losses = estimate_loss()
        print(f"step {it}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    xb, yb = get_batch('train')
    _, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# ------------------------------------------
# Generation
# ------------------------------------------
context = torch.zeros((1, 1), dtype=torch.long, device=device)
output  = model.generate(context, max_new_tokens=2000)[0].tolist()
print(decode(output))


Loading existing tokenizer...
Model params: 1.32M
step 0: train loss 7.7654, val loss 7.7698
step 50: train loss 6.9903, val loss 7.0445
step 100: train loss 6.9448, val loss 7.0152
step 150: train loss 6.9127, val loss 6.9796
step 200: train loss 6.7710, val loss 6.8681
step 250: train loss 6.4937, val loss 6.6300
step 300: train loss 6.1057, val loss 6.2508
step 350: train loss 5.7784, val loss 5.9345
step 400: train loss 5.5009, val loss 5.6673
step 450: train loss 5.2684, val loss 5.4609
step 499: train loss 5.0939, val loss 5.2893
 branes. He reyellow cheered. "ok how to piced catssailed his y and did be ever ed to like the Max. The kitchens that plealato the pock anyso wave. They would take the e. But he nd and the love they got lot of hug looked very happy.
When he could peired wrong, wtakenmake happillegluwilk with a no that. We was so proud of hneeded to the fish. I will find ed it? Twes a rainy. He drelffunnearing her of some proud of herer in Me, hey .
"Ben, not .
"Look, Tim