# Problem 4.3 – Variational Autoencoder (VAE)

## 4.3.1 Conceptual introduction to Variational Autoencoders (VAEs)

A Variational Autoencoder (VAE) is a generative model that learns a probabilistic latent representation of data.
It consists of:
- an encoder $q_\phi(\mathbf{z}\mid\mathbf{x})$ that maps data $\mathbf{x}$ to a distribution over latent variables $\mathbf{z}$,
- a decoder $p_\theta(\mathbf{x}\mid\mathbf{z})$ that maps latent variables back to a distribution over data.

### Notation and assumptions
- Prior on latent variables: $p(\mathbf{z}) = \mathcal{N}(\mathbf{0}, \mathbf{I})$.
- Variational posterior (encoder): $q_\phi(\mathbf{z}\mid\mathbf{x}) = \mathcal{N}\!\big(\boldsymbol\mu_\phi(\mathbf{x}), \mathrm{diag}(\boldsymbol\sigma^2_\phi(\mathbf{x}))\big)$.
  In practice we predict $\boldsymbol\mu$ and $\log\boldsymbol\sigma^2$ (aka `logvar`) for numerical stability.
- Likelihood (decoder): $p_\theta(\mathbf{x}\mid\mathbf{z})$.
  - If we use mean squared error (MSE) as reconstruction loss, this corresponds to a Gaussian likelihood with fixed variance: $p_\theta(\mathbf{x}\mid\mathbf{z}) = \mathcal{N}(\hat{\mathbf{x}}_\theta(\mathbf{z}), \beta \mathbf{I})$ (for some $\beta>0$).
  - If we use binary cross-entropy (BCE) on $[0,1]$ images, this corresponds to a Bernoulli likelihood with mean $\hat{\mathbf{x}}_\theta(\mathbf{z})$.

### Objective: ELBO
Maximizing the log marginal likelihood $\log p_\theta(\mathbf{x})$ directly is intractable,
so we maximize the Evidence Lower BOund (ELBO):
$$
\mathcal{L}_{\text{ELBO}}(\theta,\phi;\mathbf{x})
= \mathbb{E}_{q_\phi(\mathbf{z}\mid\mathbf{x})}\big[\log p_\theta(\mathbf{x}\mid\mathbf{z})\big]
- \mathrm{KL}\!\big(q_\phi(\mathbf{z}\mid\mathbf{x}) \,\|\, p(\mathbf{z})\big).
$$
Training conventionally minimizes the negative ELBO:
$$
\mathcal{L}_{\text{VAE}}(\mathbf{x})
= -\,\mathbb{E}_{q_\phi(\mathbf{z}\mid\mathbf{x})}\big[\log p_\theta(\mathbf{x}\mid\mathbf{z})\big]
+ \mathrm{KL}\!\big(q_\phi(\mathbf{z}\mid\mathbf{x}) \,\|\, p(\mathbf{z})\big).
$$

For Gaussian decoder with fixed variance $\beta\mathbf{I}$, the first term reduces (up to a constant scale) to the per-pixel MSE between $\mathbf{x}$ and $\hat{\mathbf{x}}=\hat{\mathbf{x}}_\theta(\mathbf{z})$:
$$
-\,\mathbb{E}_{q}\big[\log p_\theta(\mathbf{x}\mid\mathbf{z})\big]
\propto \frac{1}{2\beta}\,\|\mathbf{x}-\hat{\mathbf{x}}\|_2^2.
$$
In practice we implement it as an MSE over pixels/channels, reduced to a scalar per batch.

### Closed-form KL for diagonal Gaussians
With $q_\phi(\mathbf{z}\mid\mathbf{x})=\mathcal{N}(\boldsymbol\mu, \mathrm{diag}(\boldsymbol\sigma^2))$ and $p(\mathbf{z})=\mathcal{N}(\mathbf{0},\mathbf{I})$:
$$
\mathrm{KL}\!\big(q \,\|\, p\big)
= \frac{1}{2}\sum_{i=1}^d \big(\mu_i^2 + \sigma_i^2 - \log \sigma_i^2 - 1\big).
$$
Using `logvar = \log \sigma^2`, one computes $\sigma^2 = \exp(\text{logvar})$ and uses the same formula.

### Reparameterization trick
To backpropagate through sampling from $q_\phi(\mathbf{z}\mid\mathbf{x})$, we write
$$
\mathbf{z} = \boldsymbol\mu + \boldsymbol\sigma \odot \boldsymbol\epsilon,
\quad \boldsymbol\epsilon \sim \mathcal{N}(\mathbf{0},\mathbf{I}),
\quad \boldsymbol\sigma = \exp\!\big(\tfrac{1}{2}\,\text{logvar}\big).
$$
This makes sampling a deterministic function of $(\boldsymbol\mu,\text{logvar},\boldsymbol\epsilon)$, enabling gradient flow.

