In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1bBtLgEU0TdHr55w7SYOhyDG1BKljPaRc", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/04_00_intro.mp3"))

In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

# üöÄ Ablation Studies: What Really Matters in Recursive Reasoning?

*Part 4 of the Vizuara series on Tiny Recursive Models*
*Estimated time: 35 minutes*

## 1. Why Does This Matter?

We have built and trained a Tiny Recursive Model. It works. But **why** does it work?

The TRM paper's ablation study is one of its most valuable contributions. It systematically removes or changes components to reveal which design choices matter and which are just noise. The findings are surprising:

- Full backpropagation: **+30.9 percentage points**
- MLP over attention (for small grids): **+13.0 points**
- More recursion over more layers: **+7.9 points**
- EMA: **+7.5 points**

In this notebook, you will run these ablation experiments yourself and see firsthand what makes recursive reasoning tick.

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import time
%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
torch.manual_seed(42)
np.random.seed(42)

In [None]:
#@title üéß Listen: Intuition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_01_intuition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 2. Building Intuition

### What Is an Ablation Study?

In medicine, "ablation" means removing a specific tissue or structure. In machine learning, an ablation study removes or changes one component at a time to measure its contribution.

Think of it like debugging a recipe. If your cake tastes amazing, you want to know: Is it the vanilla extract? The extra egg? The longer baking time? You bake the cake multiple times, each time leaving out one ingredient, and compare results.

### The Paper's Key Findings (Sudoku-Extreme)

| Change | Accuracy Impact |
|--------|----------------|
| Full backprop (vs 1-step approx) | **+30.9%** |
| MLP mixing (vs attention) | **+13.0%** |
| 2 layers + more recursion (vs 4 layers) | **+7.9%** |
| EMA (vs no EMA) | **+7.5%** |
| Two features y+z (vs y only) | Essential |

### ü§î Think About This

Before running the experiments, make your predictions:
1. Will more recursion really beat more layers? It seems counterintuitive...
2. Why would a simple MLP beat powerful self-attention?
3. How much does the reasoning scratchpad z actually help?

Write down your predictions. Then compare them to the experimental results.

In [None]:
#@title üéß Listen: Math
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_02_math.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 3. The Mathematics (Quick Reference)

All the math was covered in Notebooks 2 and 3. Here is a quick reference:

**Recursion:**
$$z \leftarrow \text{net}(x, y, z), \quad y \leftarrow \text{net}(y, z)$$

**Prediction loss:** $\mathcal{L}_{\text{pred}} = -\sum_i y_i^{\text{true}} \log(\hat{y}_i)$

**Effective depth:** $T \times (n+1) \times n_{\text{layers}}$

The key insight for ablations: we keep total training time roughly constant and vary one factor at a time.

In [None]:
#@title üéß Listen: Setup
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_03_setup.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 4. Let's Build It ‚Äî The Ablation Framework

### 4.1 Model and Data Setup

In [None]:
# Reuse components from previous notebooks
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        return (x / rms) * self.weight

class SwiGLU(nn.Module):
    def __init__(self, dim, hidden_dim=None):
        super().__init__()
        hidden_dim = hidden_dim or dim * 4
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(dim, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, dim, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class MLPMixer(nn.Module):
    def __init__(self, seq_len, dim):
        super().__init__()
        self.token_mix = nn.Linear(seq_len, seq_len, bias=False)
    def forward(self, x):
        return self.token_mix(x.transpose(1, 2)).transpose(1, 2)

class SelfAttention(nn.Module):
    def __init__(self, dim, n_heads=4):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.qkv = nn.Linear(dim, 3 * dim, bias=False)
        self.out_proj = nn.Linear(dim, dim, bias=False)
    def forward(self, x):
        B, L, D = x.shape
        qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)
        q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)
        attn = (q @ k.transpose(-2,-1)) * (self.head_dim ** -0.5)
        attn = F.softmax(attn, dim=-1)
        out = (attn @ v).transpose(1,2).reshape(B, L, D)
        return self.out_proj(out)

