[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](
https://colab.research.google.com/github/FranQuant/the_ai_engineer/blob/main/capstones/week03_transformers/mini_gpt_diagnostics.ipynb
)

# Mini GPT — Diagnostic & Visualization Suite

Diagnostics for the tiny decoder-only transformer trained in `train_mini_gpt.py`.

## 1. Imports & Setup
- Load the trained checkpoint `mini_gpt.pt`.
- Instantiate `MiniTransformerLM` with the same config as the training script.

In [None]:
import math
from pathlib import Path

import random
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from mini_transformer import MiniTransformerLM
from transformer_block import TransformerBlock
from multihead_attention import MultiHeadAttention
from scaled_dot_product_attention import scaled_dot_product_attention

# -------------------------------------------------------------------------
# Reproducibility
# -------------------------------------------------------------------------
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
plt.style.use("ggplot")
sns.set_context("talk")

# -------------------------------------------------------------------------
# Model config (must match train_mini_gpt.py EXACTLY)
# -------------------------------------------------------------------------
cfg = dict(
    d_model=64,         # Match training script
    num_heads=4,
    d_ff=256,
    num_layers=4,
    max_seq_len=4,
)

checkpoint_path = Path("mini_gpt.pt")
checkpoint_exists = checkpoint_path.exists()

print(f"Using device: {device}")
print(f"Checkpoint exists: {checkpoint_exists}")

## 2. Tokenizer Utilities
Rebuild the character tokenizer and corpus used during training.

In [None]:
# -------------------------------------------------------------------------
# 2. Tokenizer Utilities (LOCKED VERSION)
# This must match EXACTLY the tiny_text used inside train_mini_gpt.py
# -------------------------------------------------------------------------

tiny_text = """
hello tiny transformer
hi tiny transformer
hello week three
hi transformer
"""

# Character vocabulary (sorted, deterministic)
chars = sorted(list(set(tiny_text)))
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}
vocab_size = len(chars)

def encode(s: str):
    """Convert string to list[int]."""
    return [stoi[c] for c in s]

def decode(ids):
    """Convert list[int] or tensor[int] back to string."""
    return "".join(itos[int(i)] for i in ids)

# Entire corpus as a flat tensor
ids = torch.tensor(encode(tiny_text), dtype=torch.long)

print(f"Vocab size: {vocab_size}")
print(f"Sample tokens: {chars[:10]}")

## 3. Load Model & Causal Mask Helper

In [None]:
## 3. Load Model & Causal Mask Helper

def build_causal_mask(T: int, device: torch.device) -> torch.Tensor:
    """
    Build a lower-triangular causal mask of shape [T, T] with 1s allowed and 0s masked.
    The scaled_dot_product_attention() function will convert zeros into -inf internally.
    """
    return torch.tril(torch.ones(T, T, device=device))


# Instantiate model (must match training config)
model = MiniTransformerLM(
    vocab_size=vocab_size,
    d_model=cfg["d_model"],
    num_heads=cfg["num_heads"],
    d_ff=cfg["d_ff"],
    num_layers=cfg["num_layers"],
    max_seq_len=cfg["max_seq_len"],
).to(device)

# Load checkpoint if available
if checkpoint_exists:
    state = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state, strict=True)
    print("Loaded checkpoint mini_gpt.pt")
else:
    print("Checkpoint not found; using randomly initialized weights (diagnostics only).")

model.eval()

## 4. Attention Heatmaps
- Extract per-layer, per-head attention from the model.
- Visualize head-wise and averaged attention.

In [None]:
# -------------------------------------------------------------------------
# 4. Attention Heatmaps
# - Extract per-layer, per-head attention from the model.
# - Visualize head-wise and averaged attention.
# -------------------------------------------------------------------------

