<a href="https://colab.research.google.com/github/KelvinM9187/Supervised-Speech-Recognition-with-Transformers/blob/main/mini_gpt_built_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Setting Up The Environment

from google.colab import drive
drive.mount('/content/drive')

CHECKPOINT_ROOT = '/content/drive/MyDrive/mini_gpt_checkpoints'
import os
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)

print("Checkpoints will be saved to:", CHECKPOINT_ROOT)

In [None]:
# Necessary Libraries

import math, time, sys, os, random
from pathlib import Path
import torch
import torch.nn as nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
import numpy as np

# Reproducibility
SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)


In [None]:
# Download the Tiny Shakespeare dataset
DATA_PATH = 'tiny_shakespeare.txt'
if not os.path.exists(DATA_PATH):
    !wget -q https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O {DATA_PATH}
    print("Downloaded tiny_shakespeare to", DATA_PATH)
else:
    print("Found existing dataset:", DATA_PATH)

# show head
with open(DATA_PATH, 'r', encoding='utf-8') as f:
    raw = f.read()
print("Dataset length (chars):", len(raw))
print("First 200 chars:\n", raw[:200].replace('\n','\\n'))


In [None]:
# Encoding & train/val/test split (80/10/10)
# Build char-level vocabulary
chars = sorted(list(set(raw)))
vocab_size = len(chars)
print("Vocab size:", vocab_size)
itos = {i:ch for i,ch in enumerate(chars)}
stoi = {ch:i for i,ch in enumerate(chars)}

# encode / decode helpers
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])

data = torch.tensor(encode(raw), dtype=torch.long)

n = len(data)
n_train = int(0.8 * n)
n_val   = int(0.1 * n)
n_test  = n - n_train - n_val
train_data = data[:n_train].to(device)
val_data   = data[n_train:n_train+n_val].to(device)
test_data  = data[n_train+n_val:].to(device)

print(f"Split: train {len(train_data)}, val {len(val_data)}, test {len(test_data)}")


In [None]:
# Data loader utility
# Function to get random mini-batches of contiguous characters
def get_batch(split, block_size, batch_size):
    if split == 'train':
        d = train_data
    elif split == 'val':
        d = val_data
    elif split == 'test':
        d = test_data
    else:
        raise ValueError(split)
    ix = torch.randint(0, len(d) - block_size, (batch_size,))
    x = torch.stack([d[i:i+block_size] for i in ix])
    y = torch.stack([d[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

# quick test
b_x, b_y = get_batch('train', block_size=64, batch_size=4)
print("Batch shapes:", b_x.shape, b_y.shape)


In [None]:
# Transformer Hyperparameters
block_size = 256   # context length
batch_size = 64
n_layers = 4
n_heads  = 4
d_model  = 256     # embedding dimension
d_ff     = 1024    # feed-forward hidden layer
dropout  = 0.1
learning_rate = 3e-4
max_iters = 3000
eval_interval = 200
save_interval = 500
grad_clip = 1.0

print("Hyperparameters set.")


In [None]:
# Model Implementation (Decoder-only mini-GPT)
# We implement a minimal decoder-only transformer with causal masking.

class CausalSelfAttention(nn.Module):
    def __init__(self, dim, n_heads, attn_dropout=0.0, proj_dropout=0.0):
        super().__init__()
        assert dim % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.proj_dropout = nn.Dropout(proj_dropout)
        # causal mask is created dynamically in forward

    def forward(self, x, attn_mask=None):
        B, T, C = x.size()
        qkv = self.qkv(x)  # (B, T, 3*C)
        qkv = qkv.reshape(B, T, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # each: (B, heads, T, head_dim)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale  # (B, heads, T, T)

        # causal mask: allow positions j <= i
        mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0)  # (1,1,T,T)
        attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.attn_dropout(attn_probs)

        out = torch.matmul(attn_probs, v)  # (B, heads, T, head_dim)
        out = out.transpose(1,2).contiguous().view(B, T, C)
        out = self.proj(out)
        out = self.proj_dropout(out)
        return out

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )
    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, mlp_hidden_dim, dropout=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = CausalSelfAttention(dim, n_heads, attn_dropout=dropout, proj_dropout=dropout)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = FeedForward(dim, mlp_hidden_dim, dropout=dropout)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class MiniGPT(nn.Module):
    def __init__(self, vocab_size, block_size, n_layers, n_heads, d_model, d_ff, dropout=0.0):
        super().__init__()
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.randn(1, block_size, d_model) * 0.01)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.Sequential(*[
            TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        # weight initialization
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.block_size, "Sequence length longer than block size"
        tok_emb = self.token_emb(idx)              # (B, T, d_model)
        pos_emb = self.pos_emb[:, :T, :]          # (1, T, d_model)
        x = self.drop(tok_emb + pos_emb)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)                     # (B, T, vocab)
        if targets is None:
            return logits
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

# instantiate model
model = MiniGPT(vocab_size=vocab_size, block_size=block_size, n_layers=n_layers,
                n_heads=n_heads, d_model=d_model, d_ff=d_ff, dropout=dropout).to(device)

print("Model size (parameters):", sum(p.numel() for p in model.parameters())/1e6, "M")


In [None]:
# Optimizer, Scheduler & Utility Functions
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-1)
# simple cosine lr schedule with warmup
def get_scheduler(optimizer, warmup_steps=200, total_steps=max_iters):
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scheduler = get_scheduler(optimizer)

# checkpoint helpers
def save_checkpoint(step, model, optimizer, scheduler, path):
    state = {
        'step': step,
        'model_state_dict': model.state_dict(),
        'optim_state_dict': optimizer.state_dict(),
        'sched_state_dict': scheduler.state_dict()
    }
    torch.save(state, path)

def load_checkpoint(path, model, optimizer=None, scheduler=None, map_location=device):
    checkpoint = torch.load(path, map_location=map_location)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optim_state_dict'])
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint.get('sched_state_dict', {}))
    return checkpoint['step']

# text generation (sampling)
@torch.no_grad()
def sample(model, start_text, length=500, temperature=1.0, top_k=None):
    model.eval()
    idx = torch.tensor(encode(start_text), dtype=torch.long, device=device).unsqueeze(0)  # (1, T)
    for _ in range(length):
        idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
        logits = model(idx_cond)[:, -1, :]  # (1, vocab)
        logits = logits / (temperature if temperature > 0 else 1.0)
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            minv = v[:, -1].unsqueeze(1)
            logits = torch.where(logits < minv, torch.full_like(logits, -1e10), logits)
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, next_id), dim=1)
    return decode(idx.squeeze().tolist())

print("Sampling function ready.")


In [None]:
# Training Loop
# Tracks losses and plots, saves checkpoints and samples.
model.train()

train_losses, val_losses, iters = [], [], []
best_val_loss = 1e9
start_iter = 0

# checkpoint
ckpt_path = '/content/drive/MyDrive/mini_gpt_checkpoints/ckpt_step_1000.pt'
if ckpt_path and os.path.exists(ckpt_path):
    start_iter = load_checkpoint(ckpt_path, model, optimizer, scheduler)
    print("Resumed from", ckpt_path, "at iter", start_iter)

for it in range(start_iter + 1, max_iters + 1):
    xb, yb = get_batch('train', block_size, batch_size)
    logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer.step()
    scheduler.step()

    train_losses.append(loss.item())
    iters.append(it)

    # periodic eval on validation set
    if it % eval_interval == 0 or it == 1:
        model.eval()
        with torch.no_grad():
            # compute val loss on a few batches to get a stable estimate
            val_loss_vals = []
            val_batches = 10
            for _ in range(val_batches):
                xb_val, yb_val = get_batch('val', block_size, batch_size)
                _, vloss = model(xb_val, yb_val)
                val_loss_vals.append(vloss.item())
            mean_val_loss = sum(val_loss_vals) / len(val_loss_vals)
            val_losses.append(mean_val_loss)
        model.train()

        # compute perplexity
        train_pp = math.exp(loss.item())
        val_pp = math.exp(mean_val_loss)
        print(f"Iter {it}/{max_iters} | train loss {loss.item():.4f} | val loss {mean_val_loss:.4f} | train pp {train_pp:.2f} | val pp {val_pp:.2f}")

        # generate a sample and save it
        sample_text = sample(model, start_text="ROMEO: ", length=400, temperature=1.0, top_k=40)
        sample_file = os.path.join(CHECKPOINT_ROOT, f"sample_iter_{it}.txt")
        with open(sample_file, 'w', encoding='utf-8') as f:
            f.write(sample_text)
        print("Sample saved to:", sample_file)

    # periodic checkpoint save
    if it % save_interval == 0 or it == max_iters:
        ckpt_file = os.path.join(CHECKPOINT_ROOT, f'ckpt_step_{it}.pt')
        save_checkpoint(it, model, optimizer, scheduler, ckpt_file)
        print("Saved checkpoint to", ckpt_file)

# After training, save final
final_path = os.path.join(CHECKPOINT_ROOT, 'ckpt_final.pt')
save_checkpoint(max_iters, model, optimizer, scheduler, final_path)
print("Training complete. Final checkpoint saved to:", final_path)


In [None]:
# Plot training & validation loss and perplexity
# Simple plotting
plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(iters, train_losses, label='train_loss')

# val_losses correspond to eval intervals
eval_iters = [i for i in iters if i % eval_interval == 0 or i == 1]
plt.plot(eval_iters[:len(val_losses)], val_losses, label='val_loss')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss curve')
plt.grid(True)

plt.subplot(1,2,2)
plt.plot(iters, [math.exp(x) for x in train_losses], label='train_perplexity')
plt.plot(eval_iters[:len(val_losses)], [math.exp(x) for x in val_losses], label='val_perplexity')
plt.xlabel('Iteration')
plt.ylabel('Perplexity')
plt.legend()
plt.title('Perplexity')
plt.grid(True)

plot_path = os.path.join(CHECKPOINT_ROOT, 'training_plots.png')
plt.savefig(plot_path)
plt.show()
print("Plots saved to:", plot_path)


In [None]:
# Final evaluation on test set
# compute test loss on some batches
model.eval()
with torch.no_grad():
    test_batches = 20
    losses = []
    for _ in range(test_batches):
        xb_test, yb_test = get_batch('test', block_size, batch_size)
        _, tloss = model(xb_test, yb_test)
        losses.append(tloss.item())
    test_loss = sum(losses) / len(losses)
    test_pp = math.exp(test_loss)
print(f"Test loss: {test_loss:.4f} | Test perplexity: {test_pp:.2f}")