class TRMLayer(nn.Module):
    def __init__(self, dim, seq_len, use_attention=False, n_heads=4):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.norm2 = RMSNorm(dim)
        self.mixer = SelfAttention(dim, n_heads) if use_attention else MLPMixer(seq_len, dim)
        self.ffn = SwiGLU(dim, hidden_dim=dim * 4)
    def forward(self, x):
        x = x + self.mixer(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

In [None]:
class AblationTRM(nn.Module):
    """
    TRM with configurable components for ablation studies.

    Configurable:
    - n_layers: number of layers in the recursive block
    - use_attention: MLP vs self-attention mixing
    - use_z: whether to use the latent reasoning feature z
    - full_backprop: whether to backprop through all recursions or use stop-grad
    """
    def __init__(self, n_classes=4, grid_size=4, dim=48,
                 n_layers=2, use_attention=False, use_z=True):
        super().__init__()
        self.dim = dim
        self.grid_size = grid_size
        self.n_classes = n_classes
        self.use_z = use_z
        seq_len = grid_size * grid_size

        # Feature dimension depends on whether we use z
        feat_mult = 3 if use_z else 2  # [x, y, z] or [x, y]

        self.input_embed = nn.Linear(n_classes + 1, dim, bias=False)
        self.y_init = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
        if use_z:
            self.z_init = nn.Parameter(torch.randn(1, 1, dim) * 0.02)

        self.layers = nn.ModuleList([
            TRMLayer(dim * feat_mult, seq_len=seq_len, use_attention=use_attention)
            for _ in range(n_layers)
        ])

        self.split_proj_y = nn.Linear(dim * feat_mult, dim, bias=False)
        if use_z:
            self.split_proj_z = nn.Linear(dim * feat_mult, dim, bias=False)
        self.output_head = nn.Linear(dim, n_classes)

    def embed_input(self, x):
        B = x.shape[0]
        x_flat = x.reshape(B, -1)
        x_onehot = F.one_hot(x_flat.long(), num_classes=self.n_classes + 1).float()
        return self.input_embed(x_onehot)

    def recurse(self, x_emb, y, z=None):
        if self.use_z and z is not None:
            combined = torch.cat([x_emb, y, z], dim=-1)
        else:
            combined = torch.cat([x_emb, y], dim=-1)
        for layer in self.layers:
            combined = layer(combined)
        y_new = self.split_proj_y(combined)
        z_new = self.split_proj_z(combined) if self.use_z else None
        return y_new, z_new

    def forward_with_supervision(self, x, T=3, n=4, full_backprop=True):
        B = x.shape[0]
        seq_len = self.grid_size * self.grid_size
        x_emb = self.embed_input(x)
        y = self.y_init.expand(B, seq_len, -1)
        z = self.z_init.expand(B, seq_len, -1) if self.use_z else None

        checkpoints = []
        for t in range(T):
            for i in range(n - 1):
                if full_backprop:
                    y, z = self.recurse(x_emb, y, z)
                else:
                    # Stop-grad: detach y and z to prevent gradient flow
                    y_d = y.detach()
                    z_d = z.detach() if z is not None else None
                    y, z = self.recurse(x_emb, y_d, z_d)

            # Last recursion always has gradients
            y, z = self.recurse(x_emb, y, z)
            logits = self.output_head(y)
            checkpoints.append(logits)

        return checkpoints

print("Ablation model ready!")

### 4.2 Data Generation

In [None]:
def is_valid_4x4(grid, r, c, num):
    if num in grid[r, :]: return False
    if num in grid[:, c]: return False
    box_r, box_c = 2*(r//2), 2*(c//2)
    if num in grid[box_r:box_r+2, box_c:box_c+2]: return False
    return True

def solve_4x4(grid):
    for r in range(4):
        for c in range(4):
            if grid[r, c] == 0:
                nums = list(range(1, 5))
                np.random.shuffle(nums)
                for num in nums:
                    if is_valid_4x4(grid, r, c, num):
                        grid[r, c] = num
                        if solve_4x4(grid): return True
                        grid[r, c] = 0
                return False
    return True

def generate_dataset(n, n_remove=8, seed=42):
    np.random.seed(seed)
    puzzles, solutions = [], []
    for _ in range(n):
        grid = np.zeros((4,4), dtype=np.int64)
        solve_4x4(grid)
        sol = grid.copy()
        idxs = np.random.choice(16, size=n_remove, replace=False)
        puz = sol.copy()
        for idx in idxs: puz[idx//4, idx%4] = 0
        puzzles.append(puz); solutions.append(sol)
    return np.array(puzzles), np.array(solutions)

from torch.utils.data import TensorDataset, DataLoader

train_puz, train_sol = generate_dataset(1000, n_remove=8, seed=42)
test_puz, test_sol = generate_dataset(200, n_remove=8, seed=999)

train_loader = DataLoader(
    TensorDataset(torch.tensor(train_puz), torch.tensor(train_sol)),
    batch_size=64, shuffle=True
)
test_loader = DataLoader(
    TensorDataset(torch.tensor(test_puz), torch.tensor(test_sol)),
    batch_size=64, shuffle=False
)
print(f"Train: {len(train_puz)} puzzles | Test: {len(test_puz)} puzzles")

### 4.3 Unified Training Function

In [None]:
def run_experiment(config, n_epochs=25, verbose=True):
    """
    Run a single ablation experiment.

    config: dict with keys:
        name, dim, n_layers, use_attention, use_z, T, n, full_backprop, use_ema
    """
    name = config['name']
    dim = config.get('dim', 48)
    n_layers = config.get('n_layers', 2)
    use_attention = config.get('use_attention', False)
    use_z = config.get('use_z', True)
    T = config.get('T', 3)
    n = config.get('n', 4)
    full_backprop = config.get('full_backprop', True)
    use_ema = config.get('use_ema', True)

    model = AblationTRM(
        n_classes=4, grid_size=4, dim=dim,
        n_layers=n_layers, use_attention=use_attention, use_z=use_z
    ).to(device)

    n_params = sum(p.numel() for p in model.parameters())
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)

    # EMA setup
    ema_shadow = None
    if use_ema:
        ema_shadow = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad}

    history = []
    start_time = time.time()

    for epoch in range(n_epochs):
        # Train
        model.train()
        for bx, by in train_loader:
            bx, by = bx.to(device), by.to(device)
            mask = (bx.reshape(bx.shape[0], -1) == 0)
            targets = by.reshape(by.shape[0], -1)

            optimizer.zero_grad()
            checkpoints = model.forward_with_supervision(bx, T=T, n=n, full_backprop=full_backprop)

            loss = 0
            for logits in checkpoints:
                targets_0idx = (targets - 1).clamp(min=0)
                per_elem = F.cross_entropy(logits.reshape(-1, 4), targets_0idx.reshape(-1), reduction='none')
                loss = loss + (per_elem * mask.reshape(-1).float()).sum() / mask.float().sum().clamp(min=1)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            if use_ema and ema_shadow:
                with torch.no_grad():
                    for n_param, p in model.named_parameters():
                        if p.requires_grad:
                            ema_shadow[n_param] = 0.999 * ema_shadow[n_param] + 0.001 * p.data

        # Evaluate (with EMA if enabled)
        if use_ema and ema_shadow:
            backup = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad}
            for n_param, p in model.named_parameters():
                if p.requires_grad: p.data = ema_shadow[n_param]

        model.eval()
        correct, total, puzzles_correct, puzzles_total = 0, 0, 0, 0
        with torch.no_grad():
            for bx, by in test_loader:
                bx, by = bx.to(device), by.to(device)
                mask = (bx.reshape(bx.shape[0], -1) == 0)
                targets = by.reshape(by.shape[0], -1)
                checkpoints = model.forward_with_supervision(bx, T=T, n=n, full_backprop=True)
                preds = checkpoints[-1].argmax(dim=-1) + 1
                c = (preds == targets) & mask
                correct += c.sum().item()
                total += mask.sum().item()
                for i in range(bx.shape[0]):
                    m = mask[i]
                    if m.sum() > 0:
                        puzzles_correct += c[i][m].all().item()
                    puzzles_total += 1

        if use_ema and ema_shadow:
            for n_param, p in model.named_parameters():
                if p.requires_grad: p.data = backup[n_param]

        cell_acc = 100 * correct / max(total, 1)
        puzzle_acc = 100 * puzzles_correct / max(puzzles_total, 1)
        history.append(puzzle_acc)

    elapsed = time.time() - start_time
    final_acc = history[-1]

    if verbose:
        eff_depth = T * (n+1) * n_layers if full_backprop else T * n_layers
        print(f"  {name:<35s} | Params: {n_params:>7,} | Eff.Depth: {eff_depth:>3} | "
              f"Puzzle Acc: {final_acc:>5.1f}% | Time: {elapsed:.0f}s")

    return {
        'name': name, 'final_acc': final_acc, 'n_params': n_params,
        'history': history, 'time': elapsed, 'config': config
    }