def run_with_attn(model: MiniTransformerLM, input_ids: torch.Tensor, attn_mask: torch.Tensor | None = None):
    """
    Manual forward pass to capture attention weights (per layer, per head).

    Returns:
      logits: [B, T, vocab_size]
      attn_weights: list of length num_layers, each [B, H, T, T]
    """
    model.eval()
    input_ids = input_ids.to(device)
    B, T = input_ids.shape

    # Embedding + positional encoding + (optional) embed dropout
    x = model.token_embed(input_ids)          # [B,T,D]
    x = model.pos_encoding(x)                # sinusoidal PE
    x = model.embed_dropout(x)               # same as in MiniTransformerLM.forward

    attn_weights = []

    for blk in model.blocks:
        # Pre-LN
        x_norm = blk.ln1(x)                  # [B,T,D]

        # Project to Q,K,V and reshape
        B_, T_, D_ = x_norm.shape
        h = blk.mha.h
        dh = blk.mha.dh

        q = blk.mha.q_proj(x_norm).view(B_, T_, h, dh).transpose(1, 2)  # [B,H,T,Dh]
        k = blk.mha.k_proj(x_norm).view(B_, T_, h, dh).transpose(1, 2)  # [B,H,T,Dh]
        v = blk.mha.v_proj(x_norm).view(B_, T_, h, dh).transpose(1, 2)  # [B,H,T,Dh]

        # Raw attention scores [B,H,T,T]
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(dh)

        # Apply causal / provided mask exactly as in the main path
        if attn_mask is not None:
            mask = attn_mask
            if mask.ndim == 2:                # [T,T]
                mask = mask.unsqueeze(0).unsqueeze(0)   # [1,1,T,T]
            elif mask.ndim == 3:              # [1,T,T] or [B,T,T]
                if mask.shape[0] in (1, B_):
                    mask = mask.unsqueeze(1)  # [1,1,T,T] or [B,1,T,T]
                else:
                    raise ValueError(f"Unsupported 3D mask shape {mask.shape}")
            elif mask.ndim == 4:
                # assume [1,1,T,T] or [B,1,T,T]
                pass
            else:
                raise ValueError(f"Unsupported mask shape {mask.shape}")

            mask_bool = mask.to(dtype=torch.bool, device=scores.device)
            scores = scores.masked_fill(~mask_bool, float("-inf"))

        # Stable softmax
        scores = scores - scores.max(dim=-1, keepdim=True).values
        w = torch.softmax(scores, dim=-1)     # [B,H,T,T]
        attn_weights.append(w.detach().cpu())

        # Continue forward with the block, re-using the regular MHA
        attn_out = blk.mha(x_norm, attn_mask)             # [B,T,D]
        attn_out = blk.attn_dropout(attn_out)             # match block forward
        x = x + attn_out                                  # residual

        x_norm2 = blk.ln2(x)
        ffn_out = blk.ffn(x_norm2)
        ffn_out = blk.ffn_dropout(ffn_out)
        x = x + ffn_out

    x = model.final_ln(x)
    logits = model.lm_head(x)
    return logits, attn_weights


# Run a sample through and visualize
sample_prompt = "tiny"  # length 4, <= cfg["max_seq_len"]
assert len(sample_prompt) <= cfg["max_seq_len"]

with torch.no_grad():
    ids_prompt = torch.tensor([encode(sample_prompt)], device=device)  # [1,T]
    mask = build_causal_mask(ids_prompt.size(1), device)              # [T,T]
    logits, attn_maps = run_with_attn(model, ids_prompt, attn_mask=mask)

print(f"Captured {len(attn_maps)} layers of attention; shape layer0: {attn_maps[0].shape}")

layer_idx = 0
head_idx = 0
attn = attn_maps[layer_idx][0, head_idx].numpy()  # [T,T]
tokens = list(sample_prompt)

plt.figure(figsize=(6, 5))
sns.heatmap(attn, xticklabels=tokens, yticklabels=tokens, cmap="magma", annot=False)
plt.title(f"Layer {layer_idx} Head {head_idx} Attention")
plt.xlabel("Key / Value positions")
plt.ylabel("Query positions")
plt.show()

