# PyTorch Variational Autoencoder Lab

### Why this notebook
- Guide you through training and sampling from the VAE implementation in `../src`.
- Highlight how the KL divergence term interacts with reconstruction loss.
- Provide ready-to-run code for reconstruction, sampling, and latent exploration.

### Learning objectives
- Train the VAE on Fashion-MNIST while logging KL and reconstruction metrics.
- Reconstruct held-out samples and generate new images from the learned latent prior.
- Adjust KL weight and latent dimension to study their impact on sample quality.

### Prerequisites
- PyTorch 2.x, torchvision, and Matplotlib installed.
- Understanding of basic autoencoder concepts and Gaussian distributions.
- Optional: familiarity with the reparameterisation trick.

### Notebook workflow
1. Import config, training, and inference helpers from the PyTorch `src` package.
2. Execute `train(CONFIG)` and inspect the returned metrics.
3. Load the trained model, reconstruct samples, and draw fresh latent samples via `sample`.
4. Extend with latent traversal plots, KL annealing experiments, or evaluation metrics (FID, inception scores).


In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parents[2]
SRC_ROOT = PROJECT_ROOT / 'pytorch' / 'src'
if str(SRC_ROOT) not in sys.path:
    sys.path.append(str(SRC_ROOT))

from config import CONFIG
from train import train
from inference import reconstruct, load_model, sample

In [None]:
metrics = train(CONFIG)
metrics

### Interpret the metrics
- `metrics` logs reconstruction loss and KL divergence per epoch.
- Use the ratio of KL to reconstruction loss to detect posterior collapse or under-regularisation.
- Plot both terms to visualise the effect of KL annealing or weight tuning.
- Compare with other variants to understand the trade-off between sample diversity and fidelity.

In [None]:
import torch

model = load_model(CONFIG)
dummy = torch.randn(8, 1, 28, 28)
outputs = reconstruct([img for img in dummy], model=model, config=CONFIG)
generated = sample(4, model=model, config=CONFIG)
len(outputs), generated.shape

### Next experiments
- Perform latent traversals by linearly interpolating between encoded points and plotting outputs.
- Experiment with different prior distributions (e.g., VampPrior) by modifying `model.py`.
- Compute quantitative metrics like FID using generated samples.
- Mirror the configuration in the TensorFlow notebook to compare training dynamics.