In [None]:
#@title üéß Listen: Todo Rankings
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_04_todo_rankings.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.4 Your Turn: Predict the Rankings

Before running the experiments, make your predictions!

### TODO: Rank the ablation impacts

In [None]:
# ============ TODO ============
# Based on your understanding from Notebooks 1-3, rank these ablations
# from MOST impactful to LEAST impactful (which removal hurts the most?).
#
# Assign ranks 1-5 (1 = most impactful removal, 5 = least impactful):
# ==============================

your_rankings = {
    'No full backprop':     ???,  # Rank 1-5
    'Attention vs MLP':     ???,  # Rank 1-5
    '4 layers vs 2+recur':  ???,  # Rank 1-5
    'No z feature':         ???,  # Rank 1-5
    'No EMA':               ???,  # Rank 1-5
}

print("Your predicted impact rankings (1=most impactful removal):")
for name, rank in sorted(your_rankings.items(), key=lambda x: x[1]):
    print(f"  #{rank}: {name}")
print("\nLet's see if you're right! Running experiments now...\n")

In [None]:
#@title üéß Listen: Experiments
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_05_experiments.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.5 Running the Ablation Experiments

Now let us run the experiments. We will test 6 configurations:

In [None]:
print("=" * 90)
print("  ABLATION STUDY ‚Äî 4√ó4 Sudoku")
print("=" * 90)
print()

