In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

# Case Study: Edge-Deployed Wafer Defect Pattern Classification Using Recursive Reasoning
## Implementation Notebook

This notebook implements a Tiny Recursive Model (TRM) for classifying wafer map defect patterns in semiconductor manufacturing. You will build the full pipeline from data loading through model evaluation, implementing key components yourself.

**Context**: SilicaAI needs a model with <10M parameters that achieves >93% accuracy on 9-class wafer defect classification, running on edge hardware (NVIDIA Jetson Orin Nano) with <50ms inference latency. The TRM architecture uses recursive reasoning ‚Äî applying a tiny 2-layer network repeatedly ‚Äî to achieve the computational depth of a 42-layer model with only 7M parameters.

**Dataset**: WM-811K wafer bin map dataset (811K real wafer maps from semiconductor production).

**What you will build**:
1. Data preprocessing pipeline with augmentation
2. Spatial feature analysis and rule-based baseline
3. CNN baseline for comparison
4. Full TRM architecture with rotary position embeddings, deep supervision, and adaptive halting
5. Training loop with EMA and class-weighted loss
6. Comprehensive evaluation and error analysis
7. Latency profiling for edge deployment

---

## 3.1 Data Acquisition and Preprocessing

We use the WM-811K wafer map dataset, which contains 811,457 wafer bin maps from real semiconductor production lines. Each map is a 2D grid of die outcomes labeled with one of 9 defect pattern classes.

# ü§ñ AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** ‚Äî it has already read this entire notebook and can help with concepts, code, and exercises.

**[üëâ Open AI Teaching Assistant](https://course-creator-brown.vercel.app/courses/tiny-recursive-models/practice/0/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*


In [None]:
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Mount Google Drive if running in Colab
from google.colab import drive
drive.mount('/content/drive')

# Load WM-811K dataset
# Download from: https://www.kaggle.com/datasets/qingyi/wm811k-wafer-map
# Upload the pickle file to your Google Drive
DATA_PATH = "/content/drive/MyDrive/datasets/wm811k.pkl"  # Update with your path

df = pd.read_pickle(DATA_PATH)
print(f"Total wafer maps: {len(df)}")
print(f"Columns: {df.columns.tolist()}")
print(f"\nLabel distribution:")
print(df['failureType'].value_counts())

In [None]:
# Preprocessing constants and class mapping
GRID_SIZE = 26  # Fixed grid size for all wafer maps
NUM_CLASSES = 9

CLASS_MAP = {
    'none': 0, 'Center': 1, 'Donut': 2, 'Edge-Loc': 3,
    'Edge-Ring': 4, 'Loc': 5, 'Near-full': 6, 'Random': 7, 'Scratch': 8
}

INV_CLASS_MAP = {v: k for k, v in CLASS_MAP.items()}

def preprocess_wafer_map(wafer_map, target_size=GRID_SIZE):
    """
    Resize a variable-size wafer map to fixed dimensions.

    Args:
        wafer_map: 2D numpy array with values {0: pass, 1: fail, 2: out-of-wafer}
        target_size: Target grid dimension (square)

    Returns:
        Resized wafer map of shape (target_size, target_size)
    """
    from skimage.transform import resize
    resized = resize(wafer_map.astype(float), (target_size, target_size),
                     order=0, preserve_range=True, anti_aliasing=False)
    return resized.astype(np.int32)

def flatten_to_sequence(wafer_map):
    """Flatten 2D grid to 1D sequence for TRM input."""
    return wafer_map.reshape(-1)  # (676,)

In [None]:
# Preprocess all wafer maps
print("Preprocessing wafer maps...")
processed_maps = []
labels = []

for idx, row in df.iterrows():
    wm = row['waferMap']
    label_str = row['failureType']

    if not isinstance(wm, np.ndarray) or wm.size == 0:
        continue
    if label_str not in CLASS_MAP and not pd.isna(label_str):
        continue

    label = CLASS_MAP.get(label_str, 0)  # Map NaN to 'none' (class 0)
    processed = preprocess_wafer_map(wm, GRID_SIZE)
    processed_maps.append(processed)
    labels.append(label)

processed_maps = np.array(processed_maps)
labels = np.array(labels)
print(f"Processed {len(processed_maps)} wafer maps")
print(f"Shape: {processed_maps.shape}")
print(f"Label distribution: {dict(Counter(labels))}")

### TODO: Data Augmentation Pipeline

Wafer maps have rotational symmetry (the wafer is circular), so geometric augmentations are label-preserving. Implement augmentation to address class imbalance and improve generalization.

In [None]:
def augment_wafer_map(wafer_map, label):
    """
    Apply data augmentation to a wafer map.

    Wafer maps have rotational symmetry (the wafer is circular), so rotations
    and reflections are label-preserving for most defect types. However, some
    augmentations must respect the defect pattern semantics:

    - Rotations (90, 180, 270 degrees): SAFE for all classes
    - Horizontal/vertical flip: SAFE for all classes (symmetric wafer)
    - Random noise injection (flip 1-2% of passing die to failing):
      SAFE, simulates real measurement noise
    - DO NOT apply translations or crops -- defect location relative to wafer
      center is diagnostic

    Args:
        wafer_map: np.array of shape (26, 26), values in {0, 1, 2}
        label: int, class label (0-8)

    Returns:
        augmented_map: np.array of shape (26, 26)

    Hints:
        1. Use np.rot90 for rotations (k parameter controls number of 90-degree rotations)
        2. Use np.flipud and np.fliplr for reflections
        3. For noise injection, only flip die that are within the wafer boundary (value != 2)
        4. Randomly choose ONE augmentation per call (not all at once)
        5. Return the original map with probability 0.3 (not every sample needs augmentation)
    """
    # TODO: Implement augmentation pipeline
    # Step 1: With 30% probability, return original (no augmentation)
    # Step 2: Randomly select one augmentation type
    # Step 3: Apply the selected augmentation
    # Step 4: Ensure out-of-wafer mask (value 2) is preserved after augmentation
    raise NotImplementedError("Implement wafer map augmentation")

In [None]:
# Verification cell for augmentation
def verify_augmentation():
    """Test that augmentation preserves key properties."""
    test_map = np.random.choice([0, 1, 2], size=(26, 26), p=[0.7, 0.2, 0.1])
    boundary_mask = test_map == 2

    augmented = augment_wafer_map(test_map.copy(), label=1)

    assert augmented.shape == (26, 26), f"Shape changed: {augmented.shape}"
    assert set(np.unique(augmented)).issubset({0, 1, 2}), "Invalid values introduced"
    original_die_count = np.sum(test_map != 2)
    augmented_die_count = np.sum(augmented != 2)
    assert abs(original_die_count - augmented_die_count) <= original_die_count * 0.05, \
        "Too many die added/removed"
    print("All augmentation checks passed!")

verify_augmentation()

**Thought questions:**
- Why is translation NOT a safe augmentation for wafer maps, even though it is commonly used for natural images?
- How would you handle augmentation differently for the "Scratch" class vs the "Center" class?

---

## 3.2 Exploratory Data Analysis

Understanding the data distribution and spatial patterns before building models.

In [None]:
# Class distribution analysis
label_counts = Counter(labels)
classes = [INV_CLASS_MAP[i] for i in range(NUM_CLASSES)]
counts = [label_counts.get(i, 0) for i in range(NUM_CLASSES)]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Linear scale
axes[0].bar(classes, counts, color='steelblue')
axes[0].set_xlabel('Defect Pattern')
axes[0].set_ylabel('Count')
axes[0].set_title('Class Distribution in WM-811K')
axes[0].tick_params(axis='x', rotation=45)

# Log scale to see minority classes
axes[1].bar(classes, counts, color='steelblue')
axes[1].set_yscale('log')
axes[1].set_xlabel('Defect Pattern')
axes[1].set_ylabel('Count (log scale)')
axes[1].set_title('Class Distribution (Log Scale)')
axes[1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

print(f"\nImbalance ratio (max/min): {max(counts)/max(min(c for c in counts if c > 0),1):.1f}x")

In [None]:
# Visualize example wafer maps from each class
fig, axes = plt.subplots(2, 5, figsize=(18, 8))
axes = axes.flatten()
axes[-1].axis('off')  # Remove extra subplot

cmap = plt.cm.colors.ListedColormap(['white', 'red', 'lightgray'])

for class_id in range(NUM_CLASSES):
    mask = labels == class_id
    if mask.sum() == 0:
        continue
    sample_idx = np.where(mask)[0][0]
    sample = processed_maps[sample_idx]

    ax = axes[class_id]
    ax.imshow(sample, cmap=cmap, interpolation='nearest', vmin=0, vmax=2)
    ax.set_title(f'{INV_CLASS_MAP[class_id]} (n={mask.sum():,})')
    ax.axis('off')

plt.suptitle('Example Wafer Maps by Defect Class', fontsize=14)
plt.tight_layout()
plt.show()

### TODO: Spatial Feature Analysis

In [None]:
def compute_spatial_features(wafer_map):
    """
    Compute spatial statistics that characterize defect patterns.

    These features will help you understand WHY certain patterns are
    challenging to distinguish, and will serve as the basis for the
    rule-based baseline.

    Args:
        wafer_map: np.array of shape (H, W), values in {0, 1, 2}

    Returns:
        dict with keys:
            - 'defect_ratio': fraction of in-wafer die that are defective
            - 'radial_profile': np.array of shape (13,) -- average defect rate
              at each radial distance from center (13 bins for 26x26 grid)
            - 'angular_profile': np.array of shape (8,) -- average defect rate
              in each 45-degree angular sector
            - 'centroid_distance': distance of defect centroid from wafer center,
              normalized by wafer radius
            - 'spatial_entropy': Shannon entropy of the 2D defect distribution
              (higher = more spread out, lower = more concentrated)

    Hints:
        1. Compute wafer center as the centroid of all in-wafer die (value != 2)
        2. For radial profile, compute distance of each die from center, bin into 13 equal-width bins
        3. For angular profile, compute angle of each die from center using np.arctan2, bin into 8 sectors
        4. Centroid of defects = mean (row, col) of all defective die (value == 1)
        5. For spatial entropy, divide the grid into 4x4 blocks, compute defect rate per block,
           then compute entropy over the block-level distribution
    """
    # TODO: Implement spatial feature computation
    raise NotImplementedError("Implement spatial feature computation")

In [None]:
# Verification cell for spatial features
def verify_spatial_features():
    """Test spatial feature computation on a known pattern."""
    # Create a center-defect pattern: defects concentrated in the middle
    test_map = np.full((26, 26), 0)
    test_map[:3, :] = 2  # Top rows out-of-wafer
    test_map[-3:, :] = 2  # Bottom rows out-of-wafer
    test_map[11:15, 11:15] = 1  # Center cluster of defects

    features = compute_spatial_features(test_map)

    assert 'defect_ratio' in features, "Missing 'defect_ratio'"
    assert 'radial_profile' in features, "Missing 'radial_profile'"
    assert features['radial_profile'].shape == (13,), f"Wrong radial shape: {features['radial_profile'].shape}"
    assert features['angular_profile'].shape == (8,), f"Wrong angular shape: {features['angular_profile'].shape}"
    assert features['centroid_distance'] < 0.3, "Center defect should have small centroid distance"
    print(f"Defect ratio: {features['defect_ratio']:.3f}")
    print(f"Centroid distance: {features['centroid_distance']:.3f}")
    print("All spatial feature checks passed!")

verify_spatial_features()

**Thought questions:**
1. Which defect patterns have the most similar radial profiles? What additional features would help distinguish them?
2. The class distribution is heavily imbalanced. What strategies could address this during training? How might class weighting interact with the TRM's deep supervision?
3. Why is spatial entropy useful for distinguishing "random" defects from structured patterns like "scratch" or "edge-ring"?

---

## 3.3 Baseline Approach

Establish performance bounds with a rule-based baseline and a simple CNN.

In [None]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report, accuracy_score, f1_score
from sklearn.model_selection import train_test_split

In [None]:
# Prepare train/val/test splits
X_train, X_temp, y_train, y_temp = train_test_split(
    processed_maps, labels, test_size=0.3, random_state=42, stratify=labels
)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)
print(f"Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")

### TODO: Rule-Based Baseline

In [None]:
def build_rule_based_baseline(X_train, y_train, X_test, y_test):
    """
    Build a decision tree classifier over spatial features as the rule-based baseline.

    This represents the traditional approach used in semiconductor fabs:
    compute handcrafted spatial statistics, then apply a simple classifier.

    Args:
        X_train: np.array of shape (N_train, 26, 26) -- wafer maps
        y_train: np.array of shape (N_train,) -- class labels
        X_test: np.array of shape (N_test, 26, 26) -- wafer maps
        y_test: np.array of shape (N_test,) -- class labels

    Returns:
        dict with keys:
            - 'accuracy': float, overall accuracy on test set
            - 'macro_f1': float, macro-averaged F1
            - 'predictions': np.array of test set predictions
            - 'model': the fitted DecisionTreeClassifier

    Steps:
        1. Compute spatial features for all training and test wafer maps
           using compute_spatial_features()
        2. Stack features into feature matrices
        3. Fit a DecisionTreeClassifier with max_depth=10 on training features
        4. Predict on test set
        5. Compute overall accuracy and macro F1
        6. Print the classification report

    Hint: Use sklearn.metrics.classification_report with output_dict=True for per-class F1
    """
    # TODO: Implement rule-based baseline
    raise NotImplementedError("Implement rule-based baseline")

### TODO: CNN Baseline

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

In [None]:
def build_cnn_baseline(X_train, y_train, X_val, y_val, X_test, y_test,
                       num_epochs=20, batch_size=128):
    """
    Build a simple CNN baseline for wafer map classification.

    Architecture: 3 conv layers (32, 64, 128 filters) with BatchNorm and MaxPool,
    followed by a global average pool and linear classifier.

    Args:
        X_train, y_train: Training data and labels
        X_val, y_val: Validation data and labels
        X_test, y_test: Test data and labels
        num_epochs: Number of training epochs
        batch_size: Batch size

    Returns:
        dict with keys:
            - 'accuracy': float, test accuracy
            - 'macro_f1': float, test macro F1
            - 'model': trained CNN model
            - 'param_count': int, number of parameters

    Steps:
        1. Define a CNN with 3 conv blocks: Conv2d -> BatchNorm -> ReLU -> MaxPool
        2. Add global average pooling and a linear layer for 9-class output
        3. Use CrossEntropyLoss with class weights (inverse frequency) to handle imbalance
        4. Train with Adam optimizer, lr=1e-3, for num_epochs
        5. Evaluate on test set

    Hints:
        - Input shape: (batch, 1, 26, 26) -- single channel wafer map
        - Conv filter sizes: 3x3 with padding=1
        - MaxPool: 2x2
        - After 3 pooling steps, spatial dim is 26->13->6->3, so GAP output is (batch, 128)
        - Total params should be ~150K (far smaller than SilicaAI's 45M cloud model)
    """
    # TODO: Implement CNN baseline
    raise NotImplementedError("Implement CNN baseline")

In [None]:
# Compare baselines
def compare_baselines(rule_results, cnn_results):
    """Print a comparison table of baseline results."""
    print(f"\n{'Method':<25} {'Accuracy':>10} {'Macro F1':>10} {'Params':>12}")
    print("-" * 60)
    print(f"{'Rule-based (DTree)':<25} {rule_results['accuracy']:>10.3f} {rule_results['macro_f1']:>10.3f} {'N/A':>12}")
    print(f"{'CNN (3-layer)':<25} {cnn_results['accuracy']:>10.3f} {cnn_results['macro_f1']:>10.3f} {cnn_results['param_count']:>12,}")
    print(f"\n{'Target (TRM)':<25} {'>0.93':>10} {'>0.90':>10} {'<10M':>12}")

**Thought questions:**
1. What classes does the rule-based baseline struggle with most? Why?
2. How does the CNN baseline's accuracy compare to SilicaAI's cloud model (89%)? What accounts for the difference?
3. If you were to improve the CNN baseline without changing the architecture, what training strategies would you try?

---

## 3.4 Model Design: Tiny Recursive Model

Now we implement the core TRM architecture adapted for wafer defect classification.

In [None]:
import math

### Building Blocks: RMSNorm and SwiGLU

In [None]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.

    Simpler and faster than LayerNorm -- normalizes by the RMS of activations
    without centering (no mean subtraction). Used in LLaMA, Gemini, and TRM.

    RMSNorm(x) = x / RMS(x) * gamma
    where RMS(x) = sqrt(mean(x^2) + eps)
    """
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight


class SwiGLU(nn.Module):
    """SwiGLU activation: a gated activation used in modern transformers.

    SwiGLU(x) = (xW1) * swish(xW2)
    where swish(z) = z * sigmoid(z)

    The gating mechanism allows the network to learn which features to pass
    through, providing more expressiveness than simple ReLU.
    """
    def __init__(self, dim, hidden_dim=None):
        super().__init__()
        hidden_dim = hidden_dim or dim * 4
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(dim, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

### TODO: 2D Rotary Position Embeddings

In [None]:
class RotaryPositionEmbedding2D(nn.Module):
    """
    2D Rotary Position Embeddings for grid-structured data.

    Standard rotary embeddings encode 1D position. For wafer maps, we need
    2D position encoding because defect patterns are defined by their (row, col)
    location relative to the wafer center.

    The key idea: split the embedding dimension in half. The first half encodes
    the row position, the second half encodes the column position. Each half
    uses standard rotary embeddings.

    Args:
        dim: Embedding dimension (must be divisible by 4 for 2D)
        grid_size: Size of the square grid (26 for our wafer maps)

    Forward args:
        x: Tensor of shape (batch, seq_len, dim) where seq_len = grid_size^2

    Returns:
        Tensor of same shape with positional information encoded

    Implementation steps:
        1. In __init__:
           a. Precompute frequency bases: theta_i = 1 / (10000^(2i/d)) for i in [0, d/4)
           b. Precompute row and column indices for each position in the flattened grid
           c. Precompute sin and cos tables for row and column positions

        2. In forward:
           a. Split x into 4 chunks along the last dimension: [x_r1, x_r2, x_c1, x_c2]
           b. Apply rotation to row components:
              x_r1' = x_r1 * cos(row_pos) - x_r2 * sin(row_pos)
              x_r2' = x_r1 * sin(row_pos) + x_r2 * cos(row_pos)
           c. Apply rotation to col components similarly
           d. Concatenate [x_r1', x_r2', x_c1', x_c2'] and return

    Hints:
        - Use torch.arange and integer division/modulo to get (row, col) from flat index
        - Register sin/cos tables as buffers (self.register_buffer) so they move to GPU automatically
        - The frequency base formula: freqs = 1.0 / (10000.0 ** (torch.arange(0, dim//4, 2).float() / (dim//4)))
    """
    def __init__(self, dim, grid_size=GRID_SIZE):
        super().__init__()
        # TODO: Implement initialization
        # Step 1: Compute frequency bases
        # Step 2: Compute row/col position indices for flattened grid
        # Step 3: Compute and register sin/cos buffers
        raise NotImplementedError("Implement 2D rotary position embeddings")

    def forward(self, x):
        # TODO: Implement forward pass
        # Step 1: Split x into 4 chunks
        # Step 2: Apply rotary transformation to row and col components
        # Step 3: Concatenate and return
        raise NotImplementedError("Implement rotary forward pass")

In [None]:
# Verification for rotary embeddings
def verify_rotary():
    """Test that rotary embeddings preserve norms and encode position."""
    rope = RotaryPositionEmbedding2D(dim=128, grid_size=26)
    x = torch.randn(2, 676, 128)
    y = rope(x)

    assert y.shape == x.shape, f"Shape mismatch: {y.shape} vs {x.shape}"
    # Rotary embeddings should approximately preserve norms
    x_norms = torch.norm(x, dim=-1)
    y_norms = torch.norm(y, dim=-1)
    assert torch.allclose(x_norms, y_norms, atol=1e-4), "Norms not preserved"
    print("Rotary embedding checks passed!")

verify_rotary()

### TODO: TRM Block (Single Recursion Step)

In [None]:
class TRMBlock(nn.Module):
    """
    A single Tiny Recursive Model block -- one layer of the 2-layer network.

    This block processes three inputs (x, y, z) and produces updated (y, z).
    It uses the attention variant since our wafer maps (676 tokens) have
    seq_len >> hidden_dim.

    Architecture per block:
        1. Concatenate inputs: cat([x, y, z]) along feature dim -> project to dim
        2. RMSNorm
        3. Self-attention with rotary position embeddings
        4. Residual connection
        5. RMSNorm
        6. SwiGLU feedforward
        7. Residual connection
        8. Project to output (y_update, z_update)

    Args:
        dim: Hidden dimension (128)
        num_heads: Number of attention heads (8)
        grid_size: Grid size for position embeddings (26)

    Forward args:
        x: Input wafer map embedding, shape (batch, seq_len, dim)
        y: Current solution state, shape (batch, seq_len, dim)
        z: Current reasoning state, shape (batch, seq_len, dim)

    Returns:
        y_new: Updated solution state, shape (batch, seq_len, dim)
        z_new: Updated reasoning state, shape (batch, seq_len, dim)

    Implementation hints:
        1. Input combination: concatenate [x, y, z] along feature dim -> (batch, seq_len, 3*dim)
           then project to dim with a linear layer
        2. For self-attention: implement manually with Q, K, V projections
           (nn.Linear(dim, dim) for each), apply rotary embeddings to Q and K,
           then compute scaled dot-product attention
        3. The residual connection adds the combined input to the attention output
        4. Output projection: one linear layer from dim -> 2*dim, then split into y_update and z_update
    """
    def __init__(self, dim=128, num_heads=8, grid_size=GRID_SIZE):
        super().__init__()
        # TODO: Implement initialization
        # Define: input projection (3*dim -> dim), RMSNorm layers,
        # Q/K/V projections, SwiGLU feedforward, output projection (dim -> 2*dim),
        # rotary embeddings
        raise NotImplementedError("Implement TRMBlock.__init__")

    def forward(self, x, y, z):
        # TODO: Implement forward pass
        # Step 1: Combine inputs: h = project(cat([x, y, z], dim=-1))
        # Step 2: norm1 -> attention (with rotary) -> residual
        # Step 3: norm2 -> SwiGLU -> residual
        # Step 4: Project to y_new, z_new = split(output_proj(h), dim=-1)
        raise NotImplementedError("Implement TRMBlock.forward")

### TODO: Full TRM with Recursion and Deep Supervision

In [None]:
class TinyRecursiveModel(nn.Module):
    """
    Full Tiny Recursive Model for wafer defect classification.

    Architecture:
        - Input embedding: maps wafer map tokens (0, 1, 2) to dim-dimensional vectors
        - 2 TRMBlock layers (the core recursive unit)
        - Classification head: pools sequence -> 9-class logits
        - Halting head: pools sequence -> scalar halt probability

    Recursion:
        - T supervision steps, n recursion iterations per step
        - At each supervision step: run n iterations, then compute loss
        - Deep supervision: loss is computed at each of the T steps

    Args:
        dim: Hidden dimension (128)
        num_heads: Attention heads (8)
        num_layers: Layers in recursive unit (2)
        num_classes: Output classes (9)
        grid_size: Wafer grid size (26)
        T: Number of supervision steps (3)
        n: Recursion iterations per supervision step (6)

    Forward args:
        wafer_map: LongTensor of shape (batch, seq_len) with values in {0, 1, 2}

    Returns:
        dict with:
            - 'logits': list of T tensors, each shape (batch, num_classes)
            - 'halt_probs': list of T tensors, each shape (batch, 1)
            - 'final_logits': tensor of shape (batch, num_classes) -- last supervision step

    Implementation steps:
        1. __init__:
           a. Embedding layer: nn.Embedding(3, dim) for input tokens
           b. Stack of num_layers TRMBlock modules
           c. Classification head: LayerNorm -> Linear(dim, num_classes)
           d. Halting head: LayerNorm -> Linear(dim, 1) -> Sigmoid

        2. forward:
           a. Embed input: x = embedding(wafer_map)  # (batch, seq_len, dim)
           b. Initialize y = zeros(batch, seq_len, dim), z = zeros(batch, seq_len, dim)
           c. For each supervision step t in [1, T]:
               i.   For each recursion iteration i in [1, n]:
                       For each layer in self.layers:
                           y, z = layer(x, y, z)
               ii.  Compute logits_t = cls_head(y.mean(dim=1))
               iii. Compute halt_t = halt_head(y.mean(dim=1))
               iv.  Append to output lists
           d. Return dict with all outputs

        3. IMPORTANT: during training, gradients flow through ALL recursion
           iterations within each supervision step (full backprop, not 1-step
           approximation). This is the key insight that gives +30.9% accuracy.
    """
    def __init__(self, dim=128, num_heads=8, num_layers=2, num_classes=NUM_CLASSES,
                 grid_size=GRID_SIZE, T=3, n=6):
        super().__init__()
        # TODO: Implement initialization
        raise NotImplementedError("Implement TinyRecursiveModel.__init__")

    def forward(self, wafer_map):
        # TODO: Implement forward pass with recursion and deep supervision
        raise NotImplementedError("Implement TinyRecursiveModel.forward")

In [None]:
# Verification for TRM
def verify_trm():
    """Test TRM architecture basics."""
    model = TinyRecursiveModel(dim=128, num_heads=8, num_layers=2,
                                num_classes=9, grid_size=26, T=3, n=6)
    model = model.to(DEVICE)

    # Count parameters
    param_count = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {param_count:,}")
    assert param_count < 20_000_000, f"Too many parameters: {param_count:,}"

    # Test forward pass
    dummy_input = torch.randint(0, 3, (4, 676)).to(DEVICE)  # batch of 4
    output = model(dummy_input)

    assert 'logits' in output, "Missing 'logits' in output"
    assert len(output['logits']) == 3, f"Expected 3 supervision steps, got {len(output['logits'])}"
    assert output['logits'][0].shape == (4, 9), f"Wrong logits shape: {output['logits'][0].shape}"
    assert output['halt_probs'][0].shape == (4, 1), f"Wrong halt shape: {output['halt_probs'][0].shape}"
    assert 0 <= output['halt_probs'][0].min() <= output['halt_probs'][0].max() <= 1, "Halt probs out of range"

    print(f"Model parameters: {param_count:,} (target: <10M)")
    print(f"Output logits shape per step: {output['logits'][0].shape}")
    print(f"Number of supervision steps: {len(output['logits'])}")
    print("All TRM checks passed!")

verify_trm()

**Thought questions:**
1. Why do we initialize y and z to zeros rather than random values? What would happen if we used random initialization?
2. The effective depth is 42 layers (3 x 7 x 2). A standard 42-layer transformer would have 42x more parameters. What is the tradeoff? In what situations might the standard transformer outperform TRM despite having more parameters?
3. Why does the halting head use sigmoid (outputting a probability) rather than outputting a discrete stop/continue decision?

---

## 3.5 Training Strategy

In [None]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

**Why AdamW?** The loss landscape of recursive models is complex -- per-parameter learning rates navigate it more effectively than a single global rate (SGD). Weight decay decoupling is critical when using parameter sharing.

**Why cosine schedule?** Provides smooth decay that avoids abrupt drops. Important because the deep supervision signal evolves during training.

**EMA (Exponential Moving Average):** Maintains a shadow copy of weights for evaluation stability. Critical for small datasets (+7.5% accuracy in ablations).

In [None]:
class EMA:
    """Exponential Moving Average of model parameters."""
    def __init__(self, model, decay=0.999):
        self.model = model
        self.decay = decay
        self.shadow = {}
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = (
                    self.decay * self.shadow[name] + (1 - self.decay) * param.data
                )

    def apply_shadow(self):
        """Replace model params with EMA params for evaluation."""
        self.backup = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data = self.shadow[name]

    def restore(self):
        """Restore original params after evaluation."""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]

### TODO: Deep Supervision Loss

In [None]:
def compute_trm_loss(output, targets, class_weights):
    """
    Compute the TRM loss with deep supervision.

    The loss is the sum of prediction loss + halting loss across all T
    supervision steps.

    Args:
        output: dict from TinyRecursiveModel.forward() with 'logits' and 'halt_probs'
        targets: LongTensor of shape (batch,) with true class labels (0-8)
        class_weights: FloatTensor of shape (num_classes,) for weighted cross-entropy

    Returns:
        dict with:
            - 'total_loss': scalar, the combined loss
            - 'pred_losses': list of T prediction losses
            - 'halt_losses': list of T halting losses
            - 'per_step_accuracy': list of T accuracy values (for logging)

    Steps:
        1. For each supervision step t:
           a. Compute weighted cross-entropy: CE(logits[t], targets, weight=class_weights)
           b. Compute per-sample correctness: q = (argmax(logits[t]) == targets).float()
           c. Compute halting loss: BCE(halt_probs[t].squeeze(), q)
           d. Step loss = CE + BCE
        2. Total loss = sum of all step losses
        3. Compute per-step accuracy for logging

    Hints:
        - Use F.cross_entropy with the weight parameter for class-weighted CE
        - Use F.binary_cross_entropy for halting loss (halt_probs already has sigmoid applied)
        - Detach q when computing BCE -- we do not want gradients flowing through the correctness check
    """
    # TODO: Implement loss computation
    raise NotImplementedError("Implement TRM loss with deep supervision")

### TODO: Training Loop

In [None]:
def train_one_epoch(model, train_loader, optimizer, class_weights, ema, device):
    """
    Train the TRM for one epoch.

    Args:
        model: TinyRecursiveModel instance
        train_loader: DataLoader for training data
        optimizer: AdamW optimizer
        class_weights: Tensor of class weights for imbalanced data
        ema: EMA instance
        device: torch device

    Returns:
        dict with:
            - 'avg_loss': float, average total loss over epoch
            - 'avg_accuracy': float, average accuracy (final supervision step)
            - 'per_step_accuracies': list of T floats, average accuracy per step

    Steps:
        1. Set model to train mode
        2. For each batch:
           a. Move data to device
           b. Forward pass through model
           c. Compute loss via compute_trm_loss
           d. Backward pass and optimizer step
           e. Update EMA weights
           f. Log running metrics
        3. Return epoch-level metrics

    Important:
        - Use torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
          to prevent gradient explosion through the deep recursion chain
        - Zero gradients BEFORE the forward pass, not after
    """
    # TODO: Implement training loop
    raise NotImplementedError("Implement training loop")

In [None]:
def evaluate(model, test_loader, class_weights, device):
    """
    Evaluate the TRM on a test/validation set.

    Args:
        model: TinyRecursiveModel instance (should use EMA weights)
        test_loader: DataLoader for evaluation data
        class_weights: Tensor of class weights
        device: torch device

    Returns:
        dict with:
            - 'accuracy': float, overall accuracy
            - 'macro_f1': float, macro-averaged F1 score
            - 'per_class_f1': dict mapping class_name -> f1
            - 'avg_loss': float, average loss
            - 'confusion_matrix': np.array of shape (9, 9)
            - 'per_step_accuracies': list of T accuracies

    Steps:
        1. Set model to eval mode, use torch.no_grad()
        2. Collect all predictions and targets
        3. Compute metrics using sklearn
    """
    # TODO: Implement evaluation
    raise NotImplementedError("Implement evaluation")

### Main Training Loop

In [None]:
# Prepare data loaders
from torch.utils.data import DataLoader, TensorDataset

BATCH_SIZE = 64

# Flatten wafer maps to sequences
X_train_flat = X_train.reshape(len(X_train), -1)  # (N, 676)
X_val_flat = X_val.reshape(len(X_val), -1)
X_test_flat = X_test.reshape(len(X_test), -1)

train_dataset = TensorDataset(
    torch.tensor(X_train_flat, dtype=torch.long),
    torch.tensor(y_train, dtype=torch.long)
)
val_dataset = TensorDataset(
    torch.tensor(X_val_flat, dtype=torch.long),
    torch.tensor(y_val, dtype=torch.long)
)
test_dataset = TensorDataset(
    torch.tensor(X_test_flat, dtype=torch.long),
    torch.tensor(y_test, dtype=torch.long)
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

In [None]:
# Initialize model and training
model = TinyRecursiveModel(dim=128, num_heads=8, num_layers=2,
                            num_classes=NUM_CLASSES, grid_size=GRID_SIZE, T=3, n=6)
model = model.to(DEVICE)

param_count = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {param_count:,}")

# Compute class weights from training set
train_class_counts = np.bincount(y_train, minlength=NUM_CLASSES)
class_weights = torch.tensor(
    1.0 / (train_class_counts + 1e-6), dtype=torch.float32
).to(DEVICE)
class_weights = class_weights / class_weights.sum() * NUM_CLASSES

optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=50)
ema = EMA(model, decay=0.999)

# Training loop
NUM_EPOCHS = 50
best_f1 = 0.0
best_state = None
patience_counter = 0
history = {'train_loss': [], 'val_accuracy': [], 'val_f1': []}

for epoch in range(NUM_EPOCHS):
    train_metrics = train_one_epoch(model, train_loader, optimizer,
                                     class_weights, ema, DEVICE)
    scheduler.step()

    ema.apply_shadow()
    val_metrics = evaluate(model, val_loader, class_weights, DEVICE)
    ema.restore()

    history['train_loss'].append(train_metrics['avg_loss'])
    history['val_accuracy'].append(val_metrics['accuracy'])
    history['val_f1'].append(val_metrics['macro_f1'])

    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
          f"Loss: {train_metrics['avg_loss']:.4f} | "
          f"Val Acc: {val_metrics['accuracy']:.4f} | "
          f"Val F1: {val_metrics['macro_f1']:.4f} | "
          f"LR: {scheduler.get_last_lr()[0]:.6f}")

    if val_metrics['macro_f1'] > best_f1:
        best_f1 = val_metrics['macro_f1']
        best_state = {k: v.clone() for k, v in model.state_dict().items()}
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= 10:
            print(f"Early stopping at epoch {epoch+1}")
            break

print(f"\nBest validation F1: {best_f1:.4f}")
model.load_state_dict(best_state)

**Thought questions:**
1. Why do we clip gradients to max_norm=1.0? What would happen without gradient clipping in a model that unrolls 18 recursion steps?
2. The EMA decay is set to 0.999. What would happen if we set it to 0.9 (faster EMA) or 0.9999 (slower EMA)? When is each appropriate?
3. Deep supervision computes the loss at 3 intermediate points. Could we use different learning rates or loss weights for each supervision step? What would be the motivation?

---

## 3.6 Evaluation

Quantitative evaluation of the trained TRM against both baselines.

In [None]:
# Final test evaluation with EMA weights
ema.apply_shadow()
test_metrics = evaluate(model, test_loader, class_weights, DEVICE)
ema.restore()

print(f"\nFinal Test Results:")
print(f"  Overall Accuracy: {test_metrics['accuracy']:.4f}")
print(f"  Macro F1:         {test_metrics['macro_f1']:.4f}")
print(f"\nPer-class F1:")
for cls_name, f1_val in test_metrics['per_class_f1'].items():
    print(f"  {cls_name:<15} {f1_val:.4f}")

### TODO: Evaluation Visualizations

In [None]:
def plot_evaluation_results(test_metrics, history):
    """
    Generate comprehensive evaluation plots.

    Create a 2x2 figure with:
        1. Confusion matrix heatmap (top-left)
        2. Per-class F1 bar chart (top-right)
        3. Training curves: loss and val accuracy over epochs (bottom-left)
        4. Per-supervision-step accuracy (bottom-right) -- showing how accuracy
           improves from step 1 to step 3

    Args:
        test_metrics: dict from evaluate() for TRM
        history: dict with training history

    Hints:
        - Use plt.imshow for confusion matrix with 'Blues' colormap
        - Annotate each cell of confusion matrix with the count
        - Use twin axes (ax.twinx()) for overlaying loss and accuracy
        - Annotate the per-step accuracy plot with the improvement from step 1 to step 3
    """
    # TODO: Implement evaluation plots
    raise NotImplementedError("Implement evaluation visualizations")

plot_evaluation_results(test_metrics, history)

In [None]:
# Deployment readiness check
def check_targets(test_metrics):
    """Verify model meets deployment requirements."""
    param_count = sum(p.numel() for p in model.parameters())
    checks = {
        'Accuracy > 93%': test_metrics['accuracy'] > 0.93,
        'Macro F1 > 0.90': test_metrics['macro_f1'] > 0.90,
        f'Params < 10M ({param_count:,})': param_count < 10_000_000,
    }
    print("\nDeployment Readiness Check:")
    all_passed = True
    for check, passed in checks.items():
        status = "PASS" if passed else "FAIL"
        print(f"  [{status}] {check}")
        if not passed:
            all_passed = False

    if all_passed:
        print("\nModel meets all deployment criteria!")
    else:
        print("\nModel does NOT meet all criteria. Review failure modes and iterate.")

check_targets(test_metrics)

**Thought questions:**
1. Which classes have the lowest F1 scores? Examine the confusion matrix -- which classes are most commonly confused? Does this make physical sense given the defect patterns?
2. How much does accuracy improve from supervision step 1 to step 3? What does this tell you about the value of recursive reasoning for this task?
3. If you could add one more evaluation metric relevant to the business problem, what would it be and why?

---

## 3.7 Error Analysis

Systematic investigation of failure modes to guide improvements.

In [None]:
def collect_errors(model, test_loader, device):
    """Collect all misclassified wafer maps with predictions and confidences."""
    model.eval()
    errors = []
    with torch.no_grad():
        for wafer_maps, batch_labels in test_loader:
            wafer_maps, batch_labels = wafer_maps.to(device), batch_labels.to(device)
            output = model(wafer_maps)
            probs = F.softmax(output['final_logits'], dim=-1)
            preds = probs.argmax(dim=-1)

            mask = preds != batch_labels
            for j in range(mask.sum()):
                idx = mask.nonzero()[j].item()
                errors.append({
                    'wafer_map': wafer_maps[idx].cpu(),
                    'true_label': batch_labels[idx].item(),
                    'pred_label': preds[idx].item(),
                    'confidence': probs[idx, preds[idx]].item(),
                    'true_prob': probs[idx, batch_labels[idx]].item(),
                    'all_probs': probs[idx].cpu().numpy()
                })
    return errors

errors = collect_errors(model, test_loader, DEVICE)
print(f"Total errors: {len(errors)}")

### TODO: Error Categorization

In [None]:
def categorize_errors(errors):
    """
    Categorize misclassifications into failure modes.

    Args:
        errors: list of dicts from collect_errors()

    Returns:
        dict with:
            - 'high_confidence_errors': list of errors where confidence > 0.8
            - 'confusion_pairs': dict mapping (true, pred) -> count
            - 'boundary_errors': list of errors where true and pred are "neighboring" classes

    Additionally, print:
        - Top 3 most common confusion pairs
        - Average confidence on correct vs incorrect predictions
        - Percentage of high-confidence errors

    Hints:
        1. "Neighboring" classes: (center, donut), (edge-loc, edge-ring),
           (loc, random), (near-full, random)
        2. Sort confusion pairs by count to find the top 3
    """
    # TODO: Implement error categorization
    raise NotImplementedError("Implement error categorization")

error_categories = categorize_errors(errors)

In [None]:
def visualize_top_errors(errors, n=6):
    """
    Visualize the most informative error cases.

    Show the top-n errors by confidence (high-confidence misclassifications),
    with each subplot showing:
    - The wafer map (reshaped to 26x26)
    - True label and predicted label
    - Confidence bar chart over all 9 classes

    Args:
        errors: list of error dicts
        n: number of errors to visualize
    """
    # TODO: Implement error visualization
    raise NotImplementedError("Implement error visualization")

visualize_top_errors(errors, n=6)

**Thought questions:**
1. Identify the top 3 failure modes. For each, propose a specific intervention -- could it be addressed by architecture changes, data augmentation, or post-processing?
2. Are high-confidence errors clustered in specific classes? What does this imply for the halting mechanism?
3. How could the error analysis inform the deployment strategy? For example, should certain predictions be automatically flagged for human review?

---

## 3.8 Latency Profiling and Deployment

In [None]:
import time

def profile_inference_latency(model, device, num_samples=1000, warmup=50):
    """Profile inference latency on the current device."""
    model.eval()
    dummy_input = torch.randint(0, 3, (1, 676)).to(device)

    # Warmup
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(dummy_input)

    # Profile
    latencies = []
    with torch.no_grad():
        for _ in range(num_samples):
            if device.type == 'cuda':
                torch.cuda.synchronize()
            start = time.perf_counter()
            _ = model(dummy_input)
            if device.type == 'cuda':
                torch.cuda.synchronize()
            latencies.append((time.perf_counter() - start) * 1000)

    latencies = np.array(latencies)
    print(f"Inference Latency Profile ({device}):")
    print(f"  p50:  {np.percentile(latencies, 50):.2f} ms")
    print(f"  p90:  {np.percentile(latencies, 90):.2f} ms")
    print(f"  p99:  {np.percentile(latencies, 99):.2f} ms")
    print(f"  mean: {np.mean(latencies):.2f} ms")
    return latencies

latencies = profile_inference_latency(model, DEVICE)

### TODO: Adaptive Inference with Early Halting

In [None]:
def inference_with_halting(model, wafer_map, halt_threshold=0.9, device=DEVICE):
    """
    Run inference with adaptive halting -- stop recursing when the model
    is confident it has the right answer.

    Args:
        model: TinyRecursiveModel instance
        wafer_map: LongTensor of shape (1, 676) -- single wafer map
        halt_threshold: float, stop when halt probability exceeds this
        device: torch device

    Returns:
        dict with:
            - 'prediction': int, predicted class
            - 'confidence': float, prediction confidence
            - 'num_steps_used': int, how many supervision steps were actually run
            - 'latency_ms': float, actual inference time
            - 'halt_probs': list of halt probabilities at each step

    Implementation:
        1. Run the model step-by-step (not all T steps at once)
        2. After each supervision step, check the halt probability
        3. If halt_prob > halt_threshold, stop and return current prediction
        4. Otherwise, continue to next supervision step
        5. Time the entire process

    Hint: For simplicity, run the full forward pass and check halting at each step.
    The latency savings in production come from actually stopping the recursion early.
    """
    # TODO: Implement adaptive inference
    raise NotImplementedError("Implement inference with halting")

In [None]:
# Compare full vs halted inference
def compare_inference_modes(model, test_loader, device, halt_threshold=0.9):
    """Compare accuracy and latency of full vs halted inference."""
    model.eval()
    full_correct = 0
    halt_correct = 0
    total = 0
    steps_used = []

    with torch.no_grad():
        for wafer_maps, batch_labels in test_loader:
            wafer_maps, batch_labels = wafer_maps.to(device), batch_labels.to(device)

            # Full inference
            output = model(wafer_maps)
            full_preds = output['final_logits'].argmax(dim=-1)
            full_correct += (full_preds == batch_labels).sum().item()

            # Halted inference (per-sample)
            for i in range(len(wafer_maps)):
                result = inference_with_halting(model, wafer_maps[i:i+1],
                                                halt_threshold, device)
                if result['prediction'] == batch_labels[i].item():
                    halt_correct += 1
                steps_used.append(result['num_steps_used'])

            total += len(batch_labels)

    print(f"\nFull inference accuracy:  {full_correct/total:.4f}")
    print(f"Halted inference accuracy: {halt_correct/total:.4f}")
    print(f"Average steps used: {np.mean(steps_used):.2f} / 3")
    print(f"Step distribution: {Counter(steps_used)}")

# Run on a subset for speed
small_test = torch.utils.data.Subset(test_dataset, range(min(500, len(test_dataset))))
small_loader = DataLoader(small_test, batch_size=1, shuffle=False)
compare_inference_modes(model, small_loader, DEVICE)

**Thought questions:**
1. What is the tradeoff between halt_threshold and accuracy? At what threshold does accuracy start to drop noticeably?
2. The Jetson Orin Nano runs at INT8 precision with 40 TOPS. How would you quantize the TRM model for deployment? What accuracy loss would you expect?
3. The production pipeline has 50ms for classification. How would you allocate the time budget across preprocessing, inference, and postprocessing?

---

## 3.9 Ethical and Regulatory Analysis

### TODO: Ethical Impact Assessment

In [None]:
def ethical_impact_assessment():
    """
    Write a brief ethical impact assessment for deploying the TRM-based
    defect classifier in semiconductor production.

    Address the following (print your answers):

    1. BIAS AND FAIRNESS
       - The WM-811K dataset comes from specific foundries. How might this introduce bias?
       - If the model performs poorly on a specific defect type, what is the downstream impact?
       - How would you monitor for performance degradation on minority classes over time?

    2. AUTOMATION AND HUMAN OVERSIGHT
       - The TRM model will replace or augment human inspectors. What is the appropriate
         level of human oversight?
       - What is the failure mode if the model encounters a novel defect pattern not in training?
       - How should the system handle out-of-distribution inputs?

    3. REGULATORY COMPLIANCE
       - ITAR: Some semiconductor manufacturing data may be ITAR-controlled. What constraints
         does this place on model training, deployment, and updates?
       - Export controls for international deployment
       - Data retention policies for production quality audits

    4. ENVIRONMENTAL IMPACT
       - Compare energy consumption: TRM edge inference vs cloud CNN inference
       - Estimate annual energy savings for a foundry processing 10,000 wafers/day
       - Edge: ~15W per Jetson * 40ms per wafer; Cloud: ~50W per GPU * 280ms per wafer

    Print your assessment as a structured document with headers and bullet points.
    """
    # TODO: Write ethical impact assessment
    raise NotImplementedError("Write ethical impact assessment")

ethical_impact_assessment()

---

## Summary

In this notebook, you built a complete pipeline for edge-deployed wafer defect classification using a Tiny Recursive Model:

1. **Data**: Loaded and preprocessed the WM-811K wafer map dataset, analyzed class distributions and spatial patterns
2. **Baselines**: Implemented rule-based (spatial features + decision tree) and CNN baselines
3. **Model**: Built the TRM architecture from scratch -- rotary position embeddings, RMSNorm, SwiGLU, recursive blocks with dual state (solution y + reasoning z), and deep supervision
4. **Training**: Trained with AdamW, cosine schedule, EMA, gradient clipping, and class-weighted deep supervision loss
5. **Evaluation**: Measured accuracy, macro F1, and compared against baselines
6. **Error Analysis**: Categorized failure modes and identified high-confidence misclassifications
7. **Deployment**: Profiled latency and implemented adaptive halting for edge inference
8. **Ethics**: Assessed bias, automation oversight, regulatory compliance, and environmental impact

The key insight: recursive reasoning with a tiny shared-weight network achieves the computational depth of a 42-layer model with only 7M parameters -- making it suitable for edge deployment while maintaining accuracy that exceeds much larger single-pass architectures.

For further reading on production deployment, system design, monitoring, and CI/CD for ML, refer to **Section 4** of the full case study document (`case_study.md`).