# PyTorch Sparse Autoencoder Lab

### Why this notebook
- Practice training the KL-regularised sparse autoencoder using the modular PyTorch stack.
- Observe how sparsity penalties evolve alongside reconstruction metrics.
- Capture reusable code snippets for reconstruction and latent analysis.

### Learning objectives
- Configure desired sparsity levels and KL weights via `CONFIG`.
- Train the model and monitor KL divergence in the returned metrics.
- Reconstruct samples and inspect how sparsity impacts latent activations.

### Prerequisites
- PyTorch 2.x with torchvision installed.
- Understanding of the vanilla autoencoder workflow.
- Optional: familiarity with KL divergence and sparsity concepts.

### Notebook workflow
1. Import `CONFIG`, `train`, and `inference` helpers from the project `src` folder.
2. Run `train(CONFIG)` and explore the reconstruction/KL metrics.
3. Load the saved checkpoint and reconstruct sample tensors.
4. Extend with latent histograms, activation sparsity plots, or alternative sparsity targets.


In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parents[2]
SRC_ROOT = PROJECT_ROOT / 'pytorch' / 'src'
if str(SRC_ROOT) not in sys.path:
    sys.path.append(str(SRC_ROOT))

from config import CONFIG
from train import train
from inference import reconstruct, load_model

In [None]:
metrics = train(CONFIG)
metrics

### Interpret the metrics
- The dictionary logs reconstruction loss, PSNR, and KL divergence per epoch.
- Plot the KL term to ensure it converges near the `CONFIG.target_sparsity` value.
- Compare against vanilla/denoising runs to assess the cost of sparsity on reconstruction quality.
- Track both train and validation KL values to spot over-regularisation early.

In [None]:
import torch

model = load_model(CONFIG)
dummy = torch.randn(8, 1, 28, 28)
outputs = reconstruct([img for img in dummy], model=model, config=CONFIG)
len(outputs)

### Next experiments
- Visualise latent activation histograms to confirm sparsity behaviour.
- Sweep the KL weight and track PSNR/KL trade-offs using a simple loop.
- Transfer the encoder representations into a downstream classifier or clustering task.
- Compare results with the TensorFlow sparse autoencoder to validate parity.