In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1TR8B7sH0a1GpUcd-1quRGmL_zhGcJtec", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/02_00_intro.mp3"))

In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

# üöÄ Building the Tiny Recursive Model Architecture from Scratch

*Part 2 of the Vizuara series on Tiny Recursive Models*
*Estimated time: 40 minutes*

# ü§ñ AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** ‚Äî it has already read this entire notebook and can help with concepts, code, and exercises.

**[üëâ Open AI Teaching Assistant](https://pods.vizuara.ai/courses/tiny-recursive-models/practice/2/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*


## 1. Why Does This Matter?

In the previous notebook, we saw that recursive constraint propagation can solve Sudoku puzzles that single-pass approaches cannot. But we hand-coded the rules ‚Äî "eliminate values in the same row, column, box."

What if the model could **learn** its own constraint-propagation rules from data?

That is exactly what the Tiny Recursive Model (TRM) does. With just **7 million parameters** and a **2-layer network** applied recursively, it learns to reason about abstract patterns ‚Äî achieving 87.4% on extreme Sudoku and 44.6% on ARC-AGI-1.

By the end of this notebook, you will have built the complete TRM architecture from scratch:
- RMSNorm normalization
- SwiGLU gated activation
- Rotary Position Embeddings (RoPE)
- The 2-layer recursive block (both MLP and Attention variants)
- The full recursion loop with solution (y) and reasoning (z) features

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name()}")

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [None]:
#@title üéß Listen: Intuition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_01_intuition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 2. Building Intuition

### The Architecture at a Glance

Think of TRM as a very small brain running in a loop:

1. **Inputs:** The puzzle (x), current solution guess (y), and reasoning scratchpad (z)
2. **Processing:** A tiny 2-layer network processes all three together
3. **Outputs:** Updated solution (y) and updated reasoning state (z)
4. **Loop:** Feed the outputs back as inputs and repeat

The key insight: **the same 2-layer network is used for every pass**. No new weights are learned between passes ‚Äî the model just gets better because each pass produces new context for the next.

### Why 2 Layers?

You might think more layers = better. But the TRM paper showed that **2 layers with more recursion** outperforms **4 layers with less recursion** by 7.9 percentage points. Why?

- More layers = more parameters = more overfitting risk (especially on small datasets)
- More recursion = more computational depth **without** adding parameters
- With T=3 supervision steps and n=6 recursions: effective depth = 3 √ó 7 √ó 2 = **42 layers**

The model gets the depth of a 42-layer transformer using only 2 layers of weights.

### ü§î Think About This

If you had a budget of 7 million parameters, would you rather have:
- (A) A 14-layer network, each layer with 500K parameters, processing input once?
- (B) A 2-layer network with 3.5M parameters each, processing input 21 times?

The TRM paper shows (B) wins dramatically. Can you intuit why?

In [None]:
#@title üéß Listen: Math
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_02_math.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 3. The Mathematics

### The Recursion Equations

The TRM recursion has two phases per step:

**Phase 1 ‚Äî Update reasoning state** (repeat $n$ times):
$$z \leftarrow \text{net}(x, y, z)$$

**Phase 2 ‚Äî Refine solution:**
$$y \leftarrow \text{net}(y, z)$$

Where $\text{net}$ is the same 2-layer network. Computationally: concatenate the inputs along the feature dimension, pass through two transformer-like layers, and split the output back into the appropriate feature sizes.

### Inside Each Layer

Each of the 2 layers contains:

**RMSNorm** ‚Äî normalizes activations using root-mean-square:
$$\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}} \odot \gamma$$

This is simpler than LayerNorm (no mean subtraction) but equally effective. Computationally: compute the RMS of the vector, divide each element by it, then scale by learnable parameter $\gamma$.

**SwiGLU activation** ‚Äî a gated activation function:
$$\text{SwiGLU}(x) = (\text{Swish}(xW_1) \odot xW_2) W_3$$

Where $\text{Swish}(x) = x \cdot \sigma(x)$. Computationally: project the input through two parallel linear layers, apply Swish to one of them, multiply element-wise, then project back down. The "gate" ($xW_2$) learns which features to pass through.

**Rotary Position Embeddings (RoPE):**
$$\text{RoPE}(x_m, m) = x_m \cdot \cos(m\theta) + \text{rotate}(x_m) \cdot \sin(m\theta)$$

This encodes position $m$ directly into the query/key vectors via rotation in 2D subspaces. Computationally: pair up consecutive dimensions, rotate each pair by an angle that depends on the position.