experiments = [
    # Baseline: full TRM configuration
    {'name': '1. Full TRM (baseline)',
     'n_layers': 2, 'use_attention': False, 'use_z': True,
     'T': 3, 'n': 4, 'full_backprop': True, 'use_ema': True},

    # Ablation: stop-grad (no full backprop)
    {'name': '2. No full backprop (stop-grad)',
     'n_layers': 2, 'use_attention': False, 'use_z': True,
     'T': 3, 'n': 4, 'full_backprop': False, 'use_ema': True},

    # Ablation: attention instead of MLP
    {'name': '3. Attention (instead of MLP)',
     'n_layers': 2, 'use_attention': True, 'use_z': True,
     'T': 3, 'n': 4, 'full_backprop': True, 'use_ema': True},

    # Ablation: 4 layers with less recursion
    {'name': '4. 4 layers, n=2 recursions',
     'n_layers': 4, 'use_attention': False, 'use_z': True,
     'T': 3, 'n': 2, 'full_backprop': True, 'use_ema': True},

    # Ablation: no z feature
    {'name': '5. No z (solution only)',
     'n_layers': 2, 'use_attention': False, 'use_z': False,
     'T': 3, 'n': 4, 'full_backprop': True, 'use_ema': True},

    # Ablation: no EMA
    {'name': '6. No EMA',
     'n_layers': 2, 'use_attention': False, 'use_z': True,
     'T': 3, 'n': 4, 'full_backprop': True, 'use_ema': False},
]

results = []
for config in experiments:
    torch.manual_seed(42)
    np.random.seed(42)
    result = run_experiment(config, n_epochs=25)
    results.append(result)

print()
print("=" * 90)

In [None]:
# üìä Bar chart of ablation results
fig, ax = plt.subplots(1, 1, figsize=(12, 6))

names = [r['name'] for r in results]
accs = [r['final_acc'] for r in results]
colors = ['#2e7d32'] + ['#ef5350', '#ff9800', '#42a5f5', '#ab47bc', '#78909c']

bars = ax.barh(range(len(names)), accs, color=colors, edgecolor='gray', linewidth=1.5, height=0.6)

# Add value labels
for i, (bar, acc) in enumerate(zip(bars, accs)):
    ax.text(acc + 0.5, i, f'{acc:.1f}%', va='center', fontsize=12, fontweight='bold')

ax.set_yticks(range(len(names)))
ax.set_yticklabels(names, fontsize=11)
ax.set_xlabel('Test Puzzle Accuracy (%)', fontsize=12)
ax.set_title('Ablation Study: What Matters in TRM?', fontsize=14, fontweight='bold')
ax.invert_yaxis()
ax.set_xlim(0, max(accs) + 10)
ax.grid(True, alpha=0.3, axis='x')

plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Curves
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_06_curves.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# üìä Training curves for all experiments
fig, ax = plt.subplots(1, 1, figsize=(12, 6))

colors_line = ['#2e7d32', '#ef5350', '#ff9800', '#42a5f5', '#ab47bc', '#78909c']
for result, color in zip(results, colors_line):
    ax.plot(result['history'], label=result['name'], linewidth=2, color=color)

ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Test Puzzle Accuracy (%)', fontsize=12)
ax.set_title('Training Progression Across Ablation Experiments', fontsize=14, fontweight='bold')
ax.legend(fontsize=9, loc='lower right')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Todo Custom
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_07_todo_custom.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 5. üîß Your Turn

### TODO: Design Your Own Ablation

Choose one hypothesis and test it by designing a new configuration.

In [None]:
# ============ TODO ============
# Pick ONE hypothesis and create a config to test it:
#
# Hypothesis A: "More recursion always helps"
#   ‚Üí Test n=8 vs n=4 vs n=2 (keep everything else the same)
#
# Hypothesis B: "Bigger hidden dimension helps"
#   ‚Üí Test dim=32 vs dim=48 vs dim=64
#
# Hypothesis C: "Deep supervision steps matter"
#   ‚Üí Test T=1 vs T=3 vs T=5
#
# Write your config below:
# ==============================

my_configs = [
    # Example: testing more recursion
    {'name': 'Custom: n=2 recursions',
     'n_layers': 2, 'use_attention': False, 'use_z': True,
     'T': 3, 'n': 2, 'full_backprop': True, 'use_ema': True},

    {'name': 'Custom: n=4 recursions (baseline)',
     'n_layers': 2, 'use_attention': False, 'use_z': True,
     'T': 3, 'n': 4, 'full_backprop': True, 'use_ema': True},

    # YOUR CONFIG HERE ‚Äî test n=8 or dim=64 or T=5
    ???
]