### Practical implementation notes (for the next steps)
- Encoder outputs: `mu`, `logvar`; use a `Sampling` layer to produce `z`.
- Decoder outputs: reconstruction $\hat{\mathbf{x}}$ in $[0,1]$ via a final `sigmoid` when inputs are normalized to $[0,1]$.
- Loss per batch:
  - Reconstruction: sum over pixels/channels per sample, then mean over batch (consistent scalar).
  - KL: sum over latent dims per sample, then mean over batch.
  - Total: `loss = recon_loss + kl_loss` (matching the exercise statement).
- Architectures for 28×28 images:
  - Encoder: Conv2D blocks with strides 2 to reduce to 7×7, then Dense to latent parameters.
  - Decoder: Dense to 7×7×C, then Conv2DTranspose with strides 2 to upsample back to 28×28.
- 2D latent ($d=2$) enables direct scatter plots and grid sampling visualizations.
- Uncertainty maps: multiple stochastic decodes for the same input yield per-pixel variance heatmaps.

### What to remember
- VAE optimizes a trade-off: accurate reconstructions vs. latent regularity (KL toward a standard normal).
- Using MSE corresponds to a Gaussian decoder; BCE corresponds to a Bernoulli decoder.
- Reparameterization trick is the key to make stochastic sampling differentiable.
- For diagonal Gaussians, the KL term is analytic and cheap to compute.


---

## 4.3.2 Fashion-MNIST: load, normalize, and visualize one sample per class

What we will do:
- Download Fashion-MNIST (60k train, 10k test), grayscale 28×28 images.
- Normalize to [0,1] and add a channel dimension -> shape (N, 28, 28, 1).
- Plot one randomly selected sample for each of the 10 classes.
- Optionally restrict training to the first 10,000 samples for speed (as allowed by the exercise).

Why:
- Normalization stabilizes optimization for subsequent model training.
- The channel dimension is required by Conv2D layers.
- Per-class samples help us visually inspect the dataset.

In [None]:
pip install tensorflow-macos tensorflow-metal

Collecting tensorflow-macos
  Downloading tensorflow_macos-2.16.2-cp310-cp310-macosx_12_0_arm64.whl.metadata (3.3 kB)
Collecting tensorflow-metal
  Downloading tensorflow_metal-1.2.0-cp310-cp310-macosx_12_0_arm64.whl.metadata (1.3 kB)
Collecting tensorflow==2.16.2 (from tensorflow-macos)
  Downloading tensorflow-2.16.2-cp310-cp310-macosx_12_0_arm64.whl.metadata (4.1 kB)
Collecting absl-py>=1.0.0 (from tensorflow==2.16.2->tensorflow-macos)
  Downloading absl_py-2.3.1-py3-none-any.whl.metadata (3.3 kB)
Collecting astunparse>=1.6.0 (from tensorflow==2.16.2->tensorflow-macos)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=23.5.26 (from tensorflow==2.16.2->tensorflow-macos)
  Downloading flatbuffers-25.9.23-py2.py3-none-any.whl.metadata (875 bytes)
Collecting gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 (from tensorflow==2.16.2->tensorflow-macos)
  Downloading gast-0.6.0-py3-none-any.whl.metadata (1.3 kB)
Collecting google-pasta>=0.1.1 (from tensorflow=

In [1]:
# Imports and basic setup for this section
import numpy as np
import random
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import seaborn as sns





ModuleNotFoundError: No module named 'tensorflow'

In [None]:
# Reproducibility (subject to GPU/cuDNN determinism limits)
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# Plot style
sns.set(context="notebook", style="whitegrid", palette="deep")
plt.rcParams["figure.figsize"] = (5.5, 5.0)
plt.rcParams["axes.titlesize"] = 12
plt.rcParams["axes.labelsize"] = 11

In [None]:
# Load Fashion-MNIST
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

# Normalize to [0,1] and add channel dimension
x_train = (x_train.astype("float32") / 255.0)[..., None]  # (N, 28, 28, 1)
x_test  = (x_test.astype("float32")  / 255.0)[..., None]

class_names = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]

print("Train:", x_train.shape, y_train.shape)
print("Test: ", x_test.shape, y_test.shape)

# Optionally limit to first 10k training samples for speed
USE_FIRST_10K = True  # set to False for the full 60k
if USE_FIRST_10K:
    x_train = x_train[:10000]
    y_train = y_train[:10000]
    print("Using subset of training data:", x_train.shape, y_train.shape)

In [None]:
# Plot one random sample per class from the (possibly reduced) training set
rng = np.random.default_rng(SEED)
fig, axes = plt.subplots(2, 5, figsize=(10, 4.2))
picked_indices = []

for c in range(10):
    indices = np.where(y_train == c)[0]
    idx = rng.choice(indices)
    picked_indices.append(idx)

for ax, idx, c in zip(axes.ravel(), picked_indices, range(10)):
    ax.imshow(x_train[idx].squeeze(), cmap="gray", vmin=0, vmax=1)
    ax.set_title(class_names[c])
    ax.axis("off")

fig.suptitle("Fashion-MNIST: one random training sample per class", y=1.02)
plt.tight_layout()
plt.show()