### Effective Depth

With $T$ supervision steps, $n$ recursions per step, and $n_\text{layers}$ layers per recursion:

$$\text{Effective depth} = T \times (n + 1) \times n_\text{layers}$$

For TRM: $3 \times 7 \times 2 = 42$ effective layers.

In [None]:
#@title üéß Listen: Rmsnorm
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_03_rmsnorm.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 4. Let's Build It ‚Äî Component by Component

### 4.1 RMSNorm

RMSNorm is the simplest component. It normalizes the input by its root-mean-square value.

In [None]:
class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization.
    Simpler than LayerNorm ‚Äî no mean subtraction, just scale by RMS.

    Used in: LLaMA, Gemini, TRM
    """
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))  # Learnable scale (gamma)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: (batch, seq_len, dim)
        # Compute RMS along the last dimension
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        # Normalize and scale
        return (x / rms) * self.weight

# Test it
norm = RMSNorm(dim=8)
test_input = torch.randn(2, 4, 8)  # batch=2, seq_len=4, dim=8
output = norm(test_input)
print(f"Input shape:  {test_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Input RMS:    {torch.sqrt(torch.mean(test_input**2, dim=-1))}")
print(f"Output RMS:   {torch.sqrt(torch.mean(output**2, dim=-1))}")

In [None]:
# üìä Visualize: RMSNorm stabilizes activations
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Generate data with varying scales
x = torch.randn(1, 10, 32)
x[:, 5:, :] *= 5  # Make some positions have much larger magnitude

x_normed = norm(x[:, :, :8])  # Apply to first 8 dims (matching norm dim)

axes[0].imshow(x[0].numpy(), cmap='RdBu', vmin=-5, vmax=5, aspect='auto')
axes[0].set_title("Before RMSNorm", fontsize=13, fontweight='bold')
axes[0].set_xlabel("Feature dimension")
axes[0].set_ylabel("Sequence position")

# For comparison, normalize all 32 dims
norm32 = RMSNorm(dim=32)
x_full_normed = norm32(x)
axes[1].imshow(x_full_normed[0].detach().numpy(), cmap='RdBu', vmin=-5, vmax=5, aspect='auto')
axes[1].set_title("After RMSNorm", fontsize=13, fontweight='bold')
axes[1].set_xlabel("Feature dimension")
axes[1].set_ylabel("Sequence position")

plt.suptitle("RMSNorm Stabilizes Activation Magnitudes", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Swiglu
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_04_swiglu.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.2 SwiGLU Activation

SwiGLU is a gated activation function. The "gate" learns which features to let through ‚Äî it acts like a learned filter.

In [None]:
class SwiGLU(nn.Module):
    """
    SwiGLU gated activation function.
    Projects input to 2 * hidden_dim, splits into two halves,
    applies Swish to one half, element-wise multiply, project back.

    Swish(x) = x * sigmoid(x)
    """
    def __init__(self, dim: int, hidden_dim: int = None):
        super().__init__()
        hidden_dim = hidden_dim or dim * 4  # Standard 4x expansion
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)  # For Swish path
        self.w2 = nn.Linear(dim, hidden_dim, bias=False)  # For gate path
        self.w3 = nn.Linear(hidden_dim, dim, bias=False)   # Project back

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Swish gate: x * sigmoid(x) applied to w1 projection
        swish_out = F.silu(self.w1(x))  # silu = Swish = x * sigmoid(x)
        # Gate: linear projection through w2
        gate = self.w2(x)
        # Element-wise multiply and project back
        return self.w3(swish_out * gate)

# Test it
swiglu = SwiGLU(dim=16, hidden_dim=32)
test_x = torch.randn(2, 4, 16)
out = swiglu(test_x)
print(f"Input shape:  {test_x.shape}")
print(f"Output shape: {out.shape}")
n_params = sum(p.numel() for p in swiglu.parameters())
print(f"Parameters:   {n_params:,}")

In [None]:
# üìä Visualize: Swish activation function
x_plot = torch.linspace(-5, 5, 200)
relu_y = F.relu(x_plot)
gelu_y = F.gelu(x_plot)
swish_y = F.silu(x_plot)

plt.figure(figsize=(10, 5))
plt.plot(x_plot.numpy(), relu_y.numpy(), label='ReLU', linewidth=2, alpha=0.7)
plt.plot(x_plot.numpy(), gelu_y.numpy(), label='GELU', linewidth=2, alpha=0.7)
plt.plot(x_plot.numpy(), swish_y.numpy(), label='Swish (SiLU)', linewidth=2.5, color='#e65100')
plt.xlabel('Input', fontsize=12)
plt.ylabel('Output', fontsize=12)
plt.title('Swish vs ReLU vs GELU ‚Äî Swish is Smooth and Allows Negative Signal', fontsize=13, fontweight='bold')
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.axhline(y=0, color='black', linewidth=0.5)
plt.axvline(x=0, color='black', linewidth=0.5)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Rope
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_05_rope.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.3 Rotary Position Embeddings (RoPE)

RoPE encodes position information by rotating pairs of dimensions. Position $m$ gets a rotation angle $m\theta_i$ for each dimension pair $i$.

In [None]:
class RotaryPositionEmbedding(nn.Module):
    """
    Rotary Position Embedding (RoPE).
    Encodes position by rotating consecutive dimension pairs.
    The rotation angle depends on the position ‚Äî closer positions have similar rotations.
    """
    def __init__(self, dim: int, max_seq_len: int = 1024, base: float = 10000.0):
        super().__init__()
        # Compute rotation frequencies for each dimension pair
        # theta_i = 1 / (base^(2i/dim)) for i = 0, 1, ..., dim/2 - 1
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        # Precompute sin/cos for all positions
        positions = torch.arange(max_seq_len).float()
        # angles shape: (max_seq_len, dim/2)
        angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)
        self.register_buffer('cos_cache', angles.cos())
        self.register_buffer('sin_cache', angles.sin())

    def forward(self, x: torch.Tensor, seq_len: int = None) -> torch.Tensor:
        """
        Apply rotary embeddings to input tensor.
        x shape: (batch, seq_len, dim)
        """
        if seq_len is None:
            seq_len = x.shape[1]
        cos = self.cos_cache[:seq_len]  # (seq_len, dim/2)
        sin = self.sin_cache[:seq_len]  # (seq_len, dim/2)

        # Split x into pairs of consecutive dimensions
        x1 = x[..., 0::2]  # Even indices
        x2 = x[..., 1::2]  # Odd indices

        # Apply rotation: [x1, x2] -> [x1*cos - x2*sin, x1*sin + x2*cos]
        rotated_x1 = x1 * cos - x2 * sin
        rotated_x2 = x1 * sin + x2 * cos

        # Interleave back
        out = torch.stack([rotated_x1, rotated_x2], dim=-1).flatten(-2)
        return out

# Test it
rope = RotaryPositionEmbedding(dim=16, max_seq_len=81)
test_x = torch.randn(1, 9, 16)
out = rope(test_x)
print(f"Input shape:  {test_x.shape}")
print(f"Output shape: {out.shape}")

In [None]:
# üìä Visualize: How RoPE rotations change with position
rope_vis = RotaryPositionEmbedding(dim=8, max_seq_len=20)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Show rotation angles for different dimension pairs
positions = np.arange(20)
for i in range(4):
    angles = positions * rope_vis.inv_freq[i].item()
    axes[0].plot(positions, np.cos(angles), label=f'Dim pair {i}', linewidth=2)

axes[0].set_xlabel('Position', fontsize=12)
axes[0].set_ylabel('cos(position √ó frequency)', fontsize=12)
axes[0].set_title('RoPE Rotation Patterns per Dimension Pair', fontsize=13, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Show how similar positions get similar embeddings
x_same = torch.ones(1, 20, 8)  # Same content at every position
x_rotated = rope_vis(x_same)

# Compute cosine similarity between all position pairs
sims = F.cosine_similarity(
    x_rotated[0].unsqueeze(0).expand(20, -1, -1),
    x_rotated[0].unsqueeze(1).expand(-1, 20, -1),
    dim=-1
)
im = axes[1].imshow(sims.detach().numpy(), cmap='viridis', vmin=-1, vmax=1)
axes[1].set_xlabel('Position', fontsize=12)
axes[1].set_ylabel('Position', fontsize=12)
axes[1].set_title('Position Similarity (Closer = More Similar)', fontsize=13, fontweight='bold')
plt.colorbar(im, ax=axes[1])

plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Mixer Attn
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_06_mixer_attn.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.4 The MLP Mixing Layer

For small, fixed-size grids (like 9√ó9 Sudoku = 81 tokens), TRM uses a simple linear layer instead of self-attention. This is a matrix multiplication of size [L, L] that learns how tokens should communicate.

In [None]:
class MLPMixer(nn.Module):
    """
    MLP token mixing layer ‚Äî replaces self-attention for fixed-size contexts.
    A simple [L, L] matrix that learns pairwise token interactions.
    Much cheaper than attention when seq_len is small and fixed.
    """
    def __init__(self, seq_len: int, dim: int):
        super().__init__()
        self.seq_len = seq_len
        # This is the key: a learnable [L, L] mixing matrix
        self.token_mix = nn.Linear(seq_len, seq_len, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: (batch, seq_len, dim)
        # Transpose to mix across token dimension, then transpose back
        x = x.transpose(1, 2)       # (batch, dim, seq_len)
        x = self.token_mix(x)        # (batch, dim, seq_len) ‚Äî mixing happens here
        x = x.transpose(1, 2)       # (batch, seq_len, dim)
        return x

# Test it
mixer = MLPMixer(seq_len=16, dim=32)
test_x = torch.randn(2, 16, 32)
out = mixer(test_x)
print(f"Input shape:  {test_x.shape}")
print(f"Output shape: {out.shape}")
n_params = sum(p.numel() for p in mixer.parameters())
print(f"Parameters:   {n_params:,} (just a {16}√ó{16} matrix!)")

### 4.5 Self-Attention (for variable-size contexts)

For larger, variable-size grids, TRM uses standard self-attention with RoPE.

In [None]:
class SelfAttention(nn.Module):
    """
    Standard multi-head self-attention with RoPE.
    Used for variable-length contexts (e.g., 30x30 mazes).
    """
    def __init__(self, dim: int, n_heads: int = 4, max_seq_len: int = 1024):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        assert dim % n_heads == 0

        self.qkv = nn.Linear(dim, 3 * dim, bias=False)
        self.out_proj = nn.Linear(dim, dim, bias=False)
        self.rope = RotaryPositionEmbedding(self.head_dim, max_seq_len)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, L, D = x.shape

        # Project to Q, K, V
        qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)  # Each: (B, L, n_heads, head_dim)

        # Apply RoPE to Q and K
        q = q.transpose(1, 2)  # (B, n_heads, L, head_dim)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Apply RoPE per head
        q_rope = self.rope(q.reshape(B * self.n_heads, L, self.head_dim))
        k_rope = self.rope(k.reshape(B * self.n_heads, L, self.head_dim))
        q = q_rope.reshape(B, self.n_heads, L, self.head_dim)
        k = k_rope.reshape(B, self.n_heads, L, self.head_dim)

        # Scaled dot-product attention
        scale = self.head_dim ** -0.5
        attn = (q @ k.transpose(-2, -1)) * scale  # (B, n_heads, L, L)
        attn = F.softmax(attn, dim=-1)

        # Apply attention to values
        out = (attn @ v)  # (B, n_heads, L, head_dim)
        out = out.transpose(1, 2).reshape(B, L, D)  # (B, L, D)
        return self.out_proj(out)

# Test it
attn = SelfAttention(dim=32, n_heads=4)
test_x = torch.randn(2, 16, 32)
out = attn(test_x)
print(f"Input shape:  {test_x.shape}")
print(f"Output shape: {out.shape}")

In [None]:
#@title üéß Listen: Trm Layer
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_07_trm_layer.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.6 The TRM Layer

Now we combine everything into a single TRM layer: RMSNorm ‚Üí Mixing (MLP or Attention) ‚Üí RMSNorm ‚Üí SwiGLU FFN.

In [None]:
class TRMLayer(nn.Module):
    """
    One layer of the Tiny Recursive Model.
    Architecture: RMSNorm ‚Üí Token Mixing ‚Üí Residual ‚Üí RMSNorm ‚Üí SwiGLU FFN ‚Üí Residual

    This is inspired by the Pre-Norm Transformer block but can use either
    MLP mixing (for fixed-size) or self-attention (for variable-size).
    """
    def __init__(self, dim: int, seq_len: int = None, use_attention: bool = False,
                 n_heads: int = 4, ffn_mult: int = 4):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.norm2 = RMSNorm(dim)

        # Token mixing: MLP or Attention
        if use_attention:
            self.mixer = SelfAttention(dim, n_heads=n_heads, max_seq_len=seq_len or 1024)
        else:
            assert seq_len is not None, "MLP mixer requires fixed seq_len"
            self.mixer = MLPMixer(seq_len, dim)

        # Feed-forward network with SwiGLU
        self.ffn = SwiGLU(dim, hidden_dim=dim * ffn_mult)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pre-norm architecture with residual connections
        x = x + self.mixer(self.norm1(x))  # Mixing + residual
        x = x + self.ffn(self.norm2(x))    # FFN + residual
        return x

# Test it
layer = TRMLayer(dim=32, seq_len=16, use_attention=False)
test_x = torch.randn(2, 16, 32)
out = layer(test_x)
print(f"Input shape:  {test_x.shape}")
print(f"Output shape: {out.shape}")
n_params = sum(p.numel() for p in layer.parameters())
print(f"Layer parameters: {n_params:,}")

In [None]:
#@title üéß Listen: Full Trm
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_08_full_trm.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.7 The Full TRM

Now for the main event ‚Äî the complete Tiny Recursive Model with the recursion loop.

In [None]:
class TinyRecursiveModel(nn.Module):
    """
    The complete Tiny Recursive Model (TRM).

    Architecture:
    - Input embedding: projects x into hidden dimension
    - Solution embedding: projects initial y (zeros) into hidden dimension
    - Reasoning embedding: projects initial z (zeros) into hidden dimension
    - 2-layer recursive block (applied n times per supervision step)
    - Output head: decodes y into class predictions

    Recursion loop:
    1. Concatenate [x, y, z] along feature dimension
    2. Pass through 2-layer block ‚Üí get updated features
    3. Split back into y and z
    4. Repeat n times (this is one supervision step)
    """
    def __init__(self, n_classes: int, grid_size: int, dim: int = 64,
                 n_layers: int = 2, use_attention: bool = False, n_heads: int = 4):
        super().__init__()
        self.dim = dim
        self.grid_size = grid_size
        seq_len = grid_size * grid_size

        # Input embedding: one-hot class ‚Üí dim
        self.input_embed = nn.Linear(n_classes + 1, dim, bias=False)  # +1 for "empty" token
        # Solution and reasoning are initialized as learned embeddings
        self.y_init = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
        self.z_init = nn.Parameter(torch.randn(1, 1, dim) * 0.02)

        # The 2-layer recursive block
        self.layers = nn.ModuleList([
            TRMLayer(dim * 3, seq_len=seq_len, use_attention=use_attention, n_heads=n_heads)
            for _ in range(n_layers)
        ])

        # Projection to split 3*dim back into dim for y and z
        self.split_proj_y = nn.Linear(dim * 3, dim, bias=False)
        self.split_proj_z = nn.Linear(dim * 3, dim, bias=False)

        # Output head: y ‚Üí class predictions
        self.output_head = nn.Linear(dim, n_classes)

        # Halting head: y ‚Üí confidence score
        self.halt_head = nn.Linear(dim, 1)

    def embed_input(self, x: torch.Tensor) -> torch.Tensor:
        """Convert integer grid to embedded features."""
        # x shape: (batch, grid_size, grid_size) with integer values 0..n_classes
        B = x.shape[0]
        x_flat = x.reshape(B, -1)  # (batch, seq_len)
        # One-hot encode
        x_onehot = F.one_hot(x_flat.long(), num_classes=self.output_head.out_features + 1).float()
        return self.input_embed(x_onehot)  # (batch, seq_len, dim)

    def recurse(self, x_emb: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> tuple:
        """
        One recursion step: z ‚Üê net(x, y, z) and y ‚Üê net(y, z).
        In practice, we concatenate [x, y, z], pass through layers, and split.
        """
        # Concatenate along feature dimension
        combined = torch.cat([x_emb, y, z], dim=-1)  # (batch, seq_len, 3*dim)

        # Pass through the 2-layer block
        for layer in self.layers:
            combined = layer(combined)

        # Split back into y and z updates
        y_new = self.split_proj_y(combined)
        z_new = self.split_proj_z(combined)

        return y_new, z_new

    def forward(self, x: torch.Tensor, n_recursions: int = 6) -> dict:
        """
        Full forward pass with n recursions.

        Args:
            x: input grid, shape (batch, grid_size, grid_size), integer values
            n_recursions: number of recursion iterations

        Returns:
            dict with 'logits', 'halt_logits', and 'y_history' (for visualization)
        """
        B = x.shape[0]
        seq_len = self.grid_size * self.grid_size

        # Embed input
        x_emb = self.embed_input(x)  # (batch, seq_len, dim)

        # Initialize y and z
        y = self.y_init.expand(B, seq_len, -1)  # (batch, seq_len, dim)
        z = self.z_init.expand(B, seq_len, -1)  # (batch, seq_len, dim)

        # Track y at each recursion step (for visualization)
        y_history = []

        # Recursion loop
        for step in range(n_recursions):
            y, z = self.recurse(x_emb, y, z)
            y_history.append(y.detach())

        # Decode final y into class predictions
        logits = self.output_head(y)  # (batch, seq_len, n_classes)
        halt_logits = self.halt_head(y).squeeze(-1)  # (batch, seq_len)

        return {
            'logits': logits,
            'halt_logits': halt_logits,
            'y_history': y_history
        }

# Create a TRM for 4x4 Sudoku (values 1-4, plus 0 for empty)
model = TinyRecursiveModel(
    n_classes=4,
    grid_size=4,
    dim=32,
    n_layers=2,
    use_attention=False
)

n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")
print(f"\nFor comparison:")
print(f"  TRM paper (Sudoku):  ~5M parameters")
print(f"  TRM paper (ARC-AGI): ~7M parameters")
print(f"  Our mini model:      {n_params:,} parameters")

In [None]:
# Test forward pass
test_grid = torch.randint(0, 5, (2, 4, 4))  # batch=2, 4x4 grid
output = model(test_grid, n_recursions=6)

print(f"Input shape:       {test_grid.shape}")
print(f"Logits shape:      {output['logits'].shape}  (batch, seq_len, n_classes)")
print(f"Halt logits shape: {output['halt_logits'].shape}  (batch, seq_len)")
print(f"Recursion steps:   {len(output['y_history'])}")

In [None]:
# üìä Visualize: How y evolves across recursion steps
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for step_idx in range(6):
    ax = axes[step_idx // 3][step_idx % 3]
    y_step = output['y_history'][step_idx][0]  # First sample

    # Show the prediction at this step
    logits_step = model.output_head(y_step)
    probs = F.softmax(logits_step, dim=-1).detach().numpy()

    # Show probability distribution as a heatmap
    im = ax.imshow(probs, cmap='Blues', aspect='auto', vmin=0, vmax=1)
    ax.set_title(f'Recursion Step {step_idx + 1}', fontsize=12, fontweight='bold')
    ax.set_xlabel('Class (1-4)')
    ax.set_ylabel('Cell position')
    ax.set_xticks(range(4))
    ax.set_xticklabels(['1', '2', '3', '4'])

plt.suptitle('Prediction Confidence Evolves Across Recursion Steps\n(untrained model ‚Äî random predictions)',
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Todo
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_09_todo.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 5. üîß Your Turn

### TODO: Implement the Attention Variant

The model above uses MLP mixing. Now implement the attention variant by modifying the model creation.

In [None]:
# ============ TODO ============
# Create a TRM that uses self-attention instead of MLP mixing.
# Hint: just change one parameter in the constructor.
# Then compare the parameter counts.
# ==============================

model_attn = TinyRecursiveModel(
    n_classes=4,
    grid_size=4,
    dim=32,
    n_layers=2,
    use_attention=???,  # YOUR CODE HERE
    n_heads=4
)

n_params_attn = sum(p.numel() for p in model_attn.parameters())
print(f"MLP variant parameters:       {n_params:,}")
print(f"Attention variant parameters: {n_params_attn:,}")
print(f"Difference: {n_params_attn - n_params:,} more parameters with attention")

In [None]:
# ‚úÖ Verification
assert n_params_attn > n_params, "‚ùå Attention variant should have more parameters than MLP variant"
print("‚úÖ Correct! Attention adds parameters for Q, K, V projections.")
print(f"   For a 4√ó4 grid (16 tokens), MLP mixing uses a 16√ó16={16*16} param matrix")
print(f"   Attention uses QKV projections: 3 √ó dim √ó dim = {3 * 32 * 32} params per head group")

### TODO: Count Effective Depth

In [None]:
# ============ TODO ============
# Calculate the effective depth for different configurations.
# Formula: effective_depth = T * (n + 1) * n_layers
#
# Fill in the values:
# ==============================

T = 3          # supervision steps
n = 6          # recursions per step
n_layers = 2   # layers per recursion

effective_depth = ???  # YOUR CODE HERE

print(f"Configuration: T={T}, n={n}, n_layers={n_layers}")
print(f"Effective depth: {effective_depth}")
print(f"This is equivalent to a {effective_depth}-layer transformer!")
print(f"But with only {n_layers} layers of unique weights.")

In [None]:
# ‚úÖ Verification
assert effective_depth == 42, f"‚ùå Expected 42, got {effective_depth}. Check formula: T √ó (n+1) √ó n_layers"
print("‚úÖ Correct! TRM achieves 42-layer depth with just 2 layers of weights.")

In [None]:
#@title üéß Listen: Final
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_10_final.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 6. Putting It All Together

Let us verify the complete model works end-to-end with a proper forward pass.

In [None]:
def count_parameters(model):
    """Count total and per-component parameters."""
    total = 0
    components = {}
    for name, param in model.named_parameters():
        n = param.numel()
        total += n
        component = name.split('.')[0]
        components[component] = components.get(component, 0) + n
    return total, components

total, components = count_parameters(model)
print("Parameter breakdown:")
print(f"{'Component':<20} {'Parameters':>12} {'Percentage':>10}")
print("-" * 44)
for comp, n in sorted(components.items(), key=lambda x: -x[1]):
    print(f"{comp:<20} {n:>12,} {100*n/total:>9.1f}%")
print("-" * 44)
print(f"{'TOTAL':<20} {total:>12,} {'100.0%':>10}")

In [None]:
# Full forward pass with a real-looking puzzle
# Create a batch of 4x4 grids with some zeros (empty cells)
batch_size = 4
grids = torch.zeros(batch_size, 4, 4, dtype=torch.long)
for b in range(batch_size):
    # Fill with a valid grid pattern then mask some cells
    base = torch.tensor([[1,2,3,4],[3,4,1,2],[2,1,4,3],[4,3,2,1]])
    mask = torch.rand(4, 4) > 0.4  # Keep ~60% of cells
    grids[b] = base * mask.long()

print("Input grids (0 = empty):")
for b in range(batch_size):
    print(f"\nGrid {b+1}:")
    print(grids[b].numpy())

# Forward pass with 6 recursions
output = model(grids, n_recursions=6)
predictions = output['logits'].argmax(dim=-1) + 1  # +1 because classes are 1-indexed

print(f"\nPredictions shape: {predictions.shape}")
print(f"(These are random since the model is untrained)")

In [None]:
# üìä Final architecture summary
print("=" * 60)
print("  TINY RECURSIVE MODEL ‚Äî ARCHITECTURE SUMMARY")
print("=" * 60)
print(f"""
  Grid size:        4√ó4 (16 cells)
  Classes:          4 (values 1-4)
  Hidden dim:       32
  Layers:           2 (shared across all recursions)
  Recursions:       6 per supervision step
  Supervision:      3 steps
  Effective depth:  42 layers

  Total parameters: {total:,}

  The same 2 layers process the input 21 times.
  No new weights ‚Äî just new information from each pass.
""")
print("=" * 60)

## 7. üéØ Final Output ‚Äî Architecture Visualization

In [None]:
# Create a visual summary of the TRM architecture
fig, ax = plt.subplots(1, 1, figsize=(14, 8))
ax.set_xlim(0, 14)
ax.set_ylim(0, 10)
ax.axis('off')

import matplotlib.patches as mpatches

# Draw the recursion loop
# Input box
rect = mpatches.FancyBboxPatch((0.5, 6), 2.5, 1.5, boxstyle="round,pad=0.2",
                                facecolor='#e3f2fd', edgecolor='#1565c0', linewidth=2)
ax.add_patch(rect)
ax.text(1.75, 6.75, 'Input (x)', ha='center', va='center', fontsize=12, fontweight='bold', color='#1565c0')

# y box
rect = mpatches.FancyBboxPatch((0.5, 4), 2.5, 1.5, boxstyle="round,pad=0.2",
                                facecolor='#c8e6c9', edgecolor='#2e7d32', linewidth=2)
ax.add_patch(rect)
ax.text(1.75, 4.75, 'Solution (y)', ha='center', va='center', fontsize=12, fontweight='bold', color='#2e7d32')

# z box
rect = mpatches.FancyBboxPatch((0.5, 2), 2.5, 1.5, boxstyle="round,pad=0.2",
                                facecolor='#fff3e0', edgecolor='#e65100', linewidth=2)
ax.add_patch(rect)
ax.text(1.75, 2.75, 'Reasoning (z)', ha='center', va='center', fontsize=12, fontweight='bold', color='#e65100')

# Arrows to network
for y_pos in [6.75, 4.75, 2.75]:
    ax.annotate('', xy=(4, 4.75), xytext=(3, y_pos),
                arrowprops=dict(arrowstyle='->', color='#555', lw=1.5))

# Network box
rect = mpatches.FancyBboxPatch((4, 2.5), 3.5, 5, boxstyle="round,pad=0.3",
                                facecolor='#e8eaf6', edgecolor='#3949ab', linewidth=2.5)
ax.add_patch(rect)
ax.text(5.75, 7, '2-Layer Network', ha='center', va='center', fontsize=13, fontweight='bold', color='#283593')
ax.text(5.75, 6.2, 'RMSNorm', ha='center', va='center', fontsize=10, color='#555')
ax.text(5.75, 5.5, 'Token Mixing', ha='center', va='center', fontsize=10, color='#555')
ax.text(5.75, 4.8, '(MLP or Attention)', ha='center', va='center', fontsize=9, color='#888')
ax.text(5.75, 4, 'RMSNorm', ha='center', va='center', fontsize=10, color='#555')
ax.text(5.75, 3.3, 'SwiGLU FFN', ha='center', va='center', fontsize=10, color='#555')

# Output arrows
ax.annotate('', xy=(9, 5.75), xytext=(7.5, 5.75),
            arrowprops=dict(arrowstyle='->', color='#2e7d32', lw=2))
ax.annotate('', xy=(9, 3.75), xytext=(7.5, 3.75),
            arrowprops=dict(arrowstyle='->', color='#e65100', lw=2))

# Updated y
rect = mpatches.FancyBboxPatch((9, 5), 3, 1.5, boxstyle="round,pad=0.2",
                                facecolor='#c8e6c9', edgecolor='#2e7d32', linewidth=2)
ax.add_patch(rect)
ax.text(10.5, 5.75, 'Updated y', ha='center', va='center', fontsize=12, fontweight='bold', color='#2e7d32')

# Updated z
rect = mpatches.FancyBboxPatch((9, 3), 3, 1.5, boxstyle="round,pad=0.2",
                                facecolor='#fff3e0', edgecolor='#e65100', linewidth=2)
ax.add_patch(rect)
ax.text(10.5, 3.75, 'Updated z', ha='center', va='center', fontsize=12, fontweight='bold', color='#e65100')

# Recursion arrow (loop back)
ax.annotate('', xy=(1.75, 3.5), xytext=(10.5, 2.8),
            arrowprops=dict(arrowstyle='->', color='#3949ab', lw=2.5,
                           connectionstyle='arc3,rad=0.3'))
ax.text(6, 1.5, 'Repeat n times', ha='center', va='center', fontsize=14,
        fontweight='bold', color='#3949ab', style='italic')

# Title
ax.text(7, 9.3, 'Tiny Recursive Model ‚Äî Architecture', ha='center', va='center',
        fontsize=16, fontweight='bold')
ax.text(7, 8.7, f'{total:,} parameters ‚Ä¢ 2 layers ‚Ä¢ applied {6} times per step',
        ha='center', va='center', fontsize=11, color='#666')

plt.tight_layout()
plt.show()

print("üéâ You have built the complete TRM architecture from scratch!")
print("   In the next notebook, we will train it with deep supervision.")

In [None]:
#@title üéß Listen: Closing
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_11_closing.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 8. Reflection and Next Steps

### üí° Key Takeaways

1. **RMSNorm** normalizes by root-mean-square (simpler than LayerNorm, equally effective)
2. **SwiGLU** is a gated activation ‚Äî learns which features to pass through
3. **RoPE** encodes position via rotation ‚Äî nearby positions get similar encodings
4. **MLP mixing** is cheaper than attention for fixed-size contexts
5. **Weight sharing** across recursions gives 42-layer depth from 2 layers of weights

### ü§î Reflection Questions

1. Why is the MLP mixer limited to fixed-size inputs, while attention handles variable sizes?
2. What role does the **z** (reasoning) feature play that **y** (solution) alone cannot handle? (Hint: think about information that is useful for reasoning but should not appear in the final answer.)
3. If we increased the hidden dimension from 32 to 256, how would the parameter count change? Would this help or hurt on a small dataset?

### üèÜ Optional Challenges

1. **Add skip connections** between recursion steps (connect y from step $i$ to step $i+2$)
2. **Implement a multi-head version** of the MLP mixer where different heads attend to different spatial neighborhoods
3. **Profile memory usage** ‚Äî how does memory scale with the number of recursions?

### What's Next

In Notebook 3, we will implement the **deep supervision training loop** ‚Äî the secret sauce that makes TRM work. We will train our model on 4√ó4 Sudoku and watch the recursion steps progressively solve puzzles.