print("Running your custom ablation...\n")
my_results = []
for config in my_configs:
    torch.manual_seed(42)
    np.random.seed(42)
    result = run_experiment(config, n_epochs=25)
    my_results.append(result)

In [None]:
# ‚úÖ Verification: Plot your results
fig, ax = plt.subplots(1, 1, figsize=(10, 4))
for result in my_results:
    ax.plot(result['history'], label=result['name'], linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Test Puzzle Accuracy (%)', fontsize=12)
ax.set_title('Your Custom Ablation Results', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Print summary
print("\nYour ablation results:")
for r in my_results:
    print(f"  {r['name']}: {r['final_acc']:.1f}%")

In [None]:
#@title üéß Listen: Big Picture
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_08_big_picture.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 6. Putting It All Together ‚Äî The Parameter vs Compute Trade-off

In [None]:
# üìä The big picture: parameters vs accuracy
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Parameters vs accuracy
params = [r['n_params'] for r in results]
accs = [r['final_acc'] for r in results]
colors_scatter = ['#2e7d32', '#ef5350', '#ff9800', '#42a5f5', '#ab47bc', '#78909c']

for i, (p, a, c, r) in enumerate(zip(params, accs, colors_scatter, results)):
    axes[0].scatter(p, a, c=c, s=200, edgecolors='gray', linewidth=1.5, zorder=5)
    axes[0].annotate(r['name'].split('.')[0] + '.',
                     (p, a), textcoords="offset points",
                     xytext=(10, 5), fontsize=9)

axes[0].set_xlabel('Parameters', fontsize=12)
axes[0].set_ylabel('Test Puzzle Accuracy (%)', fontsize=12)
axes[0].set_title('Parameters vs Accuracy', fontsize=13, fontweight='bold')
axes[0].grid(True, alpha=0.3)

# Effective depth vs accuracy
eff_depths = []
for r in results:
    c = r['config']
    if c['full_backprop']:
        ed = c['T'] * (c['n'] + 1) * c['n_layers']
    else:
        ed = c['T'] * c['n_layers']
    eff_depths.append(ed)

for i, (d, a, c, r) in enumerate(zip(eff_depths, accs, colors_scatter, results)):
    axes[1].scatter(d, a, c=c, s=200, edgecolors='gray', linewidth=1.5, zorder=5)
    axes[1].annotate(r['name'].split('.')[0] + '.',
                     (d, a), textcoords="offset points",
                     xytext=(10, 5), fontsize=9)

axes[1].set_xlabel('Effective Depth (layers)', fontsize=12)
axes[1].set_ylabel('Test Puzzle Accuracy (%)', fontsize=12)
axes[1].set_title('Effective Depth vs Accuracy', fontsize=13, fontweight='bold')
axes[1].grid(True, alpha=0.3)

plt.suptitle('The Parameters vs Compute Trade-off', fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

print("üí° Key insight: Effective depth (recursion) matters MORE than parameter count!")
print("   The baseline has fewer parameters than config 4 but higher accuracy,")
print("   because it achieves greater effective depth through more recursion.")

In [None]:
#@title üéß Listen: Report
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_09_report.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 7. üéØ Final Output ‚Äî Complete Ablation Report

In [None]:
# Generate a comprehensive ablation report
print("=" * 80)
print("  ABLATION STUDY REPORT ‚Äî Tiny Recursive Models on 4√ó4 Sudoku")
print("=" * 80)
print()

baseline_acc = results[0]['final_acc']
print(f"  Baseline accuracy: {baseline_acc:.1f}%")
print(f"  (2 layers, MLP mixing, with z, T=3, n=4, full backprop, EMA)")
print()
print(f"  {'Experiment':<35s} | {'Accuracy':>8s} | {'Œî vs baseline':>13s} | {'Key finding'}")
print("  " + "-" * 100)

findings = [
    "Reference configuration",
    "Full backprop is critical for learning",
    "MLP beats attention on small fixed grids",
    "More recursion > more layers",
    "Reasoning scratchpad z is essential",
    "EMA stabilizes training"
]

for r, finding in zip(results, findings):
    delta = r['final_acc'] - baseline_acc
    delta_str = f"{'+'if delta>=0 else ''}{delta:.1f}%"
    print(f"  {r['name']:<35s} | {r['final_acc']:>7.1f}% | {delta_str:>13s} | {finding}")

print()
print("  " + "=" * 100)
print()

# Rank by impact
impacts = [(r['name'], baseline_acc - r['final_acc'], findings[i])
           for i, r in enumerate(results[1:])]
impacts.sort(key=lambda x: x[1], reverse=True)

print("  Components ranked by impact (removing each one):")
for name, impact, finding in impacts:
    if impact > 0:
        print(f"    {impact:+.1f}% ‚Äî {name}: {finding}")
    else:
        print(f"    {impact:+.1f}% ‚Äî {name}: {finding}")

In [None]:
# üìä Final visualization: impact bar chart
fig, ax = plt.subplots(1, 1, figsize=(10, 5))

names_short = [
    'No full backprop',
    'Attention\n(vs MLP)',
    '4 layers\n(vs 2+recursion)',
    'No z feature',
    'No EMA'
]
impacts_val = [baseline_acc - r['final_acc'] for r in results[1:]]
bar_colors = ['#ef5350' if v > 0 else '#66bb6a' for v in impacts_val]

bars = ax.bar(range(len(names_short)), impacts_val, color=bar_colors,
              edgecolor='gray', linewidth=1.5, width=0.6)

for bar, val in zip(bars, impacts_val):
    y_pos = bar.get_height() + 0.3 if val > 0 else bar.get_height() - 1
    ax.text(bar.get_x() + bar.get_width()/2, y_pos,
            f'{val:+.1f}%', ha='center', fontsize=12, fontweight='bold')

ax.set_xticks(range(len(names_short)))
ax.set_xticklabels(names_short, fontsize=10)
ax.set_ylabel('Accuracy Drop (percentage points)', fontsize=12)
ax.set_title('Impact of Removing Each Component\n(Higher = More Important)',
             fontsize=14, fontweight='bold')
ax.axhline(y=0, color='black', linewidth=0.5)
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\nüéâ Congratulations! You have completed the ablation study.")
print("   You now understand WHY each component of TRM matters, not just WHAT it does.")
print(f"\n   The key lesson: recursive depth (thinking longer) beats parameter count (thinking bigger).")

In [None]:
#@title üéß Listen: Closing
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_10_closing.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 8. Reflection and Next Steps

### üí° Key Takeaways

1. **Full backpropagation** is the single most important factor ‚Äî it enables the model to learn from the full recursion chain
2. **MLP mixing** beats attention for small, fixed-size grids ‚Äî simpler is better when the context fits in a small matrix
3. **More recursion with fewer layers** beats more layers with less recursion ‚Äî computational depth without parameter growth
4. **The reasoning scratchpad z** is essential for generalization ‚Äî the model needs internal state that is not directly part of the answer
5. **EMA** provides stability, especially on small datasets

### ü§î Reflection Questions

1. Why does removing full backpropagation hurt so much? (Hint: think about what the model can learn from 1-step gradients vs multi-step gradients.)
2. The paper's Sudoku-Extreme results show +30.9% from full backprop ‚Äî much larger than our 4√ó4 result. Why would the impact be larger on harder puzzles?
3. If you had unlimited compute but only 1,000 training examples, how would you design the TRM? What T, n, and n_layers would you choose?

### üèÜ Optional Challenges

1. **Recursion budget ablation:** Fix the total number of forward passes (e.g., 12) and vary the split between T and n. Is T=3, n=4 optimal, or would T=2, n=6 be better?
2. **Data efficiency:** Train with 100, 500, and 1000 examples. At what data size does EMA stop helping?
3. **Transfer learning:** Train on easy puzzles (4 empty cells), then fine-tune on hard puzzles (10 empty cells). Does pre-training help?

### Series Summary

Across these 4 notebooks, you have:
1. **Understood** recursive reasoning through constraint propagation
2. **Built** the complete TRM architecture from scratch (RMSNorm, SwiGLU, RoPE, MLP/Attention variants)
3. **Trained** TRM with deep supervision, prediction loss, halting loss, and EMA
4. **Validated** each design choice through systematic ablation experiments

The central lesson of Tiny Recursive Models: **Less is More.** A tiny model that thinks recursively outperforms giant models that think once. Recursive depth beats parameter count. Thinking longer beats thinking bigger.

In [None]:
#@title üí¨ AI Teaching Assistant ‚Äî Click ‚ñ∂ to start
#@markdown This AI chatbot reads your notebook and can answer questions about any concept, code, or exercise.

import json as _json
import requests as _requests
from google.colab import output as _output
from IPython.display import display, HTML as _HTML, Markdown as _Markdown

# --- Read notebook content for context ---
def _get_notebook_context():
    try:
        from google.colab import _message
        nb = _message.blocking_request("get_ipynb", request="", timeout_sec=10)
        cells = nb.get("ipynb", {}).get("cells", [])
        parts = []
        for cell in cells:
            src = "".join(cell.get("source", []))
            tags = cell.get("metadata", {}).get("tags", [])
            if "chatbot" in tags:
                continue
            if src.strip():
                ct = cell.get("cell_type", "unknown")
                parts.append(f"[{ct.upper()}]\n{src}")
        return "\n\n---\n\n".join(parts)
    except Exception:
        return "Notebook content unavailable."

_NOTEBOOK_CONTEXT = _get_notebook_context()
_CHAT_HISTORY = []
_API_URL = "https://course-creator-brown.vercel.app/api/chat"

def _notebook_chat(question):
    global _CHAT_HISTORY
    try:
        resp = _requests.post(_API_URL, json={
            'question': question,
            'context': _NOTEBOOK_CONTEXT[:100000],
            'history': _CHAT_HISTORY[-10:],
        }, timeout=60)
        data = resp.json()
        answer = data.get('answer', 'Sorry, I could not generate a response.')
        _CHAT_HISTORY.append({'role': 'user', 'content': question})
        _CHAT_HISTORY.append({'role': 'assistant', 'content': answer})
        return answer
    except Exception as e:
        return f'Error connecting to teaching assistant: {str(e)}'

_output.register_callback('notebook_chat', _notebook_chat)

def ask(question):
    """Ask the AI teaching assistant a question about this notebook."""
    answer = _notebook_chat(question)
    display(_Markdown(answer))

print("\u2705 AI Teaching Assistant is ready!")
print("\U0001f4a1 Use the chat below, or call ask(\'your question\') in any cell.")

# --- Display chat widget ---
display(_HTML('''<style>
  .vc-wrap{font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,sans-serif;max-width:100%;border-radius:16px;overflow:hidden;box-shadow:0 4px 24px rgba(0,0,0,.12);background:#fff;border:1px solid #e5e7eb}
  .vc-hdr{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;padding:16px 20px;display:flex;align-items:center;gap:12px}
  .vc-avatar{width:42px;height:42px;background:rgba(255,255,255,.2);border-radius:50%;display:flex;align-items:center;justify-content:center;font-size:22px}
  .vc-hdr h3{font-size:16px;font-weight:600;margin:0}
  .vc-hdr p{font-size:12px;opacity:.85;margin:2px 0 0}
  .vc-msgs{height:420px;overflow-y:auto;padding:16px;background:#f8f9fb;display:flex;flex-direction:column;gap:10px}
  .vc-msg{display:flex;flex-direction:column;animation:vc-fade .25s ease}
  .vc-msg.user{align-items:flex-end}
  .vc-msg.bot{align-items:flex-start}
  .vc-bbl{max-width:85%;padding:10px 14px;border-radius:16px;font-size:14px;line-height:1.55;word-wrap:break-word}
  .vc-msg.user .vc-bbl{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;border-bottom-right-radius:4px}
  .vc-msg.bot .vc-bbl{background:#fff;color:#1a1a2e;border:1px solid #e8e8e8;border-bottom-left-radius:4px}
  .vc-bbl code{background:rgba(0,0,0,.07);padding:2px 6px;border-radius:4px;font-size:13px;font-family:'Fira Code',monospace}
  .vc-bbl pre{background:#1e1e2e;color:#cdd6f4;padding:12px;border-radius:8px;overflow-x:auto;margin:8px 0;font-size:13px}
  .vc-bbl pre code{background:none;padding:0;color:inherit}
  .vc-bbl h3,.vc-bbl h4{margin:10px 0 4px;font-size:15px}
  .vc-bbl ul,.vc-bbl ol{margin:4px 0;padding-left:20px}
  .vc-bbl li{margin:2px 0}
  .vc-chips{display:flex;flex-wrap:wrap;gap:8px;padding:0 16px 12px;background:#f8f9fb}
  .vc-chip{background:#fff;border:1px solid #d1d5db;border-radius:20px;padding:6px 14px;font-size:12px;cursor:pointer;transition:all .15s;color:#4b5563}
  .vc-chip:hover{border-color:#667eea;color:#667eea;background:#f0f0ff}
  .vc-input{display:flex;padding:12px 16px;background:#fff;border-top:1px solid #eee;gap:8px}
  .vc-input input{flex:1;padding:10px 16px;border:2px solid #e8e8e8;border-radius:24px;font-size:14px;outline:none;transition:border-color .2s}
  .vc-input input:focus{border-color:#667eea}
  .vc-input button{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;border:none;border-radius:50%;width:42px;height:42px;cursor:pointer;display:flex;align-items:center;justify-content:center;font-size:18px;transition:transform .1s}
  .vc-input button:hover{transform:scale(1.05)}
  .vc-input button:disabled{opacity:.5;cursor:not-allowed;transform:none}
  .vc-typing{display:flex;gap:5px;padding:4px 0}
  .vc-typing span{width:8px;height:8px;background:#667eea;border-radius:50%;animation:vc-bounce 1.4s infinite ease-in-out}
  .vc-typing span:nth-child(2){animation-delay:.2s}
  .vc-typing span:nth-child(3){animation-delay:.4s}
  @keyframes vc-bounce{0%,80%,100%{transform:scale(0)}40%{transform:scale(1)}}
  @keyframes vc-fade{from{opacity:0;transform:translateY(8px)}to{opacity:1;transform:translateY(0)}}
  .vc-note{text-align:center;font-size:11px;color:#9ca3af;padding:8px 16px 12px;background:#fff}
</style>
<div class="vc-wrap">
  <div class="vc-hdr">
    <div class="vc-avatar">&#129302;</div>
    <div>
      <h3>Vizuara Teaching Assistant</h3>
      <p>Ask me anything about this notebook</p>
    </div>
  </div>
  <div class="vc-msgs" id="vcMsgs">
    <div class="vc-msg bot">
      <div class="vc-bbl">&#128075; Hi! I've read through this entire notebook. Ask me about any concept, code block, or exercise &mdash; I'm here to help you learn!</div>
    </div>
  </div>
  <div class="vc-chips" id="vcChips">
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Explain the main concept</span>
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Help with the TODO exercise</span>
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Summarize what I learned</span>
  </div>
  <div class="vc-input">
    <input type="text" id="vcIn" placeholder="Ask about concepts, code, exercises..." />
    <button id="vcSend" onclick="vcSendMsg()">&#10148;</button>
  </div>
  <div class="vc-note">AI-generated &middot; Verify important information &middot; <a href="#" onclick="vcClear();return false" style="color:#667eea">Clear chat</a></div>
</div>
<script>
(function(){
  var msgs=document.getElementById('vcMsgs'),inp=document.getElementById('vcIn'),
      btn=document.getElementById('vcSend'),chips=document.getElementById('vcChips');

  function esc(s){var d=document.createElement('div');d.textContent=s;return d.innerHTML}

  function md(t){
    return t
      .replace(/```(\w*)\n([\s\S]*?)```/g,function(_,l,c){return '<pre><code>'+esc(c)+'</code></pre>'})
      .replace(/`([^`]+)`/g,'<code>$1</code>')
      .replace(/\*\*([^*]+)\*\*/g,'<strong>$1</strong>')
      .replace(/\*([^*]+)\*/g,'<em>$1</em>')
      .replace(/^#### (.+)$/gm,'<h4>$1</h4>')
      .replace(/^### (.+)$/gm,'<h4>$1</h4>')
      .replace(/^## (.+)$/gm,'<h3>$1</h3>')
      .replace(/^\d+\. (.+)$/gm,'<li>$1</li>')
      .replace(/^- (.+)$/gm,'<li>$1</li>')
      .replace(/\n\n/g,'<br><br>')
      .replace(/\n/g,'<br>');
  }

  function addMsg(text,isUser){
    var m=document.createElement('div');m.className='vc-msg '+(isUser?'user':'bot');
    var b=document.createElement('div');b.className='vc-bbl';
    b.innerHTML=isUser?esc(text):md(text);
    m.appendChild(b);msgs.appendChild(m);msgs.scrollTop=msgs.scrollHeight;
  }

  function showTyping(){
    var m=document.createElement('div');m.className='vc-msg bot';m.id='vcTyping';
    m.innerHTML='<div class="vc-bbl"><div class="vc-typing"><span></span><span></span><span></span></div></div>';
    msgs.appendChild(m);msgs.scrollTop=msgs.scrollHeight;
  }

  function hideTyping(){var e=document.getElementById('vcTyping');if(e)e.remove()}

  window.vcSendMsg=function(){
    var q=inp.value.trim();if(!q)return;
    inp.value='';chips.style.display='none';
    addMsg(q,true);showTyping();btn.disabled=true;
    google.colab.kernel.invokeFunction('notebook_chat',[q],{})
      .then(function(r){
        hideTyping();
        var a=r.data['application/json'];
        addMsg(typeof a==='string'?a:JSON.stringify(a),false);
      })
      .catch(function(){
        hideTyping();
        addMsg('Sorry, I encountered an error. Please check your internet connection and try again.',false);
      })
      .finally(function(){btn.disabled=false;inp.focus()});
  };

  window.vcAsk=function(q){inp.value=q;vcSendMsg()};
  window.vcClear=function(){
    msgs.innerHTML='<div class="vc-msg bot"><div class="vc-bbl">&#128075; Chat cleared. Ask me anything!</div></div>';
    chips.style.display='flex';
  };

  inp.addEventListener('keypress',function(e){if(e.key==='Enter')vcSendMsg()});
  inp.focus();
})();
</script>'''))