[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](
https://colab.research.google.com/github/FranQuant/the_ai_engineer/blob/main/capstones/week03_transformers/mini_gpt_diagnostics.ipynb
)

# Mini GPT â€” Diagnostic & Visualization Suite

Diagnostics for the tiny decoder-only transformer trained in `train_mini_gpt.py`.

## 1. Imports & Setup
- Load the trained checkpoint `mini_gpt.pt`.
- Instantiate `MiniTransformerLM` with the same config as the training script.

In [None]:
import math
from pathlib import Path

import random
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from mini_transformer import MiniTransformerLM
from transformer_block import TransformerBlock
from multihead_attention import MultiHeadAttention
from scaled_dot_product_attention import scaled_dot_product_attention

torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
plt.style.use("ggplot")
sns.set_context("talk")

# Training config (must match train_mini_gpt.py)
cfg = dict(
    d_model=128,
    num_heads=4,
    d_ff=256,
    num_layers=2,
    max_seq_len=128,
)

checkpoint_path = Path("mini_gpt.pt")
checkpoint_exists = checkpoint_path.exists()
print(f"Using device: {device}")
print(f"Checkpoint exists: {checkpoint_exists}")

## 2. Tokenizer Utilities
Rebuild the character tokenizer and corpus used during training.

In [None]:
text = """
Attention is all you need.
Transformers use self-attention to model dependencies.
This is a tiny training corpus for the Week 03 capstone.
"""

chars = sorted(list(set(text)))
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}
vocab_size = len(chars)

def encode(s: str):
    return [stoi[c] for c in s]

def decode(ids):
    return "".join([itos[int(i)] for i in ids])

ids = torch.tensor(encode(text), dtype=torch.long)
print(f"Vocab size: {vocab_size}")
print(f"Sample tokens: {chars[:10]}")

## 3. Load Model & Causal Mask Helper

In [None]:
def build_causal_mask(T: int, device: torch.device):
    return torch.tril(torch.ones(T, T, device=device))

# Instantiate model
model = MiniTransformerLM(
    vocab_size=vocab_size,
    d_model=cfg["d_model"],
    num_heads=cfg["num_heads"],
    d_ff=cfg["d_ff"],
    num_layers=cfg["num_layers"],
    max_seq_len=cfg["max_seq_len"],
).to(device)

if checkpoint_exists:
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    print("Loaded checkpoint mini_gpt.pt")
else:
    print("Checkpoint not found; using randomly initialized weights (diagnostics only).")

model.eval();

## 4. Attention Heatmaps
- Extract per-layer, per-head attention from the model.
- Visualize head-wise and averaged attention.

In [None]:
def run_with_attn(model: MiniTransformerLM, input_ids: torch.Tensor, attn_mask: torch.Tensor | None = None):
    """Manual forward to capture attention weights (per layer, per head)."""
    model.eval()
    B, T = input_ids.shape
    device = input_ids.device
    pos = torch.arange(T, device=device).unsqueeze(0)
    x = model.token_embed(input_ids) + model.pos_embed(pos)

    attn_weights = []  # list of [B,H,T,T]
    tokens = input_ids

    for blk in model.blocks:
        x_norm = blk.ln1(x)
        B_, T_, D_ = x_norm.shape
        h = blk.mha.h
        dh = blk.mha.dh

        q = blk.mha.q_proj(x_norm).view(B_, T_, h, dh).transpose(1, 2)  # [B,H,T,Dh]
        k = blk.mha.k_proj(x_norm).view(B_, T_, h, dh).transpose(1, 2)
        v = blk.mha.v_proj(x_norm).view(B_, T_, h, dh).transpose(1, 2)

        scores = (q @ k.transpose(-2, -1)) / math.sqrt(dh)  # [B,H,T,T]
        if attn_mask is not None:
            mask = attn_mask
            if mask.ndim == 2:
                mask = mask.unsqueeze(0)
            if mask.ndim == 3:
                mask = mask.unsqueeze(1)
            scores = scores + (mask == 0) * (-1e9)
        scores = scores - scores.max(dim=-1, keepdim=True).values
        w = torch.softmax(scores, dim=-1)  # [B,H,T,T]
        attn_weights.append(w.detach().cpu())

        # Continue forward
        attn_out = blk.mha(x_norm, attn_mask)
        x = x + attn_out
        x_norm2 = blk.ln2(x)
        ffn_out = blk.ffn(x_norm2)
        x = x + ffn_out

    x = model.final_ln(x)
    logits = model.lm_head(x)
    return logits, attn_weights


sample_prompt = "Attention is"
with torch.no_grad():
    ids_prompt = torch.tensor([encode(sample_prompt)], device=device)
    mask = build_causal_mask(ids_prompt.size(1), device)
    logits, attn_maps = run_with_attn(model, ids_prompt, attn_mask=mask)

print(f"Captured {len(attn_maps)} layers of attention; shape layer0: {attn_maps[0].shape}")

# Visualize first layer, head 0
layer_idx = 0
head_idx = 0
attn = attn_maps[layer_idx][0, head_idx].numpy()
tokens = list(sample_prompt)
plt.figure(figsize=(6, 5))
sns.heatmap(attn, xticklabels=tokens, yticklabels=tokens, cmap="magma", annot=False)
plt.title(f"Layer {layer_idx} Head {head_idx} Attention")
plt.xlabel("Key / Value positions")
plt.ylabel("Query positions")
plt.show()

# Averaged attention per layer
plt.figure(figsize=(6, 5))
avg_attn = attn_maps[layer_idx].mean(1)[0]
sns.heatmap(avg_attn.numpy(), xticklabels=tokens, yticklabels=tokens, cmap="magma")
plt.title(f"Layer {layer_idx} Averaged Attention")
plt.show()

## 5. Residual Stream Diagnostics
- Capture norms before/after MHA and FFN per layer.
- Plot L2 norms and activation histograms.

In [None]:
def residual_diagnostics(model: MiniTransformerLM, input_ids: torch.Tensor, attn_mask=None):
    model.eval()
    B, T = input_ids.shape
    pos = torch.arange(T, device=input_ids.device).unsqueeze(0)
    x = model.token_embed(input_ids) + model.pos_embed(pos)
    norms = []
    hists = []
    for blk in model.blocks:
        x_norm1 = blk.ln1(x)
        norms.append({"pre_mha": x_norm1.norm(dim=-1).mean().item()})
        attn_out = blk.mha(x_norm1, attn_mask)
        x = x + attn_out
        x_norm2 = blk.ln2(x)
        norms[-1]["pre_ffn"] = x_norm2.norm(dim=-1).mean().item()
        ffn_out = blk.ffn(x_norm2)
        x = x + ffn_out
        norms[-1]["post_ffn"] = x.norm(dim=-1).mean().item()
        hists.append(x.detach().cpu().flatten())
    return norms, hists


ids_prompt = torch.tensor([encode("Transformers")], device=device)
mask = build_causal_mask(ids_prompt.size(1), device)
norms, hists = residual_diagnostics(model, ids_prompt, attn_mask=mask)

# Plot L2 norms
plt.figure(figsize=(6,4))
for key in ["pre_mha", "pre_ffn", "post_ffn"]:
    plt.plot([n[key] for n in norms], label=key)
plt.title("Residual Stream L2 Norms per Layer")
plt.xlabel("Layer")
plt.ylabel("Mean L2 norm")
plt.legend()
plt.show()

# Activation histogram (last layer)
plt.figure(figsize=(6,4))
plt.hist(hists[-1].numpy(), bins=50, color="steelblue", alpha=0.8)
plt.title("Activation Histogram (last layer)")
plt.xlabel("Activation value")
plt.ylabel("Count")
plt.show()

## 6. Embedding Space Visualization
- Project token embeddings to 2D via PCA (fallback to t-SNE if desired).
- Annotate characters.

In [None]:
emb = model.token_embed.weight.detach().cpu().numpy()
try:
    proj = PCA(n_components=2).fit_transform(emb)
except Exception:
    proj = TSNE(n_components=2, init="random", learning_rate="auto").fit_transform(emb)

plt.figure(figsize=(8,6))
plt.scatter(proj[:,0], proj[:,1], alpha=0.7)
for i, ch in itos.items():
    plt.annotate(repr(ch).strip("'"), (proj[i,0], proj[i,1]), fontsize=9)
plt.title("Token Embeddings (2D projection)")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.show()

## 7. Logit / Sampling Diagnostics
- Plot logits histogram and entropy for a context.
- Show top-k probabilities and temperature effects.

In [None]:
context = "Attention is"
ctx_ids = torch.tensor([encode(context)], device=device)
mask = build_causal_mask(ctx_ids.size(1), device)
with torch.no_grad():
    logits = model(ctx_ids, attn_mask=mask)
    last_logits = logits[0, -1]
    probs = torch.softmax(last_logits, dim=-1)
    entropy = -(probs * probs.log()).sum().item()

top_k = 10
top_probs, top_idx = torch.topk(probs, top_k)
print("Entropy (nats):", entropy)
print("Top-k tokens:")
for p, idx in zip(top_probs, top_idx):
    print(f"{itos[int(idx)]!r}: {p.item():.3f}")

plt.figure(figsize=(6,4))
plt.hist(last_logits.cpu().numpy(), bins=50, color="salmon", alpha=0.8)
plt.title("Last-token logits histogram")
plt.xlabel("Logit")
plt.ylabel("Count")
plt.show()

def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None):
    model.eval()
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -cfg["max_seq_len"] :]
        T = idx_cond.size(1)
        mask = build_causal_mask(T, idx.device)
        logits = model(idx_cond, attn_mask=mask)
        logits = logits[:, -1, :] / temperature
        if top_k is not None:
            v, ix = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("inf")
        probs = torch.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)
    return idx

start = torch.tensor([[random.randint(0, vocab_size - 1)]], device=device)
samples = {
    "greedy": generate(model, start.clone(), 80, temperature=1.0, top_k=1),
    "temp_0.8": generate(model, start.clone(), 80, temperature=0.8),
    "temp_1.2": generate(model, start.clone(), 80, temperature=1.2),
    "topk8": generate(model, start.clone(), 80, temperature=0.9, top_k=8),
}

for name, seq in samples.items():
    print(f"\nSample ({name}):\n{decode(seq[0].tolist())}")

## 8. Conclusions / Notes
- Use attention heatmaps to verify causal masking and focus patterns.
- Monitor residual norms for stability; large growth may indicate learning-rate issues.
- Embedding projections reveal clustering by character class.
- Logit entropy and top-k give a quick sense of calibration; sampling comparisons show temperature/top-k effects.