In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="16tyhT-CfPx17iIUy9_IixJPaj2_P5b1I", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/03_00_intro.mp3"))

In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

# üöÄ Deep Supervision Training: Teaching a Tiny Model to Reason

*Part 3 of the Vizuara series on Tiny Recursive Models*
*Estimated time: 40 minutes*

## 1. Why Does This Matter?

In the previous notebook, we built the complete TRM architecture ‚Äî RMSNorm, SwiGLU, RoPE, and the recursion loop. But an untrained model just produces random outputs.

The key challenge: **how do you train a recursive model effectively?**

The naive approach ‚Äî run all recursions, check only the final answer ‚Äî creates unstable gradients that must flow through dozens of steps. The TRM paper's solution is **deep supervision**: check the model's answer at multiple intermediate points during recursion and provide correction at each.

By the end of this notebook, you will:
- Implement the prediction loss (cross-entropy) and halting loss (ACT)
- Build a deep supervision training loop
- Train TRM on 4√ó4 Sudoku puzzles
- **Watch** the model learn to solve puzzles through recursive refinement
- Visualize how prediction confidence sharpens across recursion steps

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)

In [None]:
#@title üéß Listen: Intuition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_01_intuition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 2. Building Intuition

### Why Deep Supervision?

Imagine you are teaching someone to solve Sudoku. Two approaches:

**Approach A (End-only supervision):** You let the student work for 30 minutes, then check only the final grid. If it is wrong, you say "try again" ‚Äî but they have no idea which of their 30 minutes of work went wrong.

**Approach B (Deep supervision):** Every 10 minutes, you check the student's progress. At minute 10, you notice a mistake in row 3 and correct it immediately. At minute 20, the student is on track. At minute 30, the final answer is correct.

Approach B is better because:
1. **Errors get caught early** ‚Äî before they compound
2. **Gradients are shorter** ‚Äî the model only needs to backpropagate through a few steps, not all 18
3. **Each checkpoint provides a learning signal** ‚Äî more supervision = faster learning

### The Halting Mechanism

TRM also learns **when to stop thinking**. For easy puzzles, 2 recursion steps might be enough. For hard ones, you need all 18. The model outputs a "halting score" $\hat{q}$ at each step, estimating whether it already has the correct answer. If $\hat{q}$ is high, the model can stop recursing early ‚Äî saving computation.

### ü§î Think About This

If a model checks its answer at 3 intermediate points (T=3) with 6 recursions between each check (n=6), how many total recursions happen? And how many gradient-carrying steps are there?

Answer: 18 total recursions, but only 3 carry gradients (one per supervision step). The other 15 run in `torch.no_grad()` mode ‚Äî they let the model reason freely without the computational overhead of gradient tracking.

In [None]:
#@title üéß Listen: Math
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_02_math.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 3. The Mathematics

### Prediction Loss

The prediction loss is standard softmax cross-entropy:

$$\mathcal{L}_{\text{pred}} = -\sum_{i} y_i^{\text{true}} \log(\hat{y}_i)$$

Computationally: for each cell in the grid, compare the model's predicted probability distribution over classes against the true label. If the model assigns high probability to the correct class, the loss is low.

For example, if the true label is class 2 and $\hat{y} = [0.1, 0.7, 0.2]$:
$$\mathcal{L}_{\text{pred}} = -\log(0.7) = 0.357$$

### Halting Loss

The halting loss is binary cross-entropy:

$$\mathcal{L}_{\text{halt}} = -[q \log(\hat{q}) + (1-q) \log(1-\hat{q})]$$

Where $q = 1$ if the prediction matches the ground truth (model is correct), and $q = 0$ otherwise. The model outputs $\hat{q}$ ‚Äî its confidence that it has the right answer.

Computationally: this teaches the model to be calibrated about its own correctness. If it has the right answer but $\hat{q}$ is low (underconfident), the loss pushes $\hat{q}$ up. If it has the wrong answer but $\hat{q}$ is high (overconfident), the loss pushes it down.

### Total Loss with Deep Supervision

At each supervision step $t \in \{1, ..., T\}$:

$$\mathcal{L}_t = \mathcal{L}_{\text{pred}}^{(t)} + \mathcal{L}_{\text{halt}}^{(t)}$$

The total loss sums over all supervision steps:

$$\mathcal{L} = \sum_{t=1}^{T} \mathcal{L}_t$$

### EMA (Exponential Moving Average)

TRM uses EMA to stabilize training ‚Äî maintaining a slow-moving copy of the model weights:

$$\theta_{\text{EMA}} \leftarrow \alpha \cdot \theta_{\text{EMA}} + (1 - \alpha) \cdot \theta$$

With $\alpha = 0.999$. Computationally: after each gradient update, the EMA model weights move 0.1% toward the current model weights. This smooths out training noise, which is critical when training on small datasets.

In [None]:
#@title üéß Listen: Model Rebuild
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_03_model_rebuild.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 4. Let's Build It ‚Äî Component by Component

### 4.1 Rebuilding the TRM (from Notebook 2)

Let us quickly reconstruct the model from the previous notebook.

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return (x / rms) * self.weight

class SwiGLU(nn.Module):
    def __init__(self, dim, hidden_dim=None):
        super().__init__()
        hidden_dim = hidden_dim or dim * 4
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(dim, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class MLPMixer(nn.Module):
    def __init__(self, seq_len, dim):
        super().__init__()
        self.token_mix = nn.Linear(seq_len, seq_len, bias=False)

    def forward(self, x):
        return self.token_mix(x.transpose(1, 2)).transpose(1, 2)

class TRMLayer(nn.Module):
    def __init__(self, dim, seq_len):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.norm2 = RMSNorm(dim)
        self.mixer = MLPMixer(seq_len, dim)
        self.ffn = SwiGLU(dim, hidden_dim=dim * 4)

    def forward(self, x):
        x = x + self.mixer(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

class TinyRecursiveModel(nn.Module):
    def __init__(self, n_classes, grid_size, dim=64, n_layers=2):
        super().__init__()
        self.dim = dim
        self.grid_size = grid_size
        self.n_classes = n_classes
        seq_len = grid_size * grid_size

        self.input_embed = nn.Linear(n_classes + 1, dim, bias=False)
        self.y_init = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
        self.z_init = nn.Parameter(torch.randn(1, 1, dim) * 0.02)

        self.layers = nn.ModuleList([
            TRMLayer(dim * 3, seq_len=seq_len) for _ in range(n_layers)
        ])

        self.split_proj_y = nn.Linear(dim * 3, dim, bias=False)
        self.split_proj_z = nn.Linear(dim * 3, dim, bias=False)
        self.output_head = nn.Linear(dim, n_classes)
        self.halt_head = nn.Linear(dim, 1)

    def embed_input(self, x):
        B = x.shape[0]
        x_flat = x.reshape(B, -1)
        x_onehot = F.one_hot(x_flat.long(), num_classes=self.n_classes + 1).float()
        return self.input_embed(x_onehot)

    def recurse(self, x_emb, y, z):
        combined = torch.cat([x_emb, y, z], dim=-1)
        for layer in self.layers:
            combined = layer(combined)
        y_new = self.split_proj_y(combined)
        z_new = self.split_proj_z(combined)
        return y_new, z_new

    def forward_with_supervision(self, x, T=3, n=6):
        """
        Forward pass with deep supervision.

        Args:
            x: input grid (batch, grid_size, grid_size)
            T: number of supervision steps
            n: recursions per supervision step

        Returns:
            List of (logits, halt_logits) at each supervision checkpoint
        """
        B = x.shape[0]
        seq_len = self.grid_size * self.grid_size
        x_emb = self.embed_input(x)

        y = self.y_init.expand(B, seq_len, -1)
        z = self.z_init.expand(B, seq_len, -1)

        checkpoints = []

        for t in range(T):
            # Run n-1 recursions WITHOUT gradients (free reasoning)
            with torch.no_grad():
                for _ in range(n - 1):
                    y_detached = y.detach()
                    z_detached = z.detach()
                    y, z = self.recurse(x_emb, y_detached, z_detached)

            # Run 1 recursion WITH gradients (learning step)
            y, z = self.recurse(x_emb, y, z)

            # Checkpoint: record predictions at this supervision step
            logits = self.output_head(y)
            halt_logits = self.halt_head(y).squeeze(-1)
            checkpoints.append((logits, halt_logits, y.detach()))

        return checkpoints

print("Model architecture loaded successfully!")

In [None]:
#@title üéß Listen: Data
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_04_data.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.2 Generating Sudoku Training Data

We need a dataset of 4√ó4 Sudoku puzzles with their solutions.

In [None]:
def is_valid_4x4(grid, r, c, num):
    """Check if placing num at (r,c) is valid in a 4x4 Sudoku."""
    if num in grid[r, :]: return False
    if num in grid[:, c]: return False
    box_r, box_c = 2 * (r // 2), 2 * (c // 2)
    if num in grid[box_r:box_r+2, box_c:box_c+2]: return False
    return True

def solve_4x4(grid):
    """Solve a 4x4 Sudoku via backtracking. Returns True if solvable."""
    for r in range(4):
        for c in range(4):
            if grid[r, c] == 0:
                nums = list(range(1, 5))
                np.random.shuffle(nums)
                for num in nums:
                    if is_valid_4x4(grid, r, c, num):
                        grid[r, c] = num
                        if solve_4x4(grid):
                            return True
                        grid[r, c] = 0
                return False
    return True

def generate_sudoku_dataset(n_puzzles, n_remove=6, seed=42):
    """Generate n_puzzles 4x4 Sudoku puzzles with solutions."""
    np.random.seed(seed)
    puzzles = []
    solutions = []

    for _ in range(n_puzzles):
        # Generate a full valid grid
        grid = np.zeros((4, 4), dtype=np.int64)
        solve_4x4(grid)
        solution = grid.copy()

        # Remove some cells
        indices = np.random.choice(16, size=n_remove, replace=False)
        puzzle = solution.copy()
        for idx in indices:
            puzzle[idx // 4, idx % 4] = 0

        puzzles.append(puzzle)
        solutions.append(solution)

    return np.array(puzzles), np.array(solutions)

# Generate datasets
train_puzzles, train_solutions = generate_sudoku_dataset(1000, n_remove=8, seed=42)
test_puzzles, test_solutions = generate_sudoku_dataset(200, n_remove=8, seed=999)

print(f"Training set: {len(train_puzzles)} puzzles")
print(f"Test set:     {len(test_puzzles)} puzzles")
print(f"\nExample puzzle:")
print(train_puzzles[0])
print(f"\nSolution:")
print(train_solutions[0])
print(f"\nEmpty cells per puzzle: {np.mean(np.sum(train_puzzles == 0, axis=(1,2))):.1f}")

In [None]:
# Create PyTorch datasets
from torch.utils.data import TensorDataset, DataLoader

train_X = torch.tensor(train_puzzles, dtype=torch.long)
train_Y = torch.tensor(train_solutions, dtype=torch.long)
test_X = torch.tensor(test_puzzles, dtype=torch.long)
test_Y = torch.tensor(test_solutions, dtype=torch.long)

train_dataset = TensorDataset(train_X, train_Y)
test_dataset = TensorDataset(test_X, test_Y)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(f"Train batches: {len(train_loader)}")
print(f"Test batches:  {len(test_loader)}")

In [None]:
#@title üéß Listen: Todo Loss
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_05_todo_loss.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.3 The Loss Functions

Now let us implement the prediction loss and halting loss.

### TODO: Implement the Prediction Loss

In [None]:
def prediction_loss(logits, targets, mask):
    """
    Softmax cross-entropy loss on masked cells only.

    Args:
        logits: (batch, seq_len, n_classes) ‚Äî raw predictions
        targets: (batch, seq_len) ‚Äî ground truth class indices (1-indexed)
        mask: (batch, seq_len) ‚Äî True for cells that were originally empty (need prediction)

    Returns:
        scalar loss
    """
    # ============ TODO ============
    # Step 1: Convert targets from 1-indexed to 0-indexed (subtract 1, clamp min=0)
    # Step 2: Reshape logits to (B*L, C) and targets to (B*L,)
    # Step 3: Compute F.cross_entropy with reduction='none'
    # Step 4: Apply the mask ‚Äî only count loss for empty cells
    # Step 5: Return the masked mean loss
    # ==============================

    # Convert targets from 1-indexed to 0-indexed for cross-entropy
    targets_0idx = ???  # YOUR CODE HERE

    # Reshape for cross-entropy
    B, L, C = logits.shape
    logits_flat = logits.reshape(-1, C)
    targets_flat = targets_0idx.reshape(-1)
    mask_flat = mask.reshape(-1)

    # Compute per-element loss
    loss = ???  # YOUR CODE HERE

    # Only count loss for masked (empty) cells
    if mask_flat.sum() > 0:
        return (loss * mask_flat.float()).sum() / mask_flat.float().sum()
    return loss.mean()

In [None]:
# ‚úÖ Verification: test your prediction loss
test_logits = torch.randn(2, 16, 4)
test_targets = torch.randint(1, 5, (2, 16))
test_mask = torch.ones(2, 16).bool()

pred_l = prediction_loss(test_logits, test_targets, test_mask)
expected_random = np.log(4)  # ~1.386 for random 4-class predictions
assert 0.5 < pred_l.item() < 3.0, f"‚ùå Loss {pred_l.item():.3f} seems wrong for random predictions"
print(f"‚úÖ Prediction loss (random): {pred_l.item():.3f} (expected ~{expected_random:.3f} for 4 classes)")

In [None]:
#@title üéß Listen: Losses
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_06_losses.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

Here is the reference implementation:

In [None]:
def prediction_loss(logits, targets, mask):
    """
    Softmax cross-entropy loss on masked cells only.

    Args:
        logits: (batch, seq_len, n_classes) ‚Äî raw predictions
        targets: (batch, seq_len) ‚Äî ground truth class indices (1-indexed)
        mask: (batch, seq_len) ‚Äî True for cells that were originally empty (need prediction)

    Returns:
        scalar loss
    """
    # Convert targets from 1-indexed to 0-indexed for cross-entropy
    targets_0idx = (targets - 1).clamp(min=0)  # classes 0..3

    # Reshape for cross-entropy
    B, L, C = logits.shape
    logits_flat = logits.reshape(-1, C)     # (B*L, C)
    targets_flat = targets_0idx.reshape(-1)  # (B*L,)
    mask_flat = mask.reshape(-1)             # (B*L,)

    # Compute per-element loss
    loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')  # (B*L,)

    # Only count loss for masked (empty) cells
    if mask_flat.sum() > 0:
        return (loss * mask_flat.float()).sum() / mask_flat.float().sum()
    return loss.mean()

def halting_loss(halt_logits, predictions_correct, mask):
    """
    Binary cross-entropy halting loss.

    Args:
        halt_logits: (batch, seq_len) ‚Äî raw halting predictions
        predictions_correct: (batch, seq_len) ‚Äî 1.0 if prediction matches target, 0.0 otherwise
        mask: (batch, seq_len) ‚Äî True for cells that need prediction

    Returns:
        scalar loss
    """
    halt_probs = torch.sigmoid(halt_logits)
    # Clamp for numerical stability
    halt_probs = halt_probs.clamp(1e-7, 1 - 1e-7)

    bce = -(predictions_correct * torch.log(halt_probs) +
            (1 - predictions_correct) * torch.log(1 - halt_probs))

    mask_float = mask.float()
    if mask_float.sum() > 0:
        return (bce * mask_float).sum() / mask_float.sum()
    return bce.mean()

# Test the losses
test_logits = torch.randn(2, 16, 4)
test_targets = torch.randint(1, 5, (2, 16))
test_mask = torch.randint(0, 2, (2, 16)).bool()

pred_l = prediction_loss(test_logits, test_targets, test_mask)
print(f"Prediction loss (random): {pred_l.item():.3f}")
print(f"Expected for 4 classes:   {np.log(4):.3f} (random baseline)")

In [None]:
# üìä Visualize: How cross-entropy loss decreases as confidence increases
confidences = np.linspace(0.01, 0.99, 100)
losses = -np.log(confidences)

plt.figure(figsize=(8, 5))
plt.plot(confidences, losses, linewidth=2.5, color='#e65100')
plt.xlabel('Predicted Probability for Correct Class', fontsize=12)
plt.ylabel('Cross-Entropy Loss', fontsize=12)
plt.title('Loss Decreases as Model Becomes More Confident', fontsize=13, fontweight='bold')
plt.axhline(y=-np.log(0.25), color='#999', linestyle='--', label=f'Random baseline (4 classes): {-np.log(0.25):.2f}')
plt.axhline(y=-np.log(0.7), color='#2e7d32', linestyle='--', label=f'70% confidence: {-np.log(0.7):.2f}')
plt.axhline(y=-np.log(0.95), color='#1565c0', linestyle='--', label=f'95% confidence: {-np.log(0.95):.2f}')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Ema
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_07_ema.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.4 EMA (Exponential Moving Average)

In [None]:
class EMA:
    """
    Exponential Moving Average of model parameters.
    Maintains a smoothed copy of the weights for more stable evaluation.
    """
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {}
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    @torch.no_grad()
    def update(self, model):
        """Update EMA weights: shadow = decay * shadow + (1 - decay) * current"""
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param.data

    def apply(self, model):
        """Replace model weights with EMA weights (for evaluation)."""
        self.backup = {}
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data = self.shadow[name]

    def restore(self, model):
        """Restore original model weights (after evaluation)."""
        for name, param in model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]

print("EMA helper class ready!")

In [None]:
#@title üéß Listen: Training Loop
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_08_training_loop.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.5 The Deep Supervision Training Loop

This is the core innovation. Instead of training on only the final output, we supervise at T intermediate checkpoints.

In [None]:
def train_one_epoch(model, train_loader, optimizer, T=3, n=6):
    """
    Train for one epoch with deep supervision.

    At each supervision step t:
    1. Run n-1 recursions without gradients (free reasoning)
    2. Run 1 recursion with gradients
    3. Compute prediction + halting loss at this checkpoint
    4. Accumulate loss across all T checkpoints
    5. Backpropagate the total loss
    """
    model.train()
    total_loss = 0
    total_correct = 0
    total_cells = 0

    for batch_x, batch_y in train_loader:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        # Mask: True where the original puzzle had empty cells
        mask = (batch_x.reshape(batch_x.shape[0], -1) == 0)
        targets = batch_y.reshape(batch_y.shape[0], -1)

        optimizer.zero_grad()

        # Get checkpoints from deep supervision forward pass
        checkpoints = model.forward_with_supervision(batch_x, T=T, n=n)

        # Accumulate loss across all supervision steps
        batch_loss = 0
        for t, (logits, halt_logits, y_snapshot) in enumerate(checkpoints):
            # Prediction loss
            p_loss = prediction_loss(logits, targets, mask)

            # Check which predictions are correct
            preds = logits.argmax(dim=-1) + 1  # 1-indexed
            correct = (preds == targets).float()

            # Halting loss
            h_loss = halting_loss(halt_logits, correct, mask)

            batch_loss = batch_loss + p_loss + 0.1 * h_loss  # Weight halting loss lower

        batch_loss.backward()

        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        # Track metrics from final checkpoint
        final_logits = checkpoints[-1][0]
        final_preds = final_logits.argmax(dim=-1) + 1
        correct_mask = ((final_preds == targets) & mask).sum().item()
        total_mask = mask.sum().item()

        total_loss += batch_loss.item()
        total_correct += correct_mask
        total_cells += total_mask

    avg_loss = total_loss / len(train_loader)
    accuracy = total_correct / max(total_cells, 1) * 100

    return avg_loss, accuracy

@torch.no_grad()
def evaluate(model, test_loader, T=3, n=6):
    """Evaluate model accuracy on test set."""
    model.eval()
    total_correct = 0
    total_cells = 0
    total_puzzles_perfect = 0
    total_puzzles = 0

    for batch_x, batch_y in test_loader:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        mask = (batch_x.reshape(batch_x.shape[0], -1) == 0)
        targets = batch_y.reshape(batch_y.shape[0], -1)

        checkpoints = model.forward_with_supervision(batch_x, T=T, n=n)
        final_logits = checkpoints[-1][0]
        final_preds = final_logits.argmax(dim=-1) + 1

        correct = (final_preds == targets) & mask
        total_correct += correct.sum().item()
        total_cells += mask.sum().item()

        # Check perfect puzzles (all masked cells correct)
        for i in range(batch_x.shape[0]):
            puzzle_mask = mask[i]
            if puzzle_mask.sum() > 0:
                puzzle_correct = correct[i][puzzle_mask].all().item()
                total_puzzles_perfect += puzzle_correct
            total_puzzles += 1

    cell_accuracy = total_correct / max(total_cells, 1) * 100
    puzzle_accuracy = total_puzzles_perfect / max(total_puzzles, 1) * 100

    return cell_accuracy, puzzle_accuracy

print("Training functions ready!")

In [None]:
#@title üéß Listen: Todo Ema
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_09_todo_ema.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 5. üîß Your Turn

### TODO: Implement the Training Loop with EMA

In [None]:
def train_with_ema(model, train_loader, test_loader, n_epochs=30, lr=1e-3, T=3, n=4):
    """
    Full training loop with EMA and deep supervision.

    Returns training history for plotting.
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)

    # ============ TODO ============
    # Step 1: Initialize the EMA helper (use decay=0.999)
    # Step 2: After each epoch's training step, update the EMA
    # Step 3: For evaluation, apply EMA weights, evaluate, then restore original weights
    # ==============================

    ema = ???  # YOUR CODE HERE: Initialize EMA

    history = {
        'train_loss': [], 'train_acc': [],
        'test_cell_acc': [], 'test_puzzle_acc': [],
        'test_cell_acc_ema': [], 'test_puzzle_acc_ema': []
    }

    for epoch in range(n_epochs):
        # Train
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, T=T, n=n)

        # Update EMA
        ???  # YOUR CODE HERE

        # Evaluate without EMA
        cell_acc, puzzle_acc = evaluate(model, test_loader, T=T, n=n)

        # Evaluate with EMA
        ???  # YOUR CODE HERE: apply EMA weights
        cell_acc_ema, puzzle_acc_ema = evaluate(model, test_loader, T=T, n=n)
        ???  # YOUR CODE HERE: restore original weights

        scheduler.step()

        # Record history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_cell_acc'].append(cell_acc)
        history['test_puzzle_acc'].append(puzzle_acc)
        history['test_cell_acc_ema'].append(cell_acc_ema)
        history['test_puzzle_acc_ema'].append(puzzle_acc_ema)

        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1:3d} | Loss: {train_loss:.4f} | "
                  f"Train Acc: {train_acc:.1f}% | "
                  f"Test Cell: {cell_acc:.1f}% | Test Puzzle: {puzzle_acc:.1f}% | "
                  f"EMA Puzzle: {puzzle_acc_ema:.1f}%")

    return history

In [None]:
# ‚úÖ Verification: Run the training
# (If your EMA implementation is correct, EMA accuracy should be >= non-EMA accuracy)

model = TinyRecursiveModel(n_classes=4, grid_size=4, dim=48, n_layers=2).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")
print(f"Training with T=3 supervision steps, n=4 recursions per step\n")

history = train_with_ema(model, train_loader, test_loader, n_epochs=30, lr=1e-3, T=3, n=4)

final_ema = history['test_puzzle_acc_ema'][-1]
final_no_ema = history['test_puzzle_acc'][-1]
print(f"\nFinal test puzzle accuracy (no EMA): {final_no_ema:.1f}%")
print(f"Final test puzzle accuracy (EMA):    {final_ema:.1f}%")
if final_ema >= final_no_ema - 1:
    print("‚úÖ EMA implementation looks correct!")
else:
    print("‚ö†Ô∏è EMA accuracy is significantly lower ‚Äî check your implementation")

In [None]:
#@title üéß Listen: Training
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_10_training.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 6. Training and Results

In [None]:
# üìä Training curves
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss
axes[0].plot(history['train_loss'], linewidth=2, color='#e65100')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Loss', fontsize=13, fontweight='bold')
axes[0].grid(True, alpha=0.3)

# Cell accuracy
axes[1].plot(history['train_acc'], label='Train', linewidth=2, color='#1565c0')
axes[1].plot(history['test_cell_acc'], label='Test', linewidth=2, color='#2e7d32')
axes[1].plot(history['test_cell_acc_ema'], label='Test (EMA)', linewidth=2, color='#2e7d32', linestyle='--')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Cell Accuracy (%)', fontsize=12)
axes[1].set_title('Per-Cell Accuracy', fontsize=13, fontweight='bold')
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

# Puzzle accuracy
axes[2].plot(history['test_puzzle_acc'], label='Test', linewidth=2, color='#6a1b9a')
axes[2].plot(history['test_puzzle_acc_ema'], label='Test (EMA)', linewidth=2, color='#6a1b9a', linestyle='--')
axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('Puzzle Accuracy (%)', fontsize=12)
axes[2].set_title('Full Puzzle Accuracy', fontsize=13, fontweight='bold')
axes[2].legend(fontsize=10)
axes[2].grid(True, alpha=0.3)

plt.suptitle('TRM Training with Deep Supervision', fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Recursion Viz
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_11_recursion_viz.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# üìä Visualize: Model solving a specific puzzle across recursion steps
@torch.no_grad()
def visualize_recursion_steps(model, puzzle, solution, T=3, n=4):
    """Show how the model's predictions evolve across recursion steps."""
    model.eval()
    x = torch.tensor(puzzle, dtype=torch.long).unsqueeze(0).to(device)
    x_emb = model.embed_input(x)
    B = 1
    seq_len = model.grid_size * model.grid_size

    y = model.y_init.expand(B, seq_len, -1)
    z = model.z_init.expand(B, seq_len, -1)

    all_preds = []
    all_confs = []

    for t in range(T):
        for i in range(n):
            y, z = model.recurse(x_emb, y, z)
            logits = model.output_head(y)
            probs = F.softmax(logits, dim=-1)
            preds = probs.argmax(dim=-1) + 1
            conf = probs.max(dim=-1).values
            all_preds.append(preds[0].cpu().numpy().reshape(4, 4))
            all_confs.append(conf[0].cpu().numpy().reshape(4, 4))

    # Plot selected steps
    n_show = min(6, len(all_preds))
    indices = np.linspace(0, len(all_preds)-1, n_show, dtype=int)

    fig, axes = plt.subplots(1, n_show + 1, figsize=(4 * (n_show + 1), 4.5))

    # Show original puzzle
    ax = axes[0]
    for r in range(4):
        for c in range(4):
            val = puzzle[r, c]
            color = '#e3f2fd' if val > 0 else '#f5f5f5'
            ax.add_patch(plt.Rectangle((c, 3-r), 1, 1, facecolor=color, edgecolor='gray', lw=2))
            if val > 0:
                ax.text(c+0.5, 3-r+0.5, str(val), ha='center', va='center',
                       fontsize=18, fontweight='bold', color='#1565c0')
            else:
                ax.text(c+0.5, 3-r+0.5, '?', ha='center', va='center',
                       fontsize=18, color='#bbb')
    for i in range(0, 5, 2):
        ax.axhline(y=i, color='black', linewidth=3)
        ax.axvline(x=i, color='black', linewidth=3)
    ax.set_xlim(0, 4); ax.set_ylim(0, 4)
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_aspect('equal')
    ax.set_title('Puzzle', fontsize=12, fontweight='bold')

    # Show predictions at each selected step
    for plot_idx, step_idx in enumerate(indices):
        ax = axes[plot_idx + 1]
        pred = all_preds[step_idx]
        conf = all_confs[step_idx]

        for r in range(4):
            for c in range(4):
                if puzzle[r, c] > 0:
                    color = '#e3f2fd'
                    text_color = '#1565c0'
                    val_str = str(puzzle[r, c])
                elif pred[r, c] == solution[r, c]:
                    # Correct prediction ‚Äî intensity by confidence
                    alpha = conf[r, c]
                    color = (0.78 * alpha + 1 * (1-alpha),
                             0.9 * alpha + 1 * (1-alpha),
                             0.77 * alpha + 1 * (1-alpha))
                    text_color = '#2e7d32'
                    val_str = str(pred[r, c])
                else:
                    color = '#ffcdd2'
                    text_color = '#c62828'
                    val_str = str(pred[r, c])

                ax.add_patch(plt.Rectangle((c, 3-r), 1, 1, facecolor=color, edgecolor='gray', lw=2))
                ax.text(c+0.5, 3-r+0.65, val_str, ha='center', va='center',
                       fontsize=16, fontweight='bold', color=text_color)
                ax.text(c+0.5, 3-r+0.3, f'{conf[r,c]:.0%}', ha='center', va='center',
                       fontsize=8, color='#888')

        for i in range(0, 5, 2):
            ax.axhline(y=i, color='black', linewidth=3)
            ax.axvline(x=i, color='black', linewidth=3)
        ax.set_xlim(0, 4); ax.set_ylim(0, 4)
        ax.set_xticks([]); ax.set_yticks([])
        ax.set_aspect('equal')
        ax.set_title(f'Step {step_idx+1}/{len(all_preds)}', fontsize=12, fontweight='bold')

    plt.suptitle('Recursive Refinement: Predictions Sharpen Over Steps',
                 fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()

# Pick a few test puzzles to visualize
for idx in [0, 5, 10]:
    print(f"\n--- Test puzzle {idx} ---")
    visualize_recursion_steps(model, test_puzzles[idx], test_solutions[idx], T=3, n=4)

In [None]:
#@title üéß Listen: Final
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_12_final.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 7. üéØ Final Output ‚Äî The Model Solves Sudoku!

In [None]:
# Count how many test puzzles the model solves perfectly
@torch.no_grad()
def evaluate_detailed(model, puzzles, solutions, T=3, n=4):
    model.eval()
    results = []

    for i in range(len(puzzles)):
        x = torch.tensor(puzzles[i], dtype=torch.long).unsqueeze(0).to(device)
        checkpoints = model.forward_with_supervision(x, T=T, n=n)
        logits = checkpoints[-1][0]
        preds = (logits.argmax(dim=-1) + 1)[0].cpu().numpy().reshape(4, 4)

        mask = puzzles[i] == 0
        correct_cells = (preds[mask] == solutions[i][mask]).sum()
        total_cells = mask.sum()
        perfect = (preds[mask] == solutions[i][mask]).all()

        results.append({
            'puzzle': puzzles[i],
            'solution': solutions[i],
            'prediction': preds,
            'mask': mask,
            'correct_cells': correct_cells,
            'total_cells': total_cells,
            'perfect': perfect
        })

    return results

results = evaluate_detailed(model, test_puzzles, test_solutions)

n_perfect = sum(r['perfect'] for r in results)
n_total = len(results)
avg_cell_acc = np.mean([r['correct_cells'] / max(r['total_cells'], 1) for r in results]) * 100

print(f"{'='*50}")
print(f"  FINAL RESULTS ‚Äî TRM on 4√ó4 Sudoku")
print(f"{'='*50}")
print(f"  Test puzzles:        {n_total}")
print(f"  Perfectly solved:    {n_perfect} ({100*n_perfect/n_total:.1f}%)")
print(f"  Average cell acc:    {avg_cell_acc:.1f}%")
print(f"  Model parameters:    {n_params:,}")
print(f"{'='*50}")

In [None]:
# üìä Show a gallery of solved puzzles
import matplotlib.patches as mpatches

fig, axes = plt.subplots(2, 5, figsize=(20, 9))

# Pick 5 correct and 5 incorrect (or random) puzzles
correct_indices = [i for i, r in enumerate(results) if r['perfect']][:5]
wrong_indices = [i for i, r in enumerate(results) if not r['perfect']][:5]

# Pad if needed
while len(correct_indices) < 5:
    correct_indices.append(correct_indices[-1] if correct_indices else 0)
while len(wrong_indices) < 5:
    wrong_indices.append(correct_indices[len(wrong_indices)] if len(correct_indices) > len(wrong_indices) else 0)

for col, idx in enumerate(correct_indices):
    r = results[idx]
    ax = axes[0][col]
    for row in range(4):
        for c in range(4):
            if r['puzzle'][row, c] > 0:
                color = '#e3f2fd'
                val = str(r['puzzle'][row, c])
                fc = '#1565c0'
            elif r['prediction'][row, c] == r['solution'][row, c]:
                color = '#c8e6c9'
                val = str(r['prediction'][row, c])
                fc = '#2e7d32'
            else:
                color = '#ffcdd2'
                val = str(r['prediction'][row, c])
                fc = '#c62828'
            ax.add_patch(plt.Rectangle((c, 3-row), 1, 1, facecolor=color, edgecolor='gray', lw=1.5))
            ax.text(c+0.5, 3-row+0.5, val, ha='center', va='center', fontsize=14, fontweight='bold', color=fc)
    for i in range(0, 5, 2):
        ax.axhline(y=i, color='black', lw=2)
        ax.axvline(x=i, color='black', lw=2)
    ax.set_xlim(0,4); ax.set_ylim(0,4); ax.set_xticks([]); ax.set_yticks([])
    ax.set_aspect('equal')
    acc = r['correct_cells'] / max(r['total_cells'], 1) * 100
    ax.set_title(f'‚úÖ {acc:.0f}%', fontsize=12, fontweight='bold', color='#2e7d32')

for col, idx in enumerate(wrong_indices):
    r = results[idx]
    ax = axes[1][col]
    for row in range(4):
        for c in range(4):
            if r['puzzle'][row, c] > 0:
                color = '#e3f2fd'; val = str(r['puzzle'][row, c]); fc = '#1565c0'
            elif r['prediction'][row, c] == r['solution'][row, c]:
                color = '#c8e6c9'; val = str(r['prediction'][row, c]); fc = '#2e7d32'
            else:
                color = '#ffcdd2'; val = str(r['prediction'][row, c]); fc = '#c62828'
            ax.add_patch(plt.Rectangle((c, 3-row), 1, 1, facecolor=color, edgecolor='gray', lw=1.5))
            ax.text(c+0.5, 3-row+0.5, val, ha='center', va='center', fontsize=14, fontweight='bold', color=fc)
    for i in range(0, 5, 2):
        ax.axhline(y=i, color='black', lw=2)
        ax.axvline(x=i, color='black', lw=2)
    ax.set_xlim(0,4); ax.set_ylim(0,4); ax.set_xticks([]); ax.set_yticks([])
    ax.set_aspect('equal')
    acc = r['correct_cells'] / max(r['total_cells'], 1) * 100
    status = '‚úÖ' if r['perfect'] else '‚ùå'
    ax.set_title(f'{status} {acc:.0f}%', fontsize=12, fontweight='bold',
                 color='#2e7d32' if r['perfect'] else '#c62828')

axes[0][0].set_ylabel('Correct', fontsize=14, fontweight='bold')
axes[1][0].set_ylabel('Errors', fontsize=14, fontweight='bold')

# Legend
fig.text(0.2, 0.02, 'üîµ Given  ', fontsize=11, color='#1565c0')
fig.text(0.4, 0.02, 'üü¢ Correctly predicted  ', fontsize=11, color='#2e7d32')
fig.text(0.65, 0.02, 'üî¥ Incorrect  ', fontsize=11, color='#c62828')

plt.suptitle(f'üéâ TRM Sudoku Results: {n_perfect}/{n_total} Puzzles Solved Perfectly',
             fontsize=15, fontweight='bold')
plt.tight_layout(rect=[0, 0.05, 1, 0.95])
plt.show()

print(f"\nüéâ Congratulations! You have trained a Tiny Recursive Model to solve Sudoku!")
print(f"   The model uses {n_params:,} parameters and recursive reasoning to solve puzzles")
print(f"   that a single-pass model of the same size could never handle.")

In [None]:
#@title üéß Listen: Closing
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_13_closing.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 8. Reflection and Next Steps

### üí° Key Takeaways

1. **Deep supervision** provides learning signal at multiple intermediate points, preventing gradient instability in deep recursion chains
2. **Halting loss** teaches the model to estimate its own confidence ‚Äî enabling early stopping for easy examples
3. **EMA** smooths training noise, critical for small datasets
4. The combination of these techniques enables training a tiny model to solve structured reasoning tasks

### ü§î Reflection Questions

1. What would happen if we set T=1 (only one supervision step)? Would the model still learn effectively?
2. Why does EMA help more on small datasets than large ones?
3. The halting mechanism could save computation during inference. How would you implement early stopping at test time?

### üèÜ Optional Challenges

1. **Increase difficulty:** Generate puzzles with 10-12 empty cells instead of 8. Does the model need more recursion steps?
2. **Curriculum learning:** Start training with easy puzzles (4 empty cells) and gradually increase difficulty. Does this help?
3. **Visualize the reasoning state z:** Run PCA on the z vectors across recursion steps. Do they form interpretable clusters?

### What's Next

In the final notebook, we will run **ablation studies** ‚Äî systematically removing components to understand what really matters: recursion depth vs model size, MLP vs attention, with/without z, with/without EMA.

In [None]:
#@title üí¨ AI Teaching Assistant ‚Äî Click ‚ñ∂ to start
#@markdown This AI chatbot reads your notebook and can answer questions about any concept, code, or exercise.

import json as _json
import requests as _requests
from google.colab import output as _output
from IPython.display import display, HTML as _HTML, Markdown as _Markdown

# --- Read notebook content for context ---
def _get_notebook_context():
    try:
        from google.colab import _message
        nb = _message.blocking_request("get_ipynb", request="", timeout_sec=10)
        cells = nb.get("ipynb", {}).get("cells", [])
        parts = []
        for cell in cells:
            src = "".join(cell.get("source", []))
            tags = cell.get("metadata", {}).get("tags", [])
            if "chatbot" in tags:
                continue
            if src.strip():
                ct = cell.get("cell_type", "unknown")
                parts.append(f"[{ct.upper()}]\n{src}")
        return "\n\n---\n\n".join(parts)
    except Exception:
        return "Notebook content unavailable."

_NOTEBOOK_CONTEXT = _get_notebook_context()
_CHAT_HISTORY = []
_API_URL = "https://course-creator-brown.vercel.app/api/chat"

def _notebook_chat(question):
    global _CHAT_HISTORY
    try:
        resp = _requests.post(_API_URL, json={
            'question': question,
            'context': _NOTEBOOK_CONTEXT[:100000],
            'history': _CHAT_HISTORY[-10:],
        }, timeout=60)
        data = resp.json()
        answer = data.get('answer', 'Sorry, I could not generate a response.')
        _CHAT_HISTORY.append({'role': 'user', 'content': question})
        _CHAT_HISTORY.append({'role': 'assistant', 'content': answer})
        return answer
    except Exception as e:
        return f'Error connecting to teaching assistant: {str(e)}'

_output.register_callback('notebook_chat', _notebook_chat)

def ask(question):
    """Ask the AI teaching assistant a question about this notebook."""
    answer = _notebook_chat(question)
    display(_Markdown(answer))

print("\u2705 AI Teaching Assistant is ready!")
print("\U0001f4a1 Use the chat below, or call ask(\'your question\') in any cell.")

# --- Display chat widget ---
display(_HTML('''<style>
  .vc-wrap{font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,sans-serif;max-width:100%;border-radius:16px;overflow:hidden;box-shadow:0 4px 24px rgba(0,0,0,.12);background:#fff;border:1px solid #e5e7eb}
  .vc-hdr{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;padding:16px 20px;display:flex;align-items:center;gap:12px}
  .vc-avatar{width:42px;height:42px;background:rgba(255,255,255,.2);border-radius:50%;display:flex;align-items:center;justify-content:center;font-size:22px}
  .vc-hdr h3{font-size:16px;font-weight:600;margin:0}
  .vc-hdr p{font-size:12px;opacity:.85;margin:2px 0 0}
  .vc-msgs{height:420px;overflow-y:auto;padding:16px;background:#f8f9fb;display:flex;flex-direction:column;gap:10px}
  .vc-msg{display:flex;flex-direction:column;animation:vc-fade .25s ease}
  .vc-msg.user{align-items:flex-end}
  .vc-msg.bot{align-items:flex-start}
  .vc-bbl{max-width:85%;padding:10px 14px;border-radius:16px;font-size:14px;line-height:1.55;word-wrap:break-word}
  .vc-msg.user .vc-bbl{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;border-bottom-right-radius:4px}
  .vc-msg.bot .vc-bbl{background:#fff;color:#1a1a2e;border:1px solid #e8e8e8;border-bottom-left-radius:4px}
  .vc-bbl code{background:rgba(0,0,0,.07);padding:2px 6px;border-radius:4px;font-size:13px;font-family:'Fira Code',monospace}
  .vc-bbl pre{background:#1e1e2e;color:#cdd6f4;padding:12px;border-radius:8px;overflow-x:auto;margin:8px 0;font-size:13px}
  .vc-bbl pre code{background:none;padding:0;color:inherit}
  .vc-bbl h3,.vc-bbl h4{margin:10px 0 4px;font-size:15px}
  .vc-bbl ul,.vc-bbl ol{margin:4px 0;padding-left:20px}
  .vc-bbl li{margin:2px 0}
  .vc-chips{display:flex;flex-wrap:wrap;gap:8px;padding:0 16px 12px;background:#f8f9fb}
  .vc-chip{background:#fff;border:1px solid #d1d5db;border-radius:20px;padding:6px 14px;font-size:12px;cursor:pointer;transition:all .15s;color:#4b5563}
  .vc-chip:hover{border-color:#667eea;color:#667eea;background:#f0f0ff}
  .vc-input{display:flex;padding:12px 16px;background:#fff;border-top:1px solid #eee;gap:8px}
  .vc-input input{flex:1;padding:10px 16px;border:2px solid #e8e8e8;border-radius:24px;font-size:14px;outline:none;transition:border-color .2s}
  .vc-input input:focus{border-color:#667eea}
  .vc-input button{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;border:none;border-radius:50%;width:42px;height:42px;cursor:pointer;display:flex;align-items:center;justify-content:center;font-size:18px;transition:transform .1s}
  .vc-input button:hover{transform:scale(1.05)}
  .vc-input button:disabled{opacity:.5;cursor:not-allowed;transform:none}
  .vc-typing{display:flex;gap:5px;padding:4px 0}
  .vc-typing span{width:8px;height:8px;background:#667eea;border-radius:50%;animation:vc-bounce 1.4s infinite ease-in-out}
  .vc-typing span:nth-child(2){animation-delay:.2s}
  .vc-typing span:nth-child(3){animation-delay:.4s}
  @keyframes vc-bounce{0%,80%,100%{transform:scale(0)}40%{transform:scale(1)}}
  @keyframes vc-fade{from{opacity:0;transform:translateY(8px)}to{opacity:1;transform:translateY(0)}}
  .vc-note{text-align:center;font-size:11px;color:#9ca3af;padding:8px 16px 12px;background:#fff}
</style>
<div class="vc-wrap">
  <div class="vc-hdr">
    <div class="vc-avatar">&#129302;</div>
    <div>
      <h3>Vizuara Teaching Assistant</h3>
      <p>Ask me anything about this notebook</p>
    </div>
  </div>
  <div class="vc-msgs" id="vcMsgs">
    <div class="vc-msg bot">
      <div class="vc-bbl">&#128075; Hi! I've read through this entire notebook. Ask me about any concept, code block, or exercise &mdash; I'm here to help you learn!</div>
    </div>
  </div>
  <div class="vc-chips" id="vcChips">
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Explain the main concept</span>
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Help with the TODO exercise</span>
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Summarize what I learned</span>
  </div>
  <div class="vc-input">
    <input type="text" id="vcIn" placeholder="Ask about concepts, code, exercises..." />
    <button id="vcSend" onclick="vcSendMsg()">&#10148;</button>
  </div>
  <div class="vc-note">AI-generated &middot; Verify important information &middot; <a href="#" onclick="vcClear();return false" style="color:#667eea">Clear chat</a></div>
</div>
<script>
(function(){
  var msgs=document.getElementById('vcMsgs'),inp=document.getElementById('vcIn'),
      btn=document.getElementById('vcSend'),chips=document.getElementById('vcChips');

  function esc(s){var d=document.createElement('div');d.textContent=s;return d.innerHTML}

  function md(t){
    return t
      .replace(/```(\w*)\n([\s\S]*?)```/g,function(_,l,c){return '<pre><code>'+esc(c)+'</code></pre>'})
      .replace(/`([^`]+)`/g,'<code>$1</code>')
      .replace(/\*\*([^*]+)\*\*/g,'<strong>$1</strong>')
      .replace(/\*([^*]+)\*/g,'<em>$1</em>')
      .replace(/^#### (.+)$/gm,'<h4>$1</h4>')
      .replace(/^### (.+)$/gm,'<h4>$1</h4>')
      .replace(/^## (.+)$/gm,'<h3>$1</h3>')
      .replace(/^\d+\. (.+)$/gm,'<li>$1</li>')
      .replace(/^- (.+)$/gm,'<li>$1</li>')
      .replace(/\n\n/g,'<br><br>')
      .replace(/\n/g,'<br>');
  }

  function addMsg(text,isUser){
    var m=document.createElement('div');m.className='vc-msg '+(isUser?'user':'bot');
    var b=document.createElement('div');b.className='vc-bbl';
    b.innerHTML=isUser?esc(text):md(text);
    m.appendChild(b);msgs.appendChild(m);msgs.scrollTop=msgs.scrollHeight;
  }

  function showTyping(){
    var m=document.createElement('div');m.className='vc-msg bot';m.id='vcTyping';
    m.innerHTML='<div class="vc-bbl"><div class="vc-typing"><span></span><span></span><span></span></div></div>';
    msgs.appendChild(m);msgs.scrollTop=msgs.scrollHeight;
  }

  function hideTyping(){var e=document.getElementById('vcTyping');if(e)e.remove()}

  window.vcSendMsg=function(){
    var q=inp.value.trim();if(!q)return;
    inp.value='';chips.style.display='none';
    addMsg(q,true);showTyping();btn.disabled=true;
    google.colab.kernel.invokeFunction('notebook_chat',[q],{})
      .then(function(r){
        hideTyping();
        var a=r.data['application/json'];
        addMsg(typeof a==='string'?a:JSON.stringify(a),false);
      })
      .catch(function(){
        hideTyping();
        addMsg('Sorry, I encountered an error. Please check your internet connection and try again.',false);
      })
      .finally(function(){btn.disabled=false;inp.focus()});
  };

  window.vcAsk=function(q){inp.value=q;vcSendMsg()};
  window.vcClear=function(){
    msgs.innerHTML='<div class="vc-msg bot"><div class="vc-bbl">&#128075; Chat cleared. Ask me anything!</div></div>';
    chips.style.display='flex';
  };

  inp.addEventListener('keypress',function(e){if(e.key==='Enter')vcSendMsg()});
  inp.focus();
})();
</script>'''))