# Gradient Descent Bias in Overcomplete Autoencoders

**Question**: Does gradient descent have an implicit bias toward semantically meaningful representations, even when there's no information-theoretic pressure to compress?

We train autoencoders where the hidden dimension is >= input dimension (784 for MNIST). With no bottleneck, the model could learn any arbitrary mapping that reconstructs perfectly. What does gradient descent actually find?

In [None]:
# Setup for Colab
import os
if 'google.colab' in str(get_ipython()):
    if not os.path.exists('/content/MNIST_AI'):
        !git clone https://github.com/Caleb-Briggs/MNIST_AI.git
    %cd /content/MNIST_AI
    import sys
    sys.path.append('/content/MNIST_AI')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA
from typing import Optional
import warnings
warnings.filterwarnings('ignore')

from shared.utils.data import load_mnist, get_device

device = get_device()
print(f"Using device: {device}")

## Model Definition

In [None]:
class OvercompleteAutoencoder(nn.Module):
    """Simple autoencoder with configurable hidden dimension."""
    
    def __init__(self, input_dim: int = 784, hidden_dim: int = 1024):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim)
        self.activation = nn.ReLU()
    
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        return self.activation(self.encoder(x))
    
    def decode(self, z: torch.Tensor) -> torch.Tensor:
        return self.decoder(z)
    
    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        z = self.encode(x)
        x_recon = self.decode(z)
        return x_recon, z

## Visualization Tools

In [None]:
def visualize_encoder_weights(model: OvercompleteAutoencoder, 
                               num_units: int = 64,
                               figsize: tuple = (12, 12),
                               title: str = "Encoder Weights"):
    """
    Visualize encoder weight vectors reshaped as 28x28 images.
    Each hidden unit has a 784-dim weight vector that can be viewed as an image.
    
    If weights look like single bright pixels -> identity-like
    If weights look like edges/strokes -> learning features
    If weights look like noise -> random solution
    """
    weights = model.encoder.weight.detach().cpu().numpy()  # (hidden_dim, 784)
    
    # Select subset of units to display
    n_display = min(num_units, weights.shape[0])
    grid_size = int(np.ceil(np.sqrt(n_display)))
    
    fig, axes = plt.subplots(grid_size, grid_size, figsize=figsize)
    axes = axes.flatten()
    
    for i in range(grid_size * grid_size):
        if i < n_display:
            w = weights[i].reshape(28, 28)
            # Normalize for visualization
            vmax = max(abs(w.min()), abs(w.max()))
            axes[i].imshow(w, cmap='RdBu_r', vmin=-vmax, vmax=vmax)
        axes[i].axis('off')
    
    plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    return fig

In [None]:
def visualize_reconstructions(model: OvercompleteAutoencoder,
                               images: torch.Tensor,
                               num_examples: int = 10,
                               figsize: tuple = (15, 3)):
    """Show original vs reconstructed images."""
    model.eval()
    with torch.no_grad():
        flat = images[:num_examples].view(num_examples, -1)
        recon, _ = model(flat)
    
    fig, axes = plt.subplots(2, num_examples, figsize=figsize)
    
    for i in range(num_examples):
        axes[0, i].imshow(images[i].cpu().squeeze(), cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].imshow(recon[i].cpu().view(28, 28), cmap='gray')
        axes[1, i].axis('off')
    
    axes[0, 0].set_ylabel('Original', fontsize=12)
    axes[1, 0].set_ylabel('Reconstructed', fontsize=12)
    plt.tight_layout()
    return fig

In [None]:
def compute_linear_probe_accuracy(representations: np.ndarray,
                                   labels: np.ndarray,
                                   n_components: Optional[int] = 50,
                                   train_size: int = 5000,
                                   test_size: int = 1000) -> float:
    """
    Train a linear classifier on representations and return test accuracy.
    
    Uses PCA to control for dimensionality differences between raw pixels
    and latent representations of different sizes.
    """
    # Split data
    X_train, y_train = representations[:train_size], labels[:train_size]
    X_test, y_test = representations[train_size:train_size+test_size], labels[train_size:train_size+test_size]
    
    # Apply PCA if requested (to control for dimensionality)
    if n_components is not None and n_components < representations.shape[1]:
        pca = PCA(n_components=n_components)
        X_train = pca.fit_transform(X_train)
        X_test = pca.transform(X_test)
    
    # Train linear classifier
    clf = LogisticRegression(max_iter=1000, n_jobs=-1)
    clf.fit(X_train, y_train)
    
    return clf.score(X_test, y_test)

In [None]:
def compute_weight_statistics(model: OvercompleteAutoencoder) -> dict:
    """Compute statistics about the weight matrices."""
    enc_weights = model.encoder.weight.detach().cpu().numpy()
    dec_weights = model.decoder.weight.detach().cpu().numpy()
    
    # Singular values (for rank analysis)
    enc_svd = np.linalg.svd(enc_weights, compute_uv=False)
    dec_svd = np.linalg.svd(dec_weights, compute_uv=False)
    
    # Effective rank (entropy-based)
    def effective_rank(s):
        s = s / s.sum()
        s = s[s > 1e-10]  # Avoid log(0)
        return np.exp(-np.sum(s * np.log(s)))
    
    return {
        'encoder_singular_values': enc_svd,
        'decoder_singular_values': dec_svd,
        'encoder_effective_rank': effective_rank(enc_svd),
        'decoder_effective_rank': effective_rank(dec_svd),
        'encoder_frobenius_norm': np.linalg.norm(enc_weights, 'fro'),
        'decoder_frobenius_norm': np.linalg.norm(dec_weights, 'fro'),
    }

## Training

In [None]:
class TrainingHistory:
    """Stores weight snapshots and metrics in memory."""
    
    def __init__(self):
        self.losses = []
        self.epochs = []
        self.encoder_weights = []  # List of (epoch, weights) - numpy arrays
        self.decoder_weights = []  # List of (epoch, weights) - numpy arrays
        self.encoder_bias = []     # List of (epoch, bias)
        self.decoder_bias = []     # List of (epoch, bias)
        self.linear_probe_latent = []
        self.linear_probe_raw = None
        self._data_sample = None   # Cache a sample of data for computing activations
    
    def save_weights(self, model: OvercompleteAutoencoder, epoch: int):
        """Save encoder and decoder weights and biases as numpy arrays."""
        enc_w = model.encoder.weight.detach().cpu().numpy().copy()
        dec_w = model.decoder.weight.detach().cpu().numpy().copy()
        enc_b = model.encoder.bias.detach().cpu().numpy().copy()
        dec_b = model.decoder.bias.detach().cpu().numpy().copy()
        self.encoder_weights.append((epoch, enc_w))
        self.decoder_weights.append((epoch, dec_w))
        self.encoder_bias.append((epoch, enc_b))
        self.decoder_bias.append((epoch, dec_b))
    
    def set_data_sample(self, data: np.ndarray):
        """Store a sample of data for computing activations."""
        self._data_sample = data
    
    def get_importance_order(self, epoch: int) -> np.ndarray:
        """
        Get unit indices sorted by importance (most important first).
        
        Importance = ||encoder_weights|| × mean(activation) × ||decoder_weights||
        
        This captures:
        - ||encoder_weights||: how much input the unit looks at (L2 norm)
        - mean(activation): how strongly the unit fires on average
        - ||decoder_weights||: how much the unit contributes to output
        """
        # Find weights/biases for this epoch
        enc_w, dec_w, enc_b = None, None, None
        for e, w in self.encoder_weights:
            if e == epoch:
                enc_w = w
                break
        for e, w in self.decoder_weights:
            if e == epoch:
                dec_w = w
                break
        for e, b in self.encoder_bias:
            if e == epoch:
                enc_b = b
                break
        
        if enc_w is None or dec_w is None or enc_b is None:
            return None
        
        # Encoder weight norm per unit
        # enc_w shape: (hidden_dim, 784)
        enc_norm = np.linalg.norm(enc_w, axis=1)  # (hidden_dim,)
        
        # Decoder weight norm per unit
        # dec_w shape: (784, hidden_dim)
        dec_norm = np.linalg.norm(dec_w, axis=0)  # (hidden_dim,)
        
        # Mean activation per unit
        if self._data_sample is not None:
            activations = self._data_sample @ enc_w.T + enc_b  # (N, hidden_dim)
            activations = np.maximum(activations, 0)  # ReLU
            mean_act = np.mean(activations, axis=0)  # (hidden_dim,)
        else:
            mean_act = np.ones(enc_w.shape[0])
        
        # Importance = encoder_norm × mean_activation × decoder_norm
        importance = enc_norm * mean_act * dec_norm
        
        return np.argsort(importance)[::-1]

In [None]:
def train_autoencoder(
    hidden_dim: int = 1024,
    epochs: int = 50,
    lr: float = 1e-3,
    batch_size: int = 256,
    snapshot_every: int = 1,
    compute_probes: bool = True,
    probe_every: int = 5,
    probe_n_components: int = 50,
    seed: int = 42,
    verbose: bool = True,
    single_pass: bool = False,
    snapshots_per_pass: int = 50,
    val_split: float = 0.1,
) -> tuple[OvercompleteAutoencoder, TrainingHistory]:
    """
    Train an overcomplete autoencoder and track metrics in memory.

    Args:
        single_pass: If True, each image is seen exactly once (no overfitting).
                    The 'epochs' parameter is ignored; instead we take
                    'snapshots_per_pass' evenly spaced snapshots through the data.
        snapshots_per_pass: Number of weight snapshots to save during single-pass training.
        val_split: Fraction of data to hold out for validation (default 10%).
    """
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Load data
    images, labels = load_mnist(device, train=True)
    images_flat = images.view(images.size(0), -1)  # (60000, 784)
    labels_np = labels.cpu().numpy()

    # Split into train and validation
    n_total = len(images_flat)
    n_val = int(n_total * val_split)
    n_train = n_total - n_val

    # Shuffle indices for split
    perm = torch.randperm(n_total, device=device)
    train_indices = perm[:n_train]
    val_indices = perm[n_train:]

    train_images = images_flat[train_indices]
    val_images = images_flat[val_indices]

    if verbose:
        print(f"Training set: {n_train} images")
        print(f"Validation set: {n_val} images")

    # Initialize model
    model = OvercompleteAutoencoder(input_dim=784, hidden_dim=hidden_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    history = TrainingHistory()
    history.val_losses = []  # Track validation loss

    # Store data sample for importance computation (use subset to save memory)
    history.set_data_sample(train_images[:5000].cpu().numpy())

    # Compute raw pixel baseline once
    if compute_probes:
        raw_pixels = images_flat.cpu().numpy()
        history.linear_probe_raw = compute_linear_probe_accuracy(
            raw_pixels, labels_np, n_components=probe_n_components
        )
        if verbose:
            print(f"Raw pixel linear probe accuracy: {history.linear_probe_raw:.4f}")

    # Helper to compute validation loss
    def compute_val_loss():
        model.eval()
        with torch.no_grad():
            recon, _ = model(val_images)
            val_loss = criterion(recon, val_images).item()
        model.train()
        return val_loss

    # Save initial state (epoch 0)
    history.save_weights(model, 0)
    history.epochs.append(0)
    history.losses.append(float('inf'))
    history.val_losses.append(compute_val_loss())

    if compute_probes:
        model.eval()
        with torch.no_grad():
            _, latent = model(images_flat)
        acc = compute_linear_probe_accuracy(
            latent.cpu().numpy(), labels_np, n_components=probe_n_components
        )
        history.linear_probe_latent.append((0, acc))
        if verbose:
            print(f"Epoch 0 - Val loss: {history.val_losses[-1]:.6f} - Latent probe: {acc:.4f}")

    n_batches = len(train_images) // batch_size

    if single_pass:
        # Single-pass training: each image seen exactly once
        if verbose:
            print(f"\nSingle-pass mode: {n_train} train images, {n_batches} batches")
            print(f"Taking {snapshots_per_pass} snapshots during training\n")

        # Shuffle training data once at the start
        train_perm = torch.randperm(n_train, device=device)

        # Calculate snapshot intervals (in terms of batches)
        snapshot_batches = set(int(i * n_batches / snapshots_per_pass) for i in range(1, snapshots_per_pass + 1))

        model.train()
        running_loss = 0.0
        loss_count = 0

        for batch_idx in range(n_batches):
            idx = train_perm[batch_idx * batch_size:(batch_idx + 1) * batch_size]
            batch = train_images[idx]

            optimizer.zero_grad()
            recon, _ = model(batch)
            loss = criterion(recon, batch)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            loss_count += 1

            # Check if we should snapshot
            if batch_idx + 1 in snapshot_batches:
                # Use images seen as the "epoch" marker
                images_seen = (batch_idx + 1) * batch_size
                avg_loss = running_loss / loss_count
                val_loss = compute_val_loss()

                history.losses.append(avg_loss)
                history.val_losses.append(val_loss)
                history.epochs.append(images_seen)
                history.save_weights(model, images_seen)

                if verbose:
                    print(f"Images: {images_seen:,} / {n_train:,} - Train: {avg_loss:.6f} - Val: {val_loss:.6f}")

                running_loss = 0.0
                loss_count = 0

        # Final probe
        if compute_probes:
            model.eval()
            with torch.no_grad():
                _, latent = model(images_flat)
            acc = compute_linear_probe_accuracy(
                latent.cpu().numpy(), labels_np, n_components=probe_n_components
            )
            history.linear_probe_latent.append((n_train, acc))
            if verbose:
                print(f"\nFinal latent probe accuracy: {acc:.4f}")

    else:
        # Standard multi-epoch training
        for epoch in range(1, epochs + 1):
            model.train()
            epoch_loss = 0.0

            # Shuffle each epoch
            train_perm = torch.randperm(n_train, device=device)

            for i in range(n_batches):
                idx = train_perm[i * batch_size:(i + 1) * batch_size]
                batch = train_images[idx]

                optimizer.zero_grad()
                recon, _ = model(batch)
                loss = criterion(recon, batch)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

            epoch_loss /= n_batches
            val_loss = compute_val_loss()

            history.losses.append(epoch_loss)
            history.val_losses.append(val_loss)
            history.epochs.append(epoch)

            # Save weight snapshot
            if epoch % snapshot_every == 0:
                history.save_weights(model, epoch)

            # Compute linear probe (less frequently - it's expensive)
            if compute_probes and epoch % probe_every == 0:
                model.eval()
                with torch.no_grad():
                    _, latent = model(images_flat)
                acc = compute_linear_probe_accuracy(
                    latent.cpu().numpy(), labels_np, n_components=probe_n_components
                )
                history.linear_probe_latent.append((epoch, acc))
                if verbose:
                    print(f"Epoch {epoch}/{epochs} - Train: {epoch_loss:.6f} - Val: {val_loss:.6f} - Probe: {acc:.4f}")
            elif verbose:
                print(f"Epoch {epoch}/{epochs} - Train: {epoch_loss:.6f} - Val: {val_loss:.6f}")

    return model, history

## Analysis Tools

In [None]:
def plot_training_summary(history: TrainingHistory, figsize: tuple = (12, 4)):
    """Plot loss curves (train + validation) and linear probe comparison."""
    fig, axes = plt.subplots(1, 2, figsize=figsize)

    # Loss curves
    axes[0].plot(history.epochs, history.losses, label='Train')
    if hasattr(history, 'val_losses') and history.val_losses:
        axes[0].plot(history.epochs, history.val_losses, label='Validation')
    axes[0].set_xlabel('Images seen' if history.epochs[-1] > 100 else 'Epoch')
    axes[0].set_ylabel('MSE Loss')
    axes[0].set_title('Reconstruction Loss')
    axes[0].set_yscale('log')
    axes[0].legend()

    # Linear probe comparison
    if history.linear_probe_latent:
        probe_epochs = [e for e, _ in history.linear_probe_latent]
        probe_accs = [a for _, a in history.linear_probe_latent]
        axes[1].plot(probe_epochs, probe_accs, 'b-o', label='Latent')
        axes[1].axhline(y=history.linear_probe_raw, color='r', linestyle='--', label='Raw pixels')
        axes[1].set_xlabel('Images seen' if history.epochs[-1] > 100 else 'Epoch')
        axes[1].set_ylabel('Accuracy')
        axes[1].set_title('Linear Probe Accuracy')
        axes[1].legend()

    plt.tight_layout()
    return fig

In [None]:
def visualize_weight_evolution(history: TrainingHistory,
                                num_units: int = 8,
                                sort_by_importance: bool = True,
                                figsize_per_unit: float = 1.5):
    """
    Visualize how hidden units evolve across ALL epochs.
    
    Creates a grid: rows = hidden units, columns = epochs.
    
    If sort_by_importance=True, units are ordered by their importance 
    (decoder weight norm) at EACH epoch. This means the same row might 
    show different units across epochs - you're watching "the most important
    unit", "second most important", etc.
    
    If sort_by_importance=False, shows fixed unit indices 0, 1, 2, ...
    """
    n_snapshots = len(history.encoder_weights)
    
    fig, axes = plt.subplots(
        num_units, n_snapshots,
        figsize=(figsize_per_unit * n_snapshots, figsize_per_unit * num_units)
    )
    
    if num_units == 1:
        axes = axes.reshape(1, -1)
    
    for col, (epoch, enc_weights) in enumerate(history.encoder_weights):
        # Get unit ordering for this epoch
        if sort_by_importance:
            order = history.get_importance_order(epoch)
            if order is None:
                order = np.arange(enc_weights.shape[0])
        else:
            order = np.arange(enc_weights.shape[0])
        
        for row in range(num_units):
            unit_idx = order[row]
            w = enc_weights[unit_idx].reshape(28, 28)
            vmax = max(abs(w.min()), abs(w.max()))
            axes[row, col].imshow(w, cmap='RdBu_r', vmin=-vmax, vmax=vmax)
            axes[row, col].axis('off')
            
            if row == 0:
                axes[row, col].set_title(f'E{epoch}', fontsize=8)
            if col == 0:
                label = f'Rank {row+1}' if sort_by_importance else f'Unit {row}'
                axes[row, col].set_ylabel(label, fontsize=8)
    
    title = 'Weight Evolution (sorted by importance per epoch)' if sort_by_importance else 'Weight Evolution (fixed unit indices)'
    plt.suptitle(title, fontsize=10)
    plt.tight_layout()
    return fig


def visualize_all_weights_at_epoch(history: TrainingHistory,
                                    epoch: int,
                                    num_units: int = 64,
                                    sort_by_importance: bool = True,
                                    figsize: tuple = (12, 12)):
    """
    Show a grid of encoder weights at a specific epoch.
    If sort_by_importance=True, most important units shown first (top-left to bottom-right).
    """
    # Find the snapshot for this epoch
    enc_weights = None
    for e, w in history.encoder_weights:
        if e == epoch:
            enc_weights = w
            break
    
    if enc_weights is None:
        print(f"No snapshot for epoch {epoch}")
        return None
    
    # Get ordering
    if sort_by_importance:
        order = history.get_importance_order(epoch)
        if order is None:
            order = np.arange(enc_weights.shape[0])
    else:
        order = np.arange(enc_weights.shape[0])
    
    n_display = min(num_units, enc_weights.shape[0])
    grid_size = int(np.ceil(np.sqrt(n_display)))
    
    fig, axes = plt.subplots(grid_size, grid_size, figsize=figsize)
    axes = axes.flatten()
    
    for i in range(grid_size * grid_size):
        if i < n_display:
            unit_idx = order[i]
            w = enc_weights[unit_idx].reshape(28, 28)
            vmax = max(abs(w.min()), abs(w.max()))
            axes[i].imshow(w, cmap='RdBu_r', vmin=-vmax, vmax=vmax)
        axes[i].axis('off')
    
    sort_label = " (by importance)" if sort_by_importance else ""
    plt.suptitle(f'Encoder Weights at Epoch {epoch}{sort_label}', fontsize=14)
    plt.tight_layout()
    return fig


def visualize_epoch_grid(history: TrainingHistory,
                         epochs_to_show: list[int] = None,
                         num_units: int = 16,
                         sort_by_importance: bool = True,
                         figsize_scale: float = 3):
    """
    Show weight grids for multiple epochs side by side.
    Each column is an epoch. Units sorted by importance at each epoch.
    """
    if epochs_to_show is None:
        epochs_to_show = [e for e, _ in history.encoder_weights]
    
    n_epochs = len(epochs_to_show)
    grid_size = int(np.ceil(np.sqrt(num_units)))
    
    fig, axes = plt.subplots(1, n_epochs, figsize=(figsize_scale * n_epochs, figsize_scale))
    if n_epochs == 1:
        axes = [axes]
    
    for ax, target_epoch in zip(axes, epochs_to_show):
        # Find weights for this epoch
        enc_weights = None
        for e, w in history.encoder_weights:
            if e == target_epoch:
                enc_weights = w
                break
        
        if enc_weights is None:
            ax.set_title(f'Epoch {target_epoch}\n(no data)')
            ax.axis('off')
            continue
        
        # Get ordering
        if sort_by_importance:
            order = history.get_importance_order(target_epoch)
            if order is None:
                order = np.arange(enc_weights.shape[0])
        else:
            order = np.arange(enc_weights.shape[0])
        
        # Create grid image
        grid = np.zeros((grid_size * 28, grid_size * 28))
        for i in range(min(num_units, enc_weights.shape[0])):
            row, col = i // grid_size, i % grid_size
            unit_idx = order[i]
            w = enc_weights[unit_idx].reshape(28, 28)
            w = (w - w.min()) / (w.max() - w.min() + 1e-8)
            grid[row*28:(row+1)*28, col*28:(col+1)*28] = w
        
        ax.imshow(grid, cmap='gray')
        ax.set_title(f'Epoch {target_epoch}')
        ax.axis('off')
    
    plt.tight_layout()
    return fig

In [None]:
def create_weight_animation(history: TrainingHistory,
                             num_units: int = 16,
                             sort_by_importance: bool = True,
                             interval: int = 200):
    """
    Create an animation showing weights evolving over training.
    Units sorted by importance at each frame.
    Returns an HTML object that plays in the notebook.
    """
    grid_size = int(np.ceil(np.sqrt(num_units)))
    
    fig, ax = plt.subplots(figsize=(6, 6))
    
    def get_grid(frame_idx):
        epoch, enc_weights = history.encoder_weights[frame_idx]
        
        if sort_by_importance:
            order = history.get_importance_order(epoch)
            if order is None:
                order = np.arange(enc_weights.shape[0])
        else:
            order = np.arange(enc_weights.shape[0])
        
        grid = np.zeros((grid_size * 28, grid_size * 28))
        for i in range(min(num_units, enc_weights.shape[0])):
            row, col = i // grid_size, i % grid_size
            unit_idx = order[i]
            w = enc_weights[unit_idx].reshape(28, 28)
            w = (w - w.min()) / (w.max() - w.min() + 1e-8)
            grid[row*28:(row+1)*28, col*28:(col+1)*28] = w
        return grid, epoch
    
    # Initialize with first frame
    grid, epoch = get_grid(0)
    im = ax.imshow(grid, cmap='gray', animated=True)
    title = ax.set_title(f'Epoch {epoch}')
    ax.axis('off')
    
    def update(frame):
        grid, epoch = get_grid(frame)
        im.set_array(grid)
        title.set_text(f'Epoch {epoch}')
        return [im, title]
    
    anim = animation.FuncAnimation(
        fig, update,
        frames=len(history.encoder_weights),
        interval=interval,
        blit=True
    )
    plt.close(fig)
    return HTML(anim.to_jshtml())

In [None]:
def plot_importance_distribution(history: TrainingHistory,
                                  epochs_to_show: list[int] = None,
                                  figsize: tuple = (12, 4)):
    """
    Plot the distribution of unit importances at different epochs.
    Importance = ||encoder_weights|| × mean(activation) × ||decoder_weights||
    """
    if epochs_to_show is None:
        all_epochs = [e for e, _ in history.decoder_weights]
        epochs_to_show = [all_epochs[0], all_epochs[len(all_epochs)//2], all_epochs[-1]]
    
    fig, axes = plt.subplots(1, len(epochs_to_show), figsize=figsize)
    if len(epochs_to_show) == 1:
        axes = [axes]
    
    for ax, target_epoch in zip(axes, epochs_to_show):
        # Get weights and bias
        enc_w, dec_w, enc_b = None, None, None
        for e, w in history.encoder_weights:
            if e == target_epoch:
                enc_w = w
                break
        for e, w in history.decoder_weights:
            if e == target_epoch:
                dec_w = w
                break
        for e, b in history.encoder_bias:
            if e == target_epoch:
                enc_b = b
                break
        
        if enc_w is None or dec_w is None or enc_b is None:
            continue
        
        # Compute importance = enc_norm × mean_activation × dec_norm
        enc_norm = np.linalg.norm(enc_w, axis=1)
        dec_norm = np.linalg.norm(dec_w, axis=0)
        
        if history._data_sample is not None:
            activations = history._data_sample @ enc_w.T + enc_b
            activations = np.maximum(activations, 0)
            mean_act = np.mean(activations, axis=0)
        else:
            mean_act = np.ones(enc_w.shape[0])
        
        importance = enc_norm * mean_act * dec_norm
        importance_sorted = np.sort(importance)[::-1]
        
        ax.bar(range(len(importance_sorted)), importance_sorted, width=1.0)
        ax.set_xlabel('Unit rank')
        ax.set_ylabel('Importance')
        ax.set_title(f'Epoch {target_epoch}')
    
    plt.suptitle('Importance = ||enc|| × activation × ||dec||', fontsize=12)
    plt.tight_layout()
    return fig


def track_top_units(history: TrainingHistory, top_k: int = 10):
    """
    Track which unit indices are in the top-k most important over training.
    """
    tracking = []
    for epoch, _ in history.decoder_weights:
        order = history.get_importance_order(epoch)
        if order is not None:
            tracking.append((epoch, order[:top_k].tolist()))
    return tracking

In [None]:
# Export data as JSON for the HTML explorer
import json

def export_explorer_data(history: TrainingHistory,
                         epochs: list[int],
                         sample_images: np.ndarray,
                         sample_labels: np.ndarray = None):
    """
    Export all data needed for the HTML explorer as JSON.
    Supports multiple epochs for exploring training dynamics.
    """
    epochs_data = []
    
    for epoch in epochs:
        # Get weights for this epoch
        enc_w, dec_w, enc_b, dec_b = None, None, None, None
        for e, w in history.encoder_weights:
            if e == epoch:
                enc_w = w
                break
        for e, w in history.decoder_weights:
            if e == epoch:
                dec_w = w
                break
        for e, b in history.encoder_bias:
            if e == epoch:
                enc_b = b
                break
        for e, b in history.decoder_bias:
            if e == epoch:
                dec_b = b
                break
        
        if enc_w is None:
            print(f"No data for epoch {epoch}, skipping")
            continue
        
        # Get importance ordering and scores
        order = history.get_importance_order(epoch)
        
        # Compute importance scores for ALL units
        enc_norm = np.linalg.norm(enc_w, axis=1)
        dec_norm = np.linalg.norm(dec_w, axis=0)
        if history._data_sample is not None:
            activations = history._data_sample @ enc_w.T + enc_b
            activations = np.maximum(activations, 0)
            mean_act = np.mean(activations, axis=0)
        else:
            mean_act = np.ones(enc_w.shape[0])
        importance = enc_norm * mean_act * dec_norm
        
        epoch_data = {
            'epoch': epoch,
            'importance_order': order.tolist(),
            'importance_scores': importance.tolist(),
            'enc_weights': enc_w.tolist(),
            'dec_weights': dec_w.T.tolist(),
            'enc_bias': enc_b.tolist(),
            'dec_bias': dec_b.tolist(),
            'enc_norms': enc_norm.tolist(),
            'dec_norms': dec_norm.tolist(),
            'mean_activations': mean_act.tolist(),
        }
        epochs_data.append(epoch_data)
        print(f"  Epoch {epoch}: done")
    
    data = {
        'epochs': epochs_data,
        'available_epochs': [ed['epoch'] for ed in epochs_data],
        'hidden_dim': int(history.encoder_weights[0][1].shape[0]),
        'input_dim': 784,
        'sample_images': sample_images.tolist(),
        'sample_labels': sample_labels.tolist() if sample_labels is not None else None,
    }
    
    json_str = json.dumps(data)
    print(f"\nJSON size: {len(json_str) / 1024 / 1024:.2f} MB")
    print(f"Epochs saved: {data['available_epochs']}")
    print(f"Hidden units: {data['hidden_dim']}")
    print(f"Sample images: {len(sample_images)}")
    
    return json_str

## Run Experiment

In [None]:
# Configuration
CONFIG = {
    'hidden_dim': 1024,      # Try: 784, 1024, 2048
    'lr': 1e-3,
    'batch_size': 256,
    'seed': 17,
    'val_split': 0.1,        # 10% held out for validation
    
    # Single-pass mode: each image seen exactly once (no overfitting)
    'single_pass': True,
    'snapshots_per_pass': 50,  # Number of checkpoints during the single pass
    
    # These are only used if single_pass=False:
    'epochs': 50,
    'snapshot_every': 1,
    'probe_every': 5,
}

In [None]:
model, history = train_autoencoder(**CONFIG)

In [None]:
# Training summary
plot_training_summary(history)

In [None]:
# Watch individual units evolve across ALL epochs
# Each row is one hidden unit, each column is an epoch
visualize_weight_evolution(history, num_units=8)

In [None]:
# Side-by-side comparison at selected checkpoints
# Automatically picks evenly spaced checkpoints from what's available
available = [e for e, _ in history.encoder_weights]
n = len(available)
indices = [0, n//5, 2*n//5, 3*n//5, 4*n//5, n-1]
epochs_to_show = [available[i] for i in indices]
visualize_epoch_grid(history, epochs_to_show=epochs_to_show, num_units=16)

In [None]:
# Detailed view: all 64 units at final epoch
final_epoch = history.encoder_weights[-1][0]
visualize_all_weights_at_epoch(history, epoch=final_epoch, num_units=64)

In [None]:
# How importance is distributed across units at different checkpoints
available = [e for e, _ in history.encoder_weights]
n = len(available)
epochs_to_show = [available[0], available[n//3], available[2*n//3], available[-1]]
plot_importance_distribution(history, epochs_to_show=epochs_to_show)

In [None]:
# Animated view - watch weights evolve over training
create_weight_animation(history, num_units=16, interval=150)

In [None]:
# Inspect an early checkpoint - change the index to explore different points
available = [e for e, _ in history.encoder_weights]
early_checkpoint = available[min(5, len(available)-1)]  # 5th checkpoint or last if fewer
print(f"Showing checkpoint: {early_checkpoint}")
visualize_all_weights_at_epoch(history, epoch=early_checkpoint, num_units=64)

In [None]:
# Reconstruction quality
images, labels = load_mnist(device, train=False)
visualize_reconstructions(model, images)

## Quick Comparison Across Hidden Sizes

In [None]:
def run_size_comparison(hidden_dims: list[int] = [784, 1024, 2048],
                        epochs: int = 30,
                        **kwargs):
    """Train autoencoders with different hidden sizes and compare."""
    results = {}
    
    for dim in hidden_dims:
        print(f"\n{'='*50}")
        print(f"Training with hidden_dim={dim}")
        print('='*50)
        
        model, history = train_autoencoder(
            hidden_dim=dim,
            epochs=epochs,
            **kwargs
        )
        results[dim] = {'model': model, 'history': history}
    
    return results

In [None]:
# Uncomment to run comparison (takes longer)
# results = run_size_comparison(
#     hidden_dims=[784, 1024, 2048],
#     epochs=30,
#     checkpoint_every=10,
#     seed=17
# )

In [None]:
# Compare results if run
# for dim, res in results.items():
#     print(f"\nhidden_dim={dim}:")
#     print(f"  Final loss: {res['history'].losses[-1]:.6f}")
#     print(f"  Raw probe: {res['history'].linear_probe_raw:.4f}")
#     print(f"  Final latent probe: {res['history'].linear_probe_latent[-1]:.4f}")

## Export for HTML Explorer

Export trained weights as JSON for use with `explorer.html`.

In [None]:
# Export JSON data for the HTML explorer
# In single_pass mode, "epochs" are actually images_seen counts

# Get test images for the explorer
test_images, test_labels = load_mnist(device, train=False)
test_flat = test_images[:100].view(100, -1).cpu().numpy()
test_labels_np = test_labels[:100].cpu().numpy()

# Get all available epochs/checkpoints from history
available_epochs = [e for e, _ in history.encoder_weights]
print(f"Available checkpoints: {len(available_epochs)}")
print(f"Range: {available_epochs[0]} to {available_epochs[-1]}")

# Select a subset for export (to keep JSON size manageable)
# Take ~15 checkpoints: dense early, sparse later
if len(available_epochs) > 15:
    # Include first few, then logarithmically space the rest
    n = len(available_epochs)
    indices = [0, 1, 2, 3, 4]  # First 5
    indices += [int(n * i / 10) for i in range(1, 10)]  # 10%, 20%, ..., 90%
    indices.append(n - 1)  # Last one
    indices = sorted(set(i for i in indices if i < n))
    epochs_to_export = [available_epochs[i] for i in indices]
else:
    epochs_to_export = available_epochs

print(f"Exporting {len(epochs_to_export)} checkpoints: {epochs_to_export}")

json_str = export_explorer_data(
    history,
    epochs=epochs_to_export,
    sample_images=test_flat,
    sample_labels=test_labels_np
)

In [None]:
# Save and download JSON file
filename = "explorer_data.json"

with open(filename, 'w') as f:
    f.write(json_str)

print(f"Saved {filename} ({len(json_str) / 1024 / 1024:.2f} MB)")

# Download via Colab
from google.colab import files
files.download(filename)