In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Case Study: Wafer Defect Detection with Variational Autoencoders
## Implementation Notebook

Welcome to the implementation notebook for the NovaSilicon wafer defect detection case study. In this notebook, you will build a Convolutional Variational Autoencoder (VAE) for unsupervised anomaly detection in semiconductor wafer inspection.

**Business Context:** NovaSilicon, a semiconductor manufacturer producing safety-critical automotive chips, faces a $109M annual cost from inspection inadequacies. Their current rule-based inspection system suffers from throughput bottlenecks, inconsistent human inspectors (78% agreement rate), and an 11-day lag in detecting novel defect types. The VAE approach learns what "normal" wafers look like, then flags anything that deviates -- enabling detection of previously unseen defect types without labeled defect data.

**What you will build:**
- A data pipeline for anomaly detection using FashionMNIST as a proxy dataset
- Exploratory data analysis to understand normal vs. anomalous patterns
- A PCA baseline for anomaly detection (linear autoencoder)
- A Convolutional VAE from scratch with encoder, decoder, and reparameterization
- A training pipeline with KL annealing and cosine learning rate scheduling
- A full evaluation pipeline with AUROC, AUPRC, and ROC curves
- Error analysis with reconstruction heatmaps
- Deployment benchmarking with TorchScript export

**Prerequisites:** Familiarity with PyTorch, convolutional neural networks, and the VAE mathematical framework. Read Sections 1-2 of the accompanying case study document before starting.

**Runtime:** This notebook is designed to run on Google Colab with a T4 GPU. Total runtime is approximately 15-20 minutes.

---

## 3.1 Data Acquisition Strategy

We use FashionMNIST as a proxy dataset for wafer anomaly detection. Class 0 (T-shirt/Top) serves as the "normal" class representing defect-free wafer images, while all other classes serve as "anomalies" representing defective wafers with various defect types.

This proxy is valid because anomaly detection methods are data-agnostic: we train exclusively on the normal class and evaluate the model's ability to distinguish normal from anomalous at test time. The principles transfer directly to real wafer imagery.

**Mapping to the NovaSilicon scenario:**
- T-shirt/Top (class 0) = defect-free wafer images
- All other classes = various defect types (scratches, particles, pattern shifts, etc.)
- Training set: normal images only (matching NovaSilicon's abundant normal data)
- Test set: mix of normal and anomalous (matching production inspection)

In [None]:
# Install dependencies
!pip install -q torch torchvision matplotlib scikit-learn tqdm

In [None]:
# Core imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset, Subset
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve, auc
from sklearn.decomposition import PCA
from tqdm import tqdm
import os
import time
import warnings
warnings.filterwarnings('ignore')

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Configuration
IMG_SIZE = 128  # 128x128 for fast training on Colab
NORMAL_CLASS = 0  # FashionMNIST class 0 (T-shirt/Top) = "normal" wafer
BATCH_SIZE = 64
LATENT_DIM = 64  # Latent space dimensionality
NUM_EPOCHS = 30
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5
KL_ANNEAL_EPOCHS = 10  # Linearly anneal beta from 0 to 1 over this many epochs

# FashionMNIST class names (for reference)
FASHION_CLASSES = [
    'T-shirt/Top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]
print(f"Normal class: {FASHION_CLASSES[NORMAL_CLASS]} (proxy for defect-free wafer)")
print(f"Anomaly classes: {[c for i, c in enumerate(FASHION_CLASSES) if i != NORMAL_CLASS]}")

In [None]:
# Download and prepare the dataset
# Base transform: resize to 128x128, convert to tensor, normalize to [0, 1]
base_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),  # Converts to [0, 1] range
])

# Download FashionMNIST
train_full = datasets.FashionMNIST(
    root='./data', train=True, download=True, transform=base_transform
)
test_full = datasets.FashionMNIST(
    root='./data', train=False, download=True, transform=base_transform
)

print(f"Full training set: {len(train_full)} images")
print(f"Full test set: {len(test_full)} images")

In [None]:
# Split into normal (training) and anomaly (testing)
# Training: ONLY normal images (class 0)
train_normal_indices = [i for i, (_, label) in enumerate(train_full) if label == NORMAL_CLASS]

# Test set: balanced mix of normal and anomalous
test_normal_indices = [i for i, (_, label) in enumerate(test_full) if label == NORMAL_CLASS]
test_anomaly_indices = [i for i, (_, label) in enumerate(test_full) if label != NORMAL_CLASS]

# Subsample anomalies to create a realistic imbalance (more normal than anomalous in test)
np.random.seed(42)
test_anomaly_subset = np.random.choice(test_anomaly_indices, size=400, replace=False).tolist()

# Create datasets
train_dataset = Subset(train_full, train_normal_indices)

# Split training into train/val (80/20)
n_train = int(0.8 * len(train_dataset))
n_val = len(train_dataset) - n_train
train_split, val_split = torch.utils.data.random_split(
    train_dataset, [n_train, n_val],
    generator=torch.Generator().manual_seed(42)
)

# Test dataset with labels: 0 = normal, 1 = anomaly
test_indices = test_normal_indices + test_anomaly_subset
test_labels = [0] * len(test_normal_indices) + [1] * len(test_anomaly_subset)
test_dataset = Subset(test_full, test_indices)

print(f"Training set (normal only): {len(train_split)} images")
print(f"Validation set (normal only): {len(val_split)} images")
print(f"Test set: {len(test_normal_indices)} normal + {len(test_anomaly_subset)} anomaly = {len(test_indices)} total")

In [None]:
# Create DataLoaders
train_loader = DataLoader(train_split, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=2, pin_memory=True)
val_loader = DataLoader(val_split, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                         num_workers=2, pin_memory=True)

# Verify shapes
sample_batch, _ = next(iter(train_loader))
print(f"Batch shape: {sample_batch.shape}")
print(f"Pixel range: [{sample_batch.min():.3f}, {sample_batch.max():.3f}]")
print(f"Training batches per epoch: {len(train_loader)}")

### TODO: Data Augmentation Pipeline

In a real wafer inspection system, certain augmentations are physically meaningful while others are not. Horizontal and vertical flips are valid because wafer images have no canonical orientation on the inspection microscope. Small rotations are valid for the same reason. Gaussian noise simulates sensor noise variations across inspection stations. However, color jitter is NOT appropriate because intensity variations in grayscale wafer images carry physical meaning (e.g., film thickness variations).

In [None]:
def create_augmented_transform(img_size=128):
    """
    Create an augmented transform pipeline for training wafer images.

    Implement the following augmentation pipeline:
    1. Resize to img_size x img_size
    2. RandomHorizontalFlip with p=0.5
    3. RandomVerticalFlip with p=0.5
    4. RandomRotation up to 15 degrees
    5. Add Gaussian noise (implement as a custom transform)
    6. ToTensor (converts to [0, 1])

    Hints:
    - For Gaussian noise, create a class that adds N(0, sigma) noise
      and clamps the result to [0, 1]. Use sigma=0.02.
    - Do NOT use ColorJitter -- intensity in grayscale wafer images
      carries physical meaning (film thickness).
    - RandomRotation with fill=0 for black borders (or use expand=False).

    Args:
        img_size: Target image size (default 128)

    Returns:
        augmented_transform: transforms.Compose pipeline
    """
    # TODO: Implement the GaussianNoise custom transform class
    # class GaussianNoise:
    #     def __init__(self, sigma=0.02):
    #         ...
    #     def __call__(self, tensor):
    #         ...  # Add noise and clamp to [0, 1]

    # TODO: Compose the augmentation pipeline
    augmented_transform = None  # Your code here

    return augmented_transform

# Verification
aug_transform = create_augmented_transform()
assert aug_transform is not None, "Augmented transform not implemented"
print("Augmented transform created successfully.")

**Question for reflection:** Why is it important NOT to use augmented data for the validation and test sets? How would augmentation during evaluation affect our anomaly detection metrics?

---

## 3.2 Exploratory Data Analysis

Before building any model, we must understand the data distributions. This is especially important for anomaly detection: we need to confirm that the normal class has consistent visual patterns and that anomalies are visually distinct.

In [None]:
# Visualize a grid of normal training images
fig, axes = plt.subplots(4, 8, figsize=(16, 8))
fig.suptitle('Normal Training Images (Defect-Free Wafer Proxy)', fontsize=14, fontweight='bold')

for i, ax in enumerate(axes.flat):
    img, _ = train_split[i]
    ax.imshow(img.squeeze(), cmap='gray', vmin=0, vmax=1)
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Visualize anomalous test images (various "defect types")
fig, axes = plt.subplots(4, 8, figsize=(16, 8))
fig.suptitle('Anomalous Test Images (Defective Wafer Proxy)', fontsize=14, fontweight='bold')

anomaly_only_indices = [i for i, lbl in enumerate(test_labels) if lbl == 1]
for i, ax in enumerate(axes.flat):
    idx = anomaly_only_indices[i]
    img, original_label = test_dataset[idx]
    ax.imshow(img.squeeze(), cmap='gray', vmin=0, vmax=1)
    ax.set_title(FASHION_CLASSES[original_label], fontsize=7)
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Pixel intensity distributions: normal vs anomalous
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Collect pixel values from normal images
normal_pixels = []
for i in range(min(200, len(train_split))):
    img, _ = train_split[i]
    normal_pixels.append(img.numpy().flatten())
normal_pixels = np.concatenate(normal_pixels)

# Collect pixel values from anomalous images
anomaly_pixels = []
for i in anomaly_only_indices[:200]:
    img, _ = test_dataset[i]
    anomaly_pixels.append(img.numpy().flatten())
anomaly_pixels = np.concatenate(anomaly_pixels)

axes[0].hist(normal_pixels, bins=50, alpha=0.7, color='steelblue', density=True, label='Normal')
axes[0].hist(anomaly_pixels, bins=50, alpha=0.7, color='crimson', density=True, label='Anomaly')
axes[0].set_xlabel('Pixel Intensity')
axes[0].set_ylabel('Density')
axes[0].set_title('Pixel Intensity Distribution')
axes[0].legend()

# Mean and std images
all_normal_imgs = torch.stack([train_split[i][0] for i in range(min(500, len(train_split)))])
mean_img = all_normal_imgs.mean(dim=0).squeeze()
std_img = all_normal_imgs.std(dim=0).squeeze()

axes[1].imshow(mean_img, cmap='gray')
axes[1].set_title('Mean Normal Image')
axes[1].axis('off')

plt.tight_layout()
plt.show()

# Standard deviation image (shows which regions vary most)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
im = ax.imshow(std_img, cmap='hot')
ax.set_title('Pixel Standard Deviation (Normal Images)')
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046)
plt.tight_layout()
plt.show()

print(f"Normal pixel mean: {normal_pixels.mean():.4f}, std: {normal_pixels.std():.4f}")
print(f"Anomaly pixel mean: {anomaly_pixels.mean():.4f}, std: {anomaly_pixels.std():.4f}")

### TODO: EDA Analysis Questions

In [None]:
def answer_eda_questions():
    """
    Answer the following questions based on your EDA above.
    Replace each None with a string containing your answer (2-3 sentences each).

    Questions:
    1. What is the dominant visual pattern in the normal images?
       Consider: texture, shape, intensity distribution.

    2. How do the anomalous images differ visually from normal ones?
       Consider: structural differences, intensity patterns, shape variations.

    3. Are the differences localized (small regions) or global (entire image)?
       Consider: where in the image the differences are most apparent.

    4. Based on the pixel distributions alone, could simple thresholding on
       mean pixel intensity detect anomalies? Why or why not?
       Consider: overlap between normal and anomaly distributions.

    Return your answers as a dictionary.
    """
    answers = {
        "q1_dominant_pattern": None,  # Your answer here
        "q2_anomaly_differences": None,  # Your answer here
        "q3_localized_or_global": None,  # Your answer here
        "q4_simple_thresholding": None,  # Your answer here
    }
    return answers

# Verification
eda_answers = answer_eda_questions()
for key, value in eda_answers.items():
    assert value is not None, f"Please answer {key}"
    assert len(value) > 30, f"Answer for {key} is too short. Provide 2-3 sentences."
print("All EDA questions answered.")
for key, value in eda_answers.items():
    print(f"\n{key}:\n  {value}")

---

## 3.3 PCA Baseline

Before building the VAE, we establish a performance floor with PCA. PCA can be viewed as a linear autoencoder: it projects data into a low-dimensional subspace and reconstructs it. The reconstruction error serves as an anomaly score because PCA captures the dominant modes of variation in normal images, and anomalies deviate from these modes.

This is the approach NovaSilicon's current rule-based system implicitly approximates with hand-crafted features. Beating PCA with the VAE demonstrates the value of nonlinear feature learning.

In [None]:
# Flatten training images for PCA
print("Preparing data for PCA...")
train_flat = []
for i in range(len(train_split)):
    img, _ = train_split[i]
    train_flat.append(img.numpy().flatten())
train_flat = np.array(train_flat)

# Flatten test images
test_flat = []
for i in range(len(test_dataset)):
    img, _ = test_dataset[i]
    test_flat.append(img.numpy().flatten())
test_flat = np.array(test_flat)

test_labels_arr = np.array(test_labels)

print(f"Training matrix: {train_flat.shape}")
print(f"Test matrix: {test_flat.shape}")
print(f"Feature dimension: {train_flat.shape[1]} (= {IMG_SIZE}x{IMG_SIZE})")

### TODO: PCA Anomaly Detection Baseline

In [None]:
def pca_anomaly_detection(train_data, test_data, test_labels, k_values=[10, 50, 100, 200]):
    """
    Implement PCA-based anomaly detection for multiple values of k (number of components).

    For each k in k_values:
    1. Fit PCA with k components on train_data (normal images only)
    2. Transform test_data into the k-dimensional PCA space
    3. Reconstruct the test_data by inverse-transforming back to the original space
    4. Compute per-image MSE between original and reconstructed test images
    5. Use the MSE as the anomaly score
    6. Compute AUROC using sklearn.metrics.roc_auc_score(test_labels, anomaly_scores)

    Hints:
    - Use sklearn.decomposition.PCA(n_components=k)
    - pca.fit(train_data) to fit on normal data only
    - pca.transform(test_data) to project test data
    - pca.inverse_transform(projected) to reconstruct
    - MSE = np.mean((original - reconstructed) ** 2, axis=1) gives per-image scores

    Args:
        train_data: np.array of shape (n_train, d) -- flattened normal images
        test_data: np.array of shape (n_test, d) -- flattened test images
        test_labels: np.array of shape (n_test,) -- 0 for normal, 1 for anomaly
        k_values: list of integers -- number of PCA components to try

    Returns:
        results: dict mapping k -> {'auroc': float, 'scores': np.array, 'pca': fitted PCA object}
    """
    results = {}

    for k in k_values:
        # TODO: Step 1 - Fit PCA with k components on train_data

        # TODO: Step 2 - Transform and reconstruct test_data

        # TODO: Step 3 - Compute per-image MSE as anomaly score

        # TODO: Step 4 - Compute AUROC

        # results[k] = {'auroc': auroc, 'scores': scores, 'pca': pca}
        pass

    return results

# Run PCA baseline
pca_results = pca_anomaly_detection(train_flat, test_flat, test_labels_arr)

# Print results
print("PCA Baseline Results:")
print("-" * 40)
for k, res in sorted(pca_results.items()):
    print(f"  k={k:4d} components: AUROC = {res['auroc']:.4f}")

In [None]:
# Verification: PCA results should exist and AUROC should be reasonable
assert len(pca_results) > 0, "PCA results dictionary is empty -- implement pca_anomaly_detection()"
for k, res in pca_results.items():
    assert 'auroc' in res, f"Missing 'auroc' key for k={k}"
    assert 0.5 <= res['auroc'] <= 1.0, f"AUROC for k={k} is {res['auroc']:.4f} -- should be between 0.5 and 1.0"
    assert 'scores' in res, f"Missing 'scores' key for k={k}"
print("PCA baseline verification passed.")

In [None]:
# Plot ROC curves for all PCA configurations
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors = ['#2196F3', '#4CAF50', '#FF9800', '#F44336']
for i, (k, res) in enumerate(sorted(pca_results.items())):
    fpr, tpr, _ = roc_curve(test_labels_arr, res['scores'])
    axes[0].plot(fpr, tpr, color=colors[i % len(colors)],
                 label=f'k={k} (AUROC={res["auroc"]:.3f})', linewidth=2)

axes[0].plot([0, 1], [0, 1], 'k--', alpha=0.3)
axes[0].set_xlabel('False Positive Rate')
axes[0].set_ylabel('True Positive Rate')
axes[0].set_title('PCA Baseline: ROC Curves')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Bar chart of AUROC vs k
k_vals = sorted(pca_results.keys())
aurocs = [pca_results[k]['auroc'] for k in k_vals]
axes[1].bar([str(k) for k in k_vals], aurocs, color='steelblue', alpha=0.8)
axes[1].set_xlabel('Number of PCA Components (k)')
axes[1].set_ylabel('AUROC')
axes[1].set_title('PCA Baseline: AUROC vs Components')
axes[1].set_ylim(0.5, 1.0)
axes[1].grid(True, alpha=0.3, axis='y')

for i, (k, a) in enumerate(zip(k_vals, aurocs)):
    axes[1].text(i, a + 0.01, f'{a:.3f}', ha='center', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.show()

# Store best PCA result for later comparison
best_k = max(pca_results, key=lambda k: pca_results[k]['auroc'])
pca_best_auroc = pca_results[best_k]['auroc']
print(f"\nBest PCA baseline: k={best_k}, AUROC={pca_best_auroc:.4f}")

---

## 3.4 Convolutional VAE Model

Now we build the core model: a Convolutional VAE. The architecture uses strided convolutions for downsampling in the encoder and transposed convolutions for upsampling in the decoder.

**Key design decisions for anomaly detection:**
- LeakyReLU prevents dead neurons, ensuring the encoder produces gradients for all input patterns (including anomalous ones at test time).
- BatchNorm stabilizes training with higher learning rates.
- Sigmoid output ensures pixel values are in [0, 1], matching our input normalization.
- The latent dimension (64) balances reconstruction quality against latent space regularity.

### Encoder Architecture

The encoder compresses a 128x128 grayscale image down to a latent vector through four convolutional blocks. Each block halves the spatial dimensions via stride-2 convolution.

| Layer | Output Shape | Description |
|-------|-------------|-------------|
| Input | (1, 128, 128) | Grayscale image |
| Conv Block 1 | (32, 64, 64) | Conv2d(1, 32, 4, stride=2, padding=1) + BN + LeakyReLU |
| Conv Block 2 | (64, 32, 32) | Conv2d(32, 64, 4, stride=2, padding=1) + BN + LeakyReLU |
| Conv Block 3 | (128, 16, 16) | Conv2d(64, 128, 4, stride=2, padding=1) + BN + LeakyReLU |
| Conv Block 4 | (256, 8, 8) | Conv2d(128, 256, 4, stride=2, padding=1) + BN + LeakyReLU |
| Flatten | (16384,) | 256 * 8 * 8 |
| Linear (mu) | (64,) | Mean of latent distribution |
| Linear (logvar) | (64,) | Log-variance of latent distribution |

### TODO: Implement the Encoder

In [None]:
class ConvEncoder(nn.Module):
    """
    Convolutional encoder for the VAE.

    Takes a (1, 128, 128) grayscale image and produces mu and logvar
    vectors of dimension latent_dim.

    Architecture:
    - 4 convolutional blocks: Conv2d -> BatchNorm2d -> LeakyReLU(0.2)
    - Each conv uses kernel_size=4, stride=2, padding=1 (halves spatial dims)
    - Channel progression: 1 -> 32 -> 64 -> 128 -> 256
    - Flatten the output of the last conv block
    - Two separate Linear layers: one for mu, one for logvar
    """

    def __init__(self, latent_dim=64):
        super().__init__()
        self.latent_dim = latent_dim

        # Convolutional layers (these are provided)
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(256)

        # Flattened dimension: 256 * 8 * 8 = 16384
        self.flat_dim = 256 * 8 * 8

        # Linear heads for mu and logvar
        self.fc_mu = nn.Linear(self.flat_dim, latent_dim)
        self.fc_logvar = nn.Linear(self.flat_dim, latent_dim)

    def forward(self, x):
        """
        Forward pass through the encoder.

        TODO: Implement this method.

        Steps:
        1. Pass x through conv1 -> bn1 -> LeakyReLU(0.2)
        2. Pass through conv2 -> bn2 -> LeakyReLU(0.2)
        3. Pass through conv3 -> bn3 -> LeakyReLU(0.2)
        4. Pass through conv4 -> bn4 -> LeakyReLU(0.2)
        5. Flatten the output (use x.view(x.size(0), -1))
        6. Compute mu = self.fc_mu(flattened)
        7. Compute logvar = self.fc_logvar(flattened)
        8. Return (mu, logvar)

        Hints:
        - Use F.leaky_relu(x, 0.2) for the activation
        - The batch dimension (dim 0) must be preserved throughout
        - After step 4, the shape should be (batch, 256, 8, 8)

        Args:
            x: Input tensor of shape (batch, 1, 128, 128)

        Returns:
            mu: Mean tensor of shape (batch, latent_dim)
            logvar: Log-variance tensor of shape (batch, latent_dim)
        """
        # TODO: Implement the forward pass
        raise NotImplementedError("Implement ConvEncoder.forward()")

In [None]:
# Verification: test encoder output shapes
encoder = ConvEncoder(latent_dim=LATENT_DIM).to(device)
test_input = torch.randn(2, 1, IMG_SIZE, IMG_SIZE).to(device)
try:
    mu, logvar = encoder(test_input)
    assert mu.shape == (2, LATENT_DIM), f"mu shape is {mu.shape}, expected (2, {LATENT_DIM})"
    assert logvar.shape == (2, LATENT_DIM), f"logvar shape is {logvar.shape}, expected (2, {LATENT_DIM})"
    print(f"Encoder output shapes: mu={mu.shape}, logvar={logvar.shape}")
    print("Encoder verification PASSED.")
except NotImplementedError:
    print("Encoder forward() not yet implemented. Complete the TODO above.")

### Decoder Architecture

The decoder mirrors the encoder, using transposed convolutions to upsample from the latent vector back to a full-resolution image. The final sigmoid activation ensures output pixels are in [0, 1].

In [None]:
class ConvDecoder(nn.Module):
    """
    Convolutional decoder for the VAE.

    Takes a latent vector of dimension latent_dim and produces
    a reconstructed image of shape (1, 128, 128).
    """

    def __init__(self, latent_dim=64):
        super().__init__()
        self.latent_dim = latent_dim
        self.flat_dim = 256 * 8 * 8

        # Project from latent space to spatial feature map
        self.fc = nn.Linear(latent_dim, self.flat_dim)

        # Transposed convolutions (mirror of encoder)
        self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.deconv3 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(32)
        self.deconv4 = nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1)

    def forward(self, z):
        """
        Forward pass through the decoder.

        Args:
            z: Latent vector of shape (batch, latent_dim)

        Returns:
            x_recon: Reconstructed image of shape (batch, 1, 128, 128)
        """
        x = self.fc(z)
        x = x.view(x.size(0), 256, 8, 8)

        x = F.leaky_relu(self.bn1(self.deconv1(x)), 0.2)
        x = F.leaky_relu(self.bn2(self.deconv2(x)), 0.2)
        x = F.leaky_relu(self.bn3(self.deconv3(x)), 0.2)
        x = torch.sigmoid(self.deconv4(x))  # Output in [0, 1]

        return x

In [None]:
# Verify decoder
decoder = ConvDecoder(latent_dim=LATENT_DIM).to(device)
z_test = torch.randn(2, LATENT_DIM).to(device)
recon_test = decoder(z_test)
assert recon_test.shape == (2, 1, IMG_SIZE, IMG_SIZE), f"Decoder output shape: {recon_test.shape}"
assert recon_test.min() >= 0 and recon_test.max() <= 1, "Decoder output out of [0, 1] range"
print(f"Decoder output shape: {recon_test.shape}")
print(f"Decoder output range: [{recon_test.min():.4f}, {recon_test.max():.4f}]")
print("Decoder verification PASSED.")

### TODO: Reparameterization Trick and Full VAE

The reparameterization trick is the key insight that makes VAE training possible. Instead of sampling $z \sim \mathcal{N}(\mu, \sigma^2)$ directly (which is not differentiable), we reparameterize as:

$$z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)$$

This moves the stochasticity to $\epsilon$, which does not depend on the model parameters, making the entire computation differentiable with respect to $\mu$ and $\sigma$.

In [None]:
class ConvVAE(nn.Module):
    """
    Convolutional Variational Autoencoder for anomaly detection.

    Combines ConvEncoder and ConvDecoder with the reparameterization trick.
    """

    def __init__(self, latent_dim=64):
        super().__init__()
        self.encoder = ConvEncoder(latent_dim)
        self.decoder = ConvDecoder(latent_dim)
        self.latent_dim = latent_dim

    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick: sample z from q(z|x) = N(mu, sigma^2).

        TODO: Implement this method.

        Steps:
        1. Compute std = exp(0.5 * logvar)
           (logvar = log(sigma^2), so 0.5 * logvar = log(sigma), and exp gives sigma)
        2. Sample epsilon from N(0, I) with the same shape as std
           (use torch.randn_like(std))
        3. Compute z = mu + std * epsilon
        4. Return z

        Hints:
        - torch.randn_like(std) creates a tensor of the same shape as std,
          filled with samples from N(0, 1)
        - During evaluation (model.eval()), we can still use this function --
          the stochasticity helps provide a distribution of anomaly scores

        Args:
            mu: Mean of the latent distribution, shape (batch, latent_dim)
            logvar: Log-variance of the latent distribution, shape (batch, latent_dim)

        Returns:
            z: Sampled latent vector, shape (batch, latent_dim)
        """
        # TODO: Implement reparameterization trick
        raise NotImplementedError("Implement ConvVAE.reparameterize()")

    def forward(self, x):
        """
        Full forward pass: encode -> reparameterize -> decode.

        Args:
            x: Input image, shape (batch, 1, 128, 128)

        Returns:
            x_recon: Reconstructed image, shape (batch, 1, 128, 128)
            mu: Latent mean, shape (batch, latent_dim)
            logvar: Latent log-variance, shape (batch, latent_dim)
        """
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

In [None]:
# Verification: full forward pass
model = ConvVAE(latent_dim=LATENT_DIM).to(device)
test_input = torch.randn(2, 1, IMG_SIZE, IMG_SIZE).to(device)
try:
    x_recon, mu, logvar = model(test_input)
    assert x_recon.shape == test_input.shape, f"Reconstruction shape mismatch: {x_recon.shape}"
    assert mu.shape == (2, LATENT_DIM), f"mu shape: {mu.shape}"
    assert logvar.shape == (2, LATENT_DIM), f"logvar shape: {logvar.shape}"
    print(f"VAE forward pass: input={test_input.shape} -> recon={x_recon.shape}")
    print(f"Latent space: mu={mu.shape}, logvar={logvar.shape}")
    n_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {n_params:,} ({n_params/1e6:.2f}M)")
    print("VAE verification PASSED.")
except NotImplementedError as e:
    print(f"Not yet implemented: {e}")
    print("Complete the TODOs in ConvEncoder.forward() and ConvVAE.reparameterize()")

### VAE Loss Function

The VAE loss has two terms:

1. **Reconstruction loss** (BCE): measures how well the model reconstructs the input image. We use binary cross-entropy because pixel values are in [0, 1].

$$\mathcal{L}_{\text{recon}} = -\sum_{i=1}^{D} [x_i \log \hat{x}_i + (1 - x_i) \log(1 - \hat{x}_i)]$$

2. **KL divergence**: regularizes the latent space to be close to a standard normal prior.

$$\mathcal{L}_{\text{KL}} = -\frac{1}{2} \sum_{j=1}^{J} (1 + \log \sigma_j^2 - \mu_j^2 - \sigma_j^2)$$

The total loss with KL annealing:

$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{recon}} + \beta \cdot \mathcal{L}_{\text{KL}}$$

In [None]:
def vae_loss(x_recon, x, mu, logvar, beta=1.0):
    """
    Compute the VAE loss: reconstruction (BCE) + beta * KL divergence.

    Args:
        x_recon: Reconstructed image, shape (batch, 1, H, W)
        x: Original image, shape (batch, 1, H, W)
        mu: Latent mean, shape (batch, latent_dim)
        logvar: Latent log-variance, shape (batch, latent_dim)
        beta: KL weight (for KL annealing)

    Returns:
        total_loss: Scalar, the combined loss
        recon_loss: Scalar, the reconstruction loss (for logging)
        kl_loss: Scalar, the KL divergence (for logging)
    """
    # Reconstruction loss: BCE summed over pixels, averaged over batch
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum') / x.size(0)

    # KL divergence: closed-form for Gaussian
    # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)

    total_loss = recon_loss + beta * kl_loss

    return total_loss, recon_loss, kl_loss

# Quick test
x_dummy = torch.rand(4, 1, IMG_SIZE, IMG_SIZE).to(device)
x_recon_dummy = torch.rand(4, 1, IMG_SIZE, IMG_SIZE).to(device)
mu_dummy = torch.randn(4, LATENT_DIM).to(device)
logvar_dummy = torch.randn(4, LATENT_DIM).to(device)

total, recon, kl = vae_loss(x_recon_dummy, x_dummy, mu_dummy, logvar_dummy, beta=0.5)
print(f"Loss test: total={total.item():.2f}, recon={recon.item():.2f}, kl={kl.item():.2f}")

---

## 3.5 Training

We train the VAE with three important techniques:

1. **KL Annealing:** $\beta$ linearly increases from 0 to 1 over the first 10 epochs. This prevents posterior collapse -- a failure mode where the KL term dominates early training and the model learns to ignore the latent space entirely.

2. **Cosine Annealing LR:** The learning rate follows a cosine schedule from 1e-3 down to 1e-5, providing smooth decay without the sharp drops of step schedules.

3. **Gradient Clipping:** Clips gradient norms to 1.0 to prevent training instability from large gradients.

### TODO: Implement the Training Loop

In [None]:
def train_vae(model, train_loader, val_loader, num_epochs=30, lr=1e-3,
              weight_decay=1e-5, kl_anneal_epochs=10, device='cuda'):
    """
    Train the VAE with KL annealing and cosine LR schedule.

    TODO: Implement this function.

    Steps for each epoch:
    1. Compute beta for KL annealing:
       beta = min(1.0, epoch / kl_anneal_epochs)
    2. For each batch in train_loader:
       a. Move batch to device
       b. Forward pass through model: x_recon, mu, logvar = model(x)
       c. Compute loss: total, recon, kl = vae_loss(x_recon, x, mu, logvar, beta)
       d. Backward pass and optimizer step
       e. Clip gradients: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
       f. Accumulate losses for logging
    3. Evaluate on val_loader (no gradients)
    4. Step the LR scheduler
    5. Log all metrics

    Hints:
    - Use torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    - Use torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-5)
    - IMPORTANT: Call optimizer.zero_grad() BEFORE the forward pass
    - IMPORTANT: Call torch.nn.utils.clip_grad_norm_ AFTER loss.backward()
      but BEFORE optimizer.step()
    - Use model.train() for training and model.eval() for validation
    - Wrap the training loop with tqdm for progress tracking

    Args:
        model: ConvVAE model
        train_loader: DataLoader with normal training images
        val_loader: DataLoader with normal validation images
        num_epochs: Number of training epochs
        lr: Initial learning rate
        weight_decay: AdamW weight decay
        kl_anneal_epochs: Number of epochs to anneal beta from 0 to 1
        device: 'cuda' or 'cpu'

    Returns:
        history: dict with keys 'train_total', 'train_recon', 'train_kl',
                 'val_total', 'val_recon', 'val_kl', each mapping to a list
                 of per-epoch values
    """
    history = {
        'train_total': [], 'train_recon': [], 'train_kl': [],
        'val_total': [], 'val_recon': [], 'val_kl': [],
    }

    # TODO: Create optimizer (AdamW)

    # TODO: Create LR scheduler (CosineAnnealingLR)

    for epoch in range(num_epochs):
        # TODO: Compute beta for KL annealing

        # TODO: Training loop
        # model.train()
        # for batch in train_loader:
        #     ...

        # TODO: Validation loop
        # model.eval()
        # with torch.no_grad():
        #     ...

        # TODO: Step scheduler

        # TODO: Log metrics to history

        # TODO: Print epoch summary (every 5 epochs)
        pass

    return history

# Train the model
model = ConvVAE(latent_dim=LATENT_DIM).to(device)
print(f"Training ConvVAE for {NUM_EPOCHS} epochs...")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

history = train_vae(
    model, train_loader, val_loader,
    num_epochs=NUM_EPOCHS, lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY, kl_anneal_epochs=KL_ANNEAL_EPOCHS,
    device=device
)

In [None]:
# Verification: training should have produced loss histories
assert len(history['train_total']) == NUM_EPOCHS, \
    f"Expected {NUM_EPOCHS} training loss values, got {len(history['train_total'])}"
assert history['train_total'][-1] < history['train_total'][0], \
    "Training loss did not decrease -- check your training loop"
print(f"Final training loss: {history['train_total'][-1]:.2f}")
print(f"Final validation loss: {history['val_total'][-1]:.2f}")
print("Training verification PASSED.")

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

epochs_range = range(1, NUM_EPOCHS + 1)

# Total loss
axes[0].plot(epochs_range, history['train_total'], 'b-', label='Train', linewidth=2)
axes[0].plot(epochs_range, history['val_total'], 'r--', label='Validation', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Total Loss')
axes[0].set_title('Total Loss (Recon + beta * KL)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Reconstruction loss
axes[1].plot(epochs_range, history['train_recon'], 'b-', label='Train', linewidth=2)
axes[1].plot(epochs_range, history['val_recon'], 'r--', label='Validation', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Reconstruction Loss (BCE)')
axes[1].set_title('Reconstruction Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# KL divergence
axes[2].plot(epochs_range, history['train_kl'], 'b-', label='Train', linewidth=2)
axes[2].plot(epochs_range, history['val_kl'], 'r--', label='Validation', linewidth=2)
# Overlay the beta schedule
beta_schedule = [min(1.0, e / KL_ANNEAL_EPOCHS) for e in range(NUM_EPOCHS)]
ax2 = axes[2].twinx()
ax2.plot(epochs_range, beta_schedule, 'g:', label='Beta (KL weight)', linewidth=2, alpha=0.7)
ax2.set_ylabel('Beta', color='green')
ax2.tick_params(axis='y', labelcolor='green')
ax2.set_ylim(0, 1.1)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('KL Divergence')
axes[2].set_title('KL Divergence + Beta Schedule')
axes[2].legend(loc='upper left')
ax2.legend(loc='upper right')
axes[2].grid(True, alpha=0.3)

plt.suptitle('VAE Training Curves', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# Visualize reconstructions on validation data
model.eval()
with torch.no_grad():
    val_batch, _ = next(iter(val_loader))
    val_batch = val_batch.to(device)
    recon_batch, _, _ = model(val_batch)

n_show = 8
fig, axes = plt.subplots(2, n_show, figsize=(2 * n_show, 4))
fig.suptitle('Validation Reconstructions: Original (top) vs Reconstructed (bottom)',
             fontsize=12, fontweight='bold')

for i in range(n_show):
    axes[0, i].imshow(val_batch[i].cpu().squeeze(), cmap='gray', vmin=0, vmax=1)
    axes[0, i].axis('off')
    axes[1, i].imshow(recon_batch[i].cpu().squeeze(), cmap='gray', vmin=0, vmax=1)
    axes[1, i].axis('off')

axes[0, 0].set_ylabel('Original', fontsize=10)
axes[1, 0].set_ylabel('Reconstructed', fontsize=10)
plt.tight_layout()
plt.show()

---

## 3.6 Evaluation

Now we evaluate the trained VAE as an anomaly detector. The core idea: the VAE was trained only on normal images, so it reconstructs normal images well (low reconstruction error) but fails to reconstruct anomalous images (high reconstruction error). The reconstruction error is our anomaly score.

### TODO: Implement the Evaluation Pipeline

In [None]:
def compute_anomaly_scores(model, data_loader, device='cuda'):
    """
    Compute anomaly scores for all images in the data loader.

    The anomaly score for each image is the mean squared error (MSE)
    between the original and reconstructed image.

    TODO: Implement this function.

    Steps:
    1. Set model to eval mode
    2. For each batch:
       a. Forward pass to get reconstructions
       b. Compute per-image MSE: mean over all pixels of (x - x_recon)^2
       c. Store the score for each image
    3. Return all scores as a numpy array

    Hints:
    - Use torch.no_grad() to disable gradient computation
    - MSE per image: ((x - x_recon) ** 2).mean(dim=[1, 2, 3])
      This averages over channels (1), height (2), and width (3)
    - Convert to numpy with .cpu().numpy()

    Args:
        model: Trained ConvVAE model
        data_loader: DataLoader with test images
        device: 'cuda' or 'cpu'

    Returns:
        scores: np.array of shape (n_images,) with anomaly scores
    """
    # TODO: Implement anomaly score computation
    raise NotImplementedError("Implement compute_anomaly_scores()")

# Compute anomaly scores
print("Computing anomaly scores on test set...")
anomaly_scores = compute_anomaly_scores(model, test_loader, device)
test_labels_arr = np.array(test_labels)

normal_scores = anomaly_scores[test_labels_arr == 0]
anomaly_scores_only = anomaly_scores[test_labels_arr == 1]

print(f"Normal scores: mean={normal_scores.mean():.6f}, std={normal_scores.std():.6f}")
print(f"Anomaly scores: mean={anomaly_scores_only.mean():.6f}, std={anomaly_scores_only.std():.6f}")

In [None]:
# Verification
assert anomaly_scores is not None, "Anomaly scores not computed"
assert len(anomaly_scores) == len(test_labels), \
    f"Score count ({len(anomaly_scores)}) != label count ({len(test_labels)})"
assert anomaly_scores_only.mean() > normal_scores.mean(), \
    "Anomaly scores should be higher than normal scores on average"
print("Anomaly score computation verification PASSED.")

In [None]:
# Compute evaluation metrics
vae_auroc = roc_auc_score(test_labels_arr, anomaly_scores)
precision, recall, _ = precision_recall_curve(test_labels_arr, anomaly_scores)
vae_auprc = auc(recall, precision)

# FPR and TPR for ROC curve
fpr, tpr, thresholds = roc_curve(test_labels_arr, anomaly_scores)

# Find operating point at FPR = 5%
fpr_5_idx = np.argmin(np.abs(fpr - 0.05))
tpr_at_fpr5 = tpr[fpr_5_idx]
threshold_at_fpr5 = thresholds[fpr_5_idx]
fnr_at_fpr5 = 1 - tpr_at_fpr5

print("=" * 50)
print("VAE Anomaly Detection Results")
print("=" * 50)
print(f"AUROC:                    {vae_auroc:.4f}")
print(f"AUPRC:                    {vae_auprc:.4f}")
print(f"TPR at FPR=5%:            {tpr_at_fpr5:.4f}")
print(f"FNR at FPR=5%:            {fnr_at_fpr5:.4f}")
print(f"Threshold at FPR=5%:      {threshold_at_fpr5:.6f}")
print(f"PCA best AUROC:           {pca_best_auroc:.4f}")
print(f"VAE improvement over PCA: {vae_auroc - pca_best_auroc:+.4f}")

In [None]:
# Plot 1: Anomaly score distributions
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Score distribution histogram
axes[0].hist(normal_scores, bins=50, alpha=0.7, color='steelblue', label='Normal', density=True)
axes[0].hist(anomaly_scores_only, bins=50, alpha=0.7, color='crimson', label='Anomaly', density=True)
axes[0].axvline(threshold_at_fpr5, color='green', linestyle='--', linewidth=2,
                label=f'Threshold (FPR=5%)')
axes[0].set_xlabel('Anomaly Score (MSE)')
axes[0].set_ylabel('Density')
axes[0].set_title('Anomaly Score Distributions')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# ROC curve
axes[1].plot(fpr, tpr, 'b-', linewidth=2, label=f'VAE (AUROC={vae_auroc:.3f})')
# Add PCA baseline
best_pca_scores = pca_results[best_k]['scores']
pca_fpr, pca_tpr, _ = roc_curve(test_labels_arr, best_pca_scores)
axes[1].plot(pca_fpr, pca_tpr, 'r--', linewidth=2, label=f'PCA k={best_k} (AUROC={pca_best_auroc:.3f})')
axes[1].plot([0, 1], [0, 1], 'k--', alpha=0.3)
# Mark FPR=5% operating point
axes[1].plot(fpr[fpr_5_idx], tpr[fpr_5_idx], 'go', markersize=12,
             label=f'FPR=5% (TPR={tpr_at_fpr5:.3f})')
axes[1].set_xlabel('False Positive Rate')
axes[1].set_ylabel('True Positive Rate')
axes[1].set_title('ROC Curve: VAE vs PCA Baseline')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Precision-Recall curve
axes[2].plot(recall, precision, 'b-', linewidth=2, label=f'VAE (AUPRC={vae_auprc:.3f})')
axes[2].set_xlabel('Recall')
axes[2].set_ylabel('Precision')
axes[2].set_title('Precision-Recall Curve')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.suptitle('VAE Anomaly Detection Performance', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

---

## 3.7 Error Analysis

Understanding where the model fails is critical for production deployment. False negatives (missed defects) are far more costly than false positives (false alarms) in safety-critical semiconductor inspection.

In [None]:
# Classify predictions at the FPR=5% threshold
predictions = (anomaly_scores >= threshold_at_fpr5).astype(int)

# Find error cases
false_negatives = np.where((test_labels_arr == 1) & (predictions == 0))[0]
false_positives = np.where((test_labels_arr == 0) & (predictions == 1))[0]
true_positives = np.where((test_labels_arr == 1) & (predictions == 1))[0]
true_negatives = np.where((test_labels_arr == 0) & (predictions == 0))[0]

print(f"Confusion Matrix at FPR=5% threshold:")
print(f"  True Negatives  (correct normal):  {len(true_negatives)}")
print(f"  True Positives  (caught defects):  {len(true_positives)}")
print(f"  False Positives (false alarms):    {len(false_positives)}")
print(f"  False Negatives (missed defects):  {len(false_negatives)}")

In [None]:
# Visualize false negatives: defective images the model MISSED
model.eval()
with torch.no_grad():
    n_show = min(5, len(false_negatives))
    if n_show > 0:
        fig, axes = plt.subplots(3, n_show, figsize=(3 * n_show, 9))
        fig.suptitle('False Negatives: Defects the Model Missed', fontsize=13, fontweight='bold')

        for i in range(n_show):
            idx = false_negatives[i]
            img, original_label = test_dataset[idx]
            img_tensor = img.unsqueeze(0).to(device)
            recon, _, _ = model(img_tensor)

            original = img.squeeze().cpu().numpy()
            reconstructed = recon.squeeze().cpu().numpy()
            heatmap = (original - reconstructed) ** 2

            axes[0, i].imshow(original, cmap='gray', vmin=0, vmax=1)
            axes[0, i].set_title(f'Original\n({FASHION_CLASSES[original_label]})', fontsize=9)
            axes[0, i].axis('off')

            axes[1, i].imshow(reconstructed, cmap='gray', vmin=0, vmax=1)
            axes[1, i].set_title(f'Reconstructed\nScore: {anomaly_scores[idx]:.5f}', fontsize=9)
            axes[1, i].axis('off')

            im = axes[2, i].imshow(heatmap, cmap='hot', vmin=0)
            axes[2, i].set_title('Error Heatmap', fontsize=9)
            axes[2, i].axis('off')

        axes[0, 0].set_ylabel('Original', fontsize=10)
        axes[1, 0].set_ylabel('Reconstruction', fontsize=10)
        axes[2, 0].set_ylabel('Error Heatmap', fontsize=10)
        plt.tight_layout()
        plt.show()
    else:
        print("No false negatives at this threshold -- the model caught all anomalies.")

In [None]:
# Visualize false positives: normal images the model incorrectly flagged
with torch.no_grad():
    n_show = min(5, len(false_positives))
    if n_show > 0:
        fig, axes = plt.subplots(3, n_show, figsize=(3 * n_show, 9))
        fig.suptitle('False Positives: Normal Images Incorrectly Flagged', fontsize=13, fontweight='bold')

        for i in range(n_show):
            idx = false_positives[i]
            img, _ = test_dataset[idx]
            img_tensor = img.unsqueeze(0).to(device)
            recon, _, _ = model(img_tensor)

            original = img.squeeze().cpu().numpy()
            reconstructed = recon.squeeze().cpu().numpy()
            heatmap = (original - reconstructed) ** 2

            axes[0, i].imshow(original, cmap='gray', vmin=0, vmax=1)
            axes[0, i].set_title('Original\n(Normal)', fontsize=9)
            axes[0, i].axis('off')

            axes[1, i].imshow(reconstructed, cmap='gray', vmin=0, vmax=1)
            axes[1, i].set_title(f'Reconstructed\nScore: {anomaly_scores[idx]:.5f}', fontsize=9)
            axes[1, i].axis('off')

            im = axes[2, i].imshow(heatmap, cmap='hot', vmin=0)
            axes[2, i].set_title('Error Heatmap', fontsize=9)
            axes[2, i].axis('off')

        axes[0, 0].set_ylabel('Original', fontsize=10)
        axes[1, 0].set_ylabel('Reconstruction', fontsize=10)
        axes[2, 0].set_ylabel('Error Heatmap', fontsize=10)
        plt.tight_layout()
        plt.show()
    else:
        print("No false positives at this threshold.")

### TODO: Error Analysis and Failure Mode Categorization

In [None]:
def categorize_failure_modes(false_negatives, false_positives, anomaly_scores,
                             test_dataset, test_labels, fashion_classes):
    """
    Analyze and categorize the model's failure modes.

    TODO: Implement this function.

    Steps:
    1. For false negatives (missed defects):
       a. Group by the original FashionMNIST class label
       b. Count how many missed detections come from each "defect type"
       c. Identify the top 3 hardest defect types to detect

    2. For false positives (false alarms):
       a. Compute the anomaly score for each false positive
       b. Identify common patterns (e.g., are these images at the boundary
          of the normal distribution?)

    3. Propose one concrete mitigation for each of the top 3 failure modes.
       Mitigations could include:
       - Data augmentation strategies
       - Architecture modifications (e.g., attention mechanisms)
       - Ensemble methods (e.g., multiple VAEs with different latent dims)
       - Post-processing (e.g., multi-scale anomaly scoring)

    Hints:
    - Use a dictionary to count false negatives by original class
    - The "hardest" defect type is the one with the highest miss rate
    - Consider: WHY does the model reconstruct some anomalies well?
      Is the anomaly visually similar to the normal class?

    Args:
        false_negatives: np.array of indices of missed defects
        false_positives: np.array of indices of false alarms
        anomaly_scores: np.array of all anomaly scores
        test_dataset: The test Dataset object
        test_labels: np.array of test labels (0=normal, 1=anomaly)
        fashion_classes: List of class names

    Returns:
        analysis: dict with keys:
          - 'fn_by_class': dict mapping class_name -> count of false negatives
          - 'top_3_failure_modes': list of 3 strings describing failure modes
          - 'mitigations': list of 3 strings with proposed mitigations
    """
    # TODO: Implement failure mode analysis
    analysis = {
        'fn_by_class': None,
        'top_3_failure_modes': None,
        'mitigations': None,
    }
    return analysis

# Run analysis
analysis = categorize_failure_modes(
    false_negatives, false_positives, anomaly_scores,
    test_dataset, test_labels_arr, FASHION_CLASSES
)

# Print results
if analysis['fn_by_class'] is not None:
    print("False Negatives by Defect Type:")
    for cls, count in sorted(analysis['fn_by_class'].items(), key=lambda x: -x[1]):
        print(f"  {cls:20s}: {count}")

    print("\nTop 3 Failure Modes:")
    for i, mode in enumerate(analysis['top_3_failure_modes'], 1):
        print(f"  {i}. {mode}")

    print("\nProposed Mitigations:")
    for i, mit in enumerate(analysis['mitigations'], 1):
        print(f"  {i}. {mit}")
else:
    print("Complete the categorize_failure_modes() TODO above.")

In [None]:
# Verification
assert analysis['fn_by_class'] is not None, "Failure mode analysis not implemented"
assert analysis['top_3_failure_modes'] is not None, "Top 3 failure modes not identified"
assert len(analysis['top_3_failure_modes']) == 3, "Exactly 3 failure modes required"
assert analysis['mitigations'] is not None, "Mitigations not proposed"
assert len(analysis['mitigations']) == 3, "Exactly 3 mitigations required"
print("Error analysis verification PASSED.")

---

## 3.8 Deployment Considerations

For NovaSilicon's production deployment, the model must meet strict latency requirements (< 50ms per image on a T4 GPU) and be exportable to a format suitable for serving with NVIDIA Triton or TorchServe.

### TODO: Benchmark and Export the Model

In [None]:
def benchmark_inference(model, img_size=128, device='cuda', n_warmup=10, n_runs=100):
    """
    Benchmark the model's inference latency.

    TODO: Implement this function.

    Steps:
    1. Create a dummy input tensor of shape (1, 1, img_size, img_size)
    2. Run n_warmup forward passes to warm up the GPU
    3. Time n_runs forward passes (use time.time() or torch.cuda.Event)
    4. Report mean and P99 latency
    5. Also benchmark batch inference for batch_sizes = [1, 8, 32, 64]

    Hints:
    - Use torch.cuda.synchronize() before timing to ensure GPU operations complete
    - For P99 latency: sort the times, take the 99th percentile
      P99 index = int(0.99 * n_runs)
    - Use model.eval() and torch.no_grad()
    - For batch benchmarking, create tensors of shape (batch_size, 1, img_size, img_size)

    Args:
        model: Trained ConvVAE model
        img_size: Input image size
        device: 'cuda' or 'cpu'
        n_warmup: Number of warmup iterations
        n_runs: Number of timed iterations

    Returns:
        results: dict with:
          - 'single_mean_ms': Mean latency for single image
          - 'single_p99_ms': P99 latency for single image
          - 'batch_results': dict mapping batch_size -> mean latency per image in ms
    """
    # TODO: Implement benchmarking
    results = {
        'single_mean_ms': None,
        'single_p99_ms': None,
        'batch_results': None,
    }
    return results

# Run benchmarks
print("Benchmarking inference latency...")
bench_results = benchmark_inference(model, img_size=IMG_SIZE, device=device)

if bench_results['single_mean_ms'] is not None:
    print(f"\nSingle Image Inference:")
    print(f"  Mean latency: {bench_results['single_mean_ms']:.2f} ms")
    print(f"  P99  latency: {bench_results['single_p99_ms']:.2f} ms")
    meets_requirement = bench_results['single_mean_ms'] < 50
    print(f"  Meets 50ms requirement: {'YES' if meets_requirement else 'NO'}")

    if bench_results['batch_results'] is not None:
        print(f"\nBatch Inference (per-image latency):")
        for bs, lat in sorted(bench_results['batch_results'].items()):
            print(f"  Batch size {bs:3d}: {lat:.2f} ms/image")
else:
    print("Complete the benchmark_inference() TODO above.")

In [None]:
# Export to TorchScript
print("Exporting model to TorchScript...")
model.eval()
model_cpu = model.cpu()

try:
    # TorchScript export via tracing
    dummy_input = torch.randn(1, 1, IMG_SIZE, IMG_SIZE)
    traced_model = torch.jit.trace(model_cpu, dummy_input)

    # Save the traced model
    torchscript_path = "vae_wafer_detector.pt"
    traced_model.save(torchscript_path)

    # Verify: load and compare outputs
    loaded_model = torch.jit.load(torchscript_path)
    with torch.no_grad():
        original_out, _, _ = model_cpu(dummy_input)
        loaded_out, _, _ = loaded_model(dummy_input)

    max_diff = (original_out - loaded_out).abs().max().item()
    print(f"TorchScript model saved to: {torchscript_path}")
    print(f"Max output difference (original vs loaded): {max_diff:.8f}")
    assert max_diff < 1e-5, "TorchScript output differs from original"
    print("TorchScript export verification PASSED.")
except Exception as e:
    print(f"TorchScript export failed: {e}")
    print("This may happen if the model uses operations not supported by TorchScript.")
    print("In production, consider using torch.onnx.export() as an alternative.")

# Move model back to device for any subsequent use
model = model.to(device)

In [None]:
# Model size and memory analysis
n_params = sum(p.numel() for p in model.parameters())
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)

print("Model Summary:")
print(f"  Total parameters:     {n_params:,}")
print(f"  Model size (params):  {model_size_mb:.2f} MB")

if os.path.exists(torchscript_path):
    disk_size_mb = os.path.getsize(torchscript_path) / (1024 * 1024)
    print(f"  TorchScript on disk:  {disk_size_mb:.2f} MB")

if device.type == 'cuda':
    torch.cuda.reset_peak_memory_stats()
    dummy = torch.randn(1, 1, IMG_SIZE, IMG_SIZE).to(device)
    with torch.no_grad():
        _ = model(dummy)
    peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
    print(f"  Peak GPU memory:      {peak_mem_mb:.2f} MB")
    print(f"  Meets 15M param budget: {'YES' if n_params < 15_000_000 else 'NO'}")

In [None]:
# Verification for benchmarking
assert bench_results['single_mean_ms'] is not None, "Benchmarking not implemented"
assert bench_results['batch_results'] is not None, "Batch benchmarking not implemented"
print("Deployment benchmarking verification PASSED.")

---

## 3.9 Ethical and Regulatory Analysis

Even in industrial manufacturing, deploying ML systems carries ethical responsibilities. NovaSilicon's chips power safety-critical automotive applications, making the ethical implications of inspection failures particularly severe.

### TODO: Ethical Impact Assessment

In [None]:
def ethical_impact_assessment():
    """
    Write a 300-500 word ethical impact assessment for deploying the VAE-based
    wafer inspection system at NovaSilicon.

    TODO: Return a string (300-500 words) addressing ALL FOUR of the following:

    1. FAILURE MODE IMPACT: If the model misses a defect in a safety-critical
       automotive chip, what are the downstream consequences? How does the model's
       false negative rate compare to human inspectors (22% disagreement rate)?
       Propose a mitigation (e.g., human-in-the-loop for borderline cases).

    2. AUTOMATION BIAS: Will human inspectors over-rely on the model and reduce
       their own vigilance? How can the system be designed to prevent this?
       Consider: Should the model's confidence score be shown to inspectors?
       What if inspectors always agree with the model?

    3. WORKER DISPLACEMENT: If the model automates 80% of inspections, what
       happens to the 28 human inspectors? What is a responsible transition plan?
       Consider: Retraining, new roles (model monitoring, edge case review),
       phased deployment.

    4. REGULATORY COMPLIANCE: What documentation is required under ISO 9001
       and automotive quality standards (IATF 16949) to deploy an ML-based
       inspection system? Consider: Model validation protocols, change management,
       traceability, periodic recertification.

    Requirements:
    - Address all 4 points
    - 300-500 words total
    - Professional tone
    - At least one concrete mitigation per risk
    - Acknowledge both benefits and risks

    Returns:
        assessment: str, the ethical impact assessment (300-500 words)
    """
    # TODO: Write your ethical impact assessment
    assessment = None  # Replace with your assessment string
    return assessment

# Verification
assessment = ethical_impact_assessment()
assert assessment is not None, "Ethical impact assessment not written"
word_count = len(assessment.split())
assert 250 <= word_count <= 600, f"Assessment is {word_count} words -- target 300-500"
print(f"Ethical impact assessment: {word_count} words")
print("=" * 60)
print(assessment)
print("=" * 60)
print("Ethical analysis verification PASSED.")

---

## Summary and Next Steps

In this notebook, you built a complete anomaly detection pipeline for semiconductor wafer inspection:

1. **Data Pipeline:** Prepared a proxy anomaly detection dataset with normal-only training and mixed normal/anomalous testing.

2. **PCA Baseline:** Established a performance floor using linear reconstruction error as an anomaly score.

3. **Convolutional VAE:** Implemented a VAE from scratch with:
   - Convolutional encoder with BatchNorm and LeakyReLU
   - Reparameterization trick for differentiable sampling
   - Transposed convolutional decoder with sigmoid output
   - Combined BCE reconstruction + KL divergence loss

4. **Training:** Trained with KL annealing, cosine LR schedule, and gradient clipping.

5. **Evaluation:** Computed AUROC, AUPRC, and ROC curves. Compared to PCA baseline.

6. **Error Analysis:** Identified failure modes and proposed mitigations.

7. **Deployment:** Benchmarked latency, exported to TorchScript, and analyzed model size.

8. **Ethics:** Assessed the societal impact of deploying this system in safety-critical manufacturing.

**Key takeaway:** The VAE's ability to detect anomalies without labeled defect data makes it ideal for manufacturing inspection, where defect types are diverse and evolving. The reconstruction-based approach provides interpretable outputs (heatmaps) that human inspectors can use to understand and verify the model's decisions.

**Next steps:** Read Section 4 of the case study document for the full production system design, including the three-tier decision engine (PASS / REVIEW / REJECT), the human-in-the-loop workflow, the monitoring and retraining pipeline, and the NVIDIA Triton serving infrastructure.