# Averaged attention over heads
plt.figure(figsize=(6, 5))
avg_attn = attn_maps[layer_idx].mean(1)[0]        # [T,T]
sns.heatmap(avg_attn.numpy(), xticklabels=tokens, yticklabels=tokens, cmap="magma")
plt.title(f"Layer {layer_idx} Averaged Attention")
plt.show()

## 5. Residual Stream Diagnostics
- Capture norms before/after MHA and FFN per layer.
- Plot L2 norms and activation histograms.

In [None]:
# -------------------------------------------------------------------------
# 5. Residual Stream Diagnostics
# - Capture norms before/after MHA and FFN per layer.
# - Plot L2 norms and activation histograms.
# -------------------------------------------------------------------------

def residual_diagnostics(model: MiniTransformerLM, input_ids: torch.Tensor, attn_mask=None):
    model.eval()
    input_ids = input_ids.to(device)
    B, T = input_ids.shape

    # Same embedding pipeline as MiniTransformerLM.forward
    x = model.token_embed(input_ids)      # [B,T,D]
    x = model.pos_encoding(x)            # add sinusoidal PE
    x = model.embed_dropout(x)           # (p=0.0 now but keeps API consistent)

    norms = []
    hists = []

    for blk in model.blocks:
        # Pre-MHA
        x_norm1 = blk.ln1(x)
        norms.append({"pre_mha": x_norm1.norm(dim=-1).mean().item()})

        attn_out = blk.mha(x_norm1, attn_mask)
        attn_out = blk.attn_dropout(attn_out)
        x = x + attn_out

        # Pre-FFN
        x_norm2 = blk.ln2(x)
        norms[-1]["pre_ffn"] = x_norm2.norm(dim=-1).mean().item()

        ffn_out = blk.ffn(x_norm2)
        ffn_out = blk.ffn_dropout(ffn_out)
        x = x + ffn_out

        # Post-FFN
        norms[-1]["post_ffn"] = x.norm(dim=-1).mean().item()
        hists.append(x.detach().cpu().flatten())

    return norms, hists


# Use a context whose length <= max_seq_len
ids_prompt = torch.tensor([encode("tiny")], device=device)
assert ids_prompt.size(1) <= cfg["max_seq_len"]

mask = build_causal_mask(ids_prompt.size(1), device)
norms, hists = residual_diagnostics(model, ids_prompt, attn_mask=mask)

# Plot L2 norms
plt.figure(figsize=(6, 4))
for key in ["pre_mha", "pre_ffn", "post_ffn"]:
    plt.plot([n[key] for n in norms], label=key)
plt.title("Residual Stream L2 Norms per Layer")
plt.xlabel("Layer")
plt.ylabel("Mean L2 norm")
plt.legend()
plt.show()

# Activation histogram (last layer)
plt.figure(figsize=(6, 4))
plt.hist(hists[-1].numpy(), bins=50, alpha=0.8)
plt.title("Activation Histogram (last layer)")
plt.xlabel("Activation value")
plt.ylabel("Count")
plt.show()

## 6. Embedding Space Visualization
- Project token embeddings to 2D via PCA (fallback to t-SNE if desired).
- Annotate characters.

In [None]:
# -------------------------------------------------------------------------
# 6. Embedding Space Visualization (LOCKED FINAL VERSION)
# -------------------------------------------------------------------------

# Extract token embeddings as numpy array
emb = model.token_embed.weight.detach().cpu().numpy()

# First try PCA → deterministic, stable
try:
    proj = PCA(n_components=2, random_state=0).fit_transform(emb)
except Exception as e:
    print("PCA failed, falling back to t-SNE:", e)
    proj = TSNE(
        n_components=2,
        init="random",
        learning_rate="auto",
        perplexity=5,
        random_state=0
    ).fit_transform(emb)

# Plot embedding projections
plt.figure(figsize=(8, 6))
plt.scatter(proj[:, 0], proj[:, 1], alpha=0.75, s=80)

# Annotate each point with its character
for i, ch in itos.items():
    label = repr(ch).strip("'")          # clean printable character
    plt.annotate(label, (proj[i, 0], proj[i, 1]), fontsize=10)

plt.title("Token Embeddings (2D Projection)", fontsize=14)
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.grid(True, alpha=0.3)
plt.show()

## 7. Logit / Sampling Diagnostics
- Plot logits histogram and entropy for a context.
- Show top-k probabilities and temperature effects.

In [None]:
# -------------------------------------------------------------------------
# 7. Logit / Sampling Diagnostics (LOCKED FINAL VERSION)
# -------------------------------------------------------------------------

# A short context (must not exceed max_seq_len=4)
context = "tin"       # length 3, safe
ctx_ids = torch.tensor([encode(context)], device=device)   # [1,T]
T = ctx_ids.size(1)

# Build causal mask
mask = build_causal_mask(T, device)

with torch.no_grad():
    logits = model(ctx_ids, attn_mask=mask)     # [1,T,V]
    last_logits = logits[0, -1]                 # [V]
    probs = torch.softmax(last_logits, dim=-1)
    entropy = -(probs * probs.log()).sum().item()

print(f"Entropy (nats): {entropy:.4f}")
print("\nTop-k tokens:")

top_k = min(10, vocab_size)
top_probs, top_idx = torch.topk(probs, top_k)

for p, idx in zip(top_probs, top_idx):
    print(f"{itos[int(idx)]!r}: {p.item():.3f}")

# Logits histogram
plt.figure(figsize=(6,4))
plt.hist(last_logits.cpu().numpy(), bins=vocab_size, color="salmon", alpha=0.85)
plt.title("Last-token logits histogram")
plt.xlabel("Logit")
plt.ylabel("Count")
plt.show()


# -------------------------------------------------------------------------
# Generation function compatible with max_seq_len=4
# -------------------------------------------------------------------------
def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None):
    """
    idx: [1, T] initial sequence
    Returns: extended sequence
    """
    model.eval()
    for _ in range(max_new_tokens):

        # Always truncate to last max_seq_len tokens
        idx_cond = idx[:, -cfg["max_seq_len"]:]        # [1, max_seq_len]

        T = idx_cond.size(1)
        mask = build_causal_mask(T, idx.device)

        logits = model(idx_cond, attn_mask=mask)       # [1,T,V]
        logits = logits[:, -1, :] / temperature        # [1,V]

        if top_k is not None:
            k = min(top_k, logits.size(-1))
            v, ix = torch.topk(logits, k)
            logits[logits < v[:, [-1]]] = -float("inf")

        probs = torch.softmax(logits, dim=-1)          # [1,V]
        next_id = torch.multinomial(probs, num_samples=1)  # [1,1]

        idx = torch.cat([idx, next_id], dim=1)         # extend sequence

    return idx


# -------------------------------------------------------------------------
# Generate samples
# -------------------------------------------------------------------------
start = torch.tensor([[random.randint(0, vocab_size - 1)]], device=device)

samples = {
    "greedy":   generate(model, start.clone(), 40, temperature=1.0, top_k=1),
    "temp_0.8": generate(model, start.clone(), 40, temperature=0.8),
    "temp_1.2": generate(model, start.clone(), 40, temperature=1.2),
    "topk8":    generate(model, start.clone(), 40, temperature=0.9, top_k=8),
}

for name, seq in samples.items():
    print(f"\n=== Sample ({name}) ===")
    print(decode(seq[0]))


## 8. Conclusions / Notes
- Use attention heatmaps to verify causal masking and focus patterns.
- Monitor residual norms for stability; large growth may indicate learning-rate issues.
- Embedding projections reveal clustering by character class.
- Logit entropy and top-k give a quick sense of calibration; sampling comparisons show temperature/top-k effects.