# TensorFlow Variational Autoencoder Lab

### Why this notebook
- Walk through training and sampling with the Keras VAE implementation from `../src`.
- Provide a framework-specific companion to the PyTorch VAE for comparison and study.
- Capture code patterns you can reuse when building generative models.

### Learning objectives
- Train the VAE with KL annealing support controlled by `CONFIG`.
- Inspect the metrics dictionary to balance reconstruction loss and KL divergence.
- Reconstruct inputs and generate novel samples using the decoder.

### Prerequisites
- TensorFlow 2.x installed (GPU optional but helpful).
- Understanding of Gaussian latent variables and the reparameterisation trick.
- Familiarity with the vanilla TensorFlow autoencoder notebook.

### Notebook workflow
1. Import config, training, and inference utilities from `tensorflow/src`.
2. Run `train(CONFIG)` to fit the VAE and capture metrics.
3. Load the trained model, reconstruct samples, and call `sample` for new images.
4. Extend with latent traversals, KL schedules, or evaluation metrics (FID, log-likelihood estimates).


In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parents[2]
SRC_ROOT = PROJECT_ROOT / 'tensorflow' / 'src'
if str(SRC_ROOT) not in sys.path:
    sys.path.append(str(SRC_ROOT))

from config import CONFIG
from train import train
from inference import reconstruct, load_model, sample

In [None]:
metrics = train(CONFIG)
metrics

### Interpret the metrics
- Monitor reconstruction loss and KL divergence captured per epoch.
- If KL drops near zero, increase `CONFIG.kl_weight` or enable annealing to prevent posterior collapse.
- Plot both curves to understand when reconstruction dominates vs when the latent prior tightens.
- Compare with PyTorch metrics to ensure consistent behaviour across frameworks.

In [None]:
import numpy as np

model = load_model(CONFIG)
dummy = np.random.uniform(-1.0, 1.0, size=(8, 28, 28, 1)).astype('float32')
outputs = reconstruct(dummy, model=model, config=CONFIG)
generated = sample(4, model=model, config=CONFIG)
len(outputs), generated.shape

### Next experiments
- Visualise latent traversals by sampling along individual latent dimensions.
- Experiment with Î²-VAE objectives by scaling `CONFIG.kl_weight`.
- Export the decoder as a standalone generator and integrate it into downstream tasks.
- Evaluate FID or other generative metrics to benchmark sample quality.