In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# Set up plotting style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (10, 4)

torch.manual_seed(42)

<torch._C.Generator at 0x114d6d250>

# Variational Autoencoders (VAE)

<span style="font-size: 15px;">

In the previous notebook (`Latent_Variable_Models_and_Autoencoders.ipynb`), we introduced the standard autoencoder, a neural network that learns compressed representations of data through an encoder-decoder architecture. While autoencoders are powerful for dimensionality reduction and feature learning, they have a fundamental limitation: **they are not true generative models**.

The problem lies in the latent space structure. In a standard autoencoder:
- The encoder maps each input $x$ to a **deterministic** point $z$ in latent space
- The latent space has "holes", regions where no training data maps to
- Sampling random vectors from the latent space often produces meaningless outputs

The **Variational Autoencoder (VAE)**, introduced by Kingma and Welling (2014), addresses these limitations by introducing **probability** and **stochasticity** into the autoencoder framework, **transforming** it into a proper generative model.

</span>

## The Natural Image Manifold and Latent Representations

<span style="font-size: 15px;">

Before diving into the VAE architecture, let's build intuition about what we're trying to learn.

**The Pixel Space $\mathcal{X}$**

Consider the space of all possible images. For a $28 \times 28$ grayscale image, this is $\mathcal{X} = [0, 1]^{784}$, a 784-dimensional hypercube. However, not all points in this space correspond to "natural" or meaningful images. Random noise is technically a valid point in pixel space, but it doesn't represent anything recognizable.

**The Natural Image Manifold**

Real-world images (faces, digits, cats, etc.) occupy only a tiny subspace of the full pixel space. This subspace is often called the **natural image manifold**,a complex, high-dimensional surface where:
- Similar images are close together (cats near cats, dogs near dogs)
- The manifold captures the underlying structure that makes images "natural"

**The Latent Space $\mathcal{Z}$**

The goal of generative models is to learn a **latent space** $\mathcal{Z}$ that:
1. Is lower-dimensional than pixel space ($D_z \ll D_x$)
2. Captures the essential features of the data ("two eyes," "four legs," "curved stroke")
3. Is **smooth** and **continuous**, nearby points produce similar outputs
4. Can be easily sampled from for generation

A well-trained VAE learns a latent space where:
$$
\text{Simple distribution in } \mathcal{Z} \xrightarrow{\text{Decoder}} \text{Complex distribution in } \mathcal{X}
$$

</span>

## From Autoencoder to Variational Autoencoder

<span style="font-size: 15px;">

**The Limitation of Deterministic Autoencoders**

Recall that a standard autoencoder performs the following mapping:
$$
x \xrightarrow{E_\phi} z \xrightarrow{D_\theta} \hat{x}
$$

where the encoder $E_\phi$ maps each input to a **single, fixed point** in latent space. This deterministic mapping means:
- The same input always produces the same latent vector
- The latent space has no inherent structure, points are scattered wherever minimizes reconstruction error
- We cannot meaningfully sample from the latent space

**The VAE Solution: Probabilistic Encoding**

The key insight of the VAE is to replace the deterministic encoder with a **probabilistic** one:

$$
\boxed{
\text{Autoencoder: } x \to z \quad \text{vs} \quad \text{VAE: } x \to p(z|x) = \mathcal{N}(\mu, \sigma^2)
}
$$

Instead of encoding to a point, the VAE encoder outputs the **parameters of a probability distribution**:
- A mean vector $\mu \in \mathbb{R}^{D_z}$
- A variance vector $\sigma^2 \in \mathbb{R}^{D_z}$ (or equivalently, log-variance $\log \sigma^2$)

The latent vector $z$ is then **sampled** from this distribution:
$$
z \sim \mathcal{N}(\mu, \text{diag}(\sigma^2))
$$

This simple change has profound consequences:
1. **Stochasticity**: The same input can produce different latent vectors (and thus different reconstructions)
2. **Structured latent space**: The latent space is regularized to approximate a known distribution
3. **Generative capability**: We can sample $z \sim \mathcal{N}(0, I)$ and decode to generate new images

</span>

## Mathematical Framework of VAEs

**The Generative Model**

<span style="font-size: 15px;">

The VAE assumes the following generative process for data:

1. Sample a latent variable from a **prior** distribution:
   $$z \sim p(z) = \mathcal{N}(0, I)$$

2. Generate data from a **conditional likelihood**:
   $$x \sim p_\theta(x|z)$$

where $\theta$ are the parameters of the decoder network. The decoder defines $p_\theta(x|z)$, which for continuous data is often a Gaussian with mean given by the decoder output.

**The Inference Problem**

Given observed data $x$, we want to infer the latent variable $z$ that generated it. By Bayes' theorem:

$$
p(z|x) = \frac{p_\theta(x|z) \, p(z)}{p(x)}
$$

However, the **evidence** $p(x) = \int p_\theta(x|z) \, p(z) \, dz$ is intractable, it requires integrating over all possible latent vectors.

**Variational Inference**

Since we cannot compute $p(z|x)$ exactly, we **approximate** it with a simpler distribution $q_\phi(z|x)$ parameterized by the encoder network:

$$
q_\phi(z|x) = \mathcal{N}(z; \mu_\phi(x), \text{diag}(\sigma^2_\phi(x)))
$$

The encoder neural network outputs $\mu_\phi(x)$ and $\sigma^2_\phi(x)$ (or $\log \sigma^2_\phi(x)$) for each input $x$.

**Goal**: Find parameters $\phi$ such that $q_\phi(z|x)$ is close to the true posterior $p(z|x)$.

</span>


### The KL Divergence

<span style="font-size: 15px;">

To measure how close $q_\phi(z|x)$ is to $p(z|x)$, we use the **Kullback-Leibler (KL) divergence**:

$$
D_{\text{KL}}(q \| p) = \mathbb{E}_{z \sim q}\left[ \log \frac{q(z)}{p(z)} \right] = \int q(z) \log \frac{q(z)}{p(z)} \, dz
$$

**Key properties of KL divergence:**

1. $D_{\text{KL}}(q \| p) \geq 0$ always (non-negative)
2. $D_{\text{KL}}(q \| p) = 0$ if and only if $q = p$ almost everywhere
3. **Asymmetric**: $D_{\text{KL}}(q \| p) \neq D_{\text{KL}}(p \| q)$ in general

The KL divergence can be interpreted as:
- The expected **extra bits** needed to encode samples from $q$ using a code optimized for $p$
- A measure of **information loss** when using $q$ to approximate $p$

**Intuition**: If $q$ and $p$ are both Gaussians, $D_{\text{KL}}(q \| p)$ penalizes differences in both mean and variance.

</span>

### The Evidence Lower Bound (ELBO)

<span style="font-size: 15px;">

We want to maximize the log-likelihood of our data $\log p(x)$. Through a derivation involving Jensen's inequality, we can show:

$$
\log p(x) = \underbrace{\mathbb{E}_{z \sim q_\phi(z|x)}\left[ \log p_\theta(x|z) \right]}_{\text{Reconstruction term}} - \underbrace{D_{\text{KL}}(q_\phi(z|x) \| p(z))}_{\text{Regularization term}} + \underbrace{D_{\text{KL}}(q_\phi(z|x) \| p(z|x))}_{\geq 0}
$$

Since $D_{\text{KL}}(q_\phi(z|x) \| p(z|x)) \geq 0$, we have:

$$
\log p(x) \geq \underbrace{\mathbb{E}_{z \sim q_\phi(z|x)}\left[ \log p_\theta(x|z) \right] - D_{\text{KL}}(q_\phi(z|x) \| p(z))}_{\text{ELBO}(x; \phi, \theta)}
$$

This lower bound is called the **Evidence Lower Bound (ELBO)**. Maximizing the ELBO simultaneously:
1. Maximizes the data likelihood (generative quality)
2. Minimizes $D_{\text{KL}}(q_\phi(z|x) \| p(z|x))$ (inference quality)

</span>

### The VAE Loss Function

<span style="font-size: 15px;">

The VAE is trained by **maximizing the ELBO**, or equivalently, **minimizing the negative ELBO**:

$$
\mathcal{L}_{\text{VAE}}(x; \phi, \theta) = -\text{ELBO} = \underbrace{-\mathbb{E}_{z \sim q_\phi(z|x)}\left[ \log p_\theta(x|z) \right]}_{\mathcal{L}_{\text{recon}}} + \underbrace{D_{\text{KL}}(q_\phi(z|x) \| p(z))}_{\mathcal{L}_{\text{KL}}}
$$

**Reconstruction Loss $\mathcal{L}_{\text{recon}}$**:
- Measures how well the decoder reconstructs the input from sampled latent vectors
- For Gaussian $p_\theta(x|z)$ with fixed variance: equivalent to MSE loss
- For Bernoulli $p_\theta(x|z)$ (binary images): equivalent to BCE loss

**KL Divergence Loss $\mathcal{L}_{\text{KL}}$**:
- Regularizes the encoder to produce distributions close to the prior $p(z) = \mathcal{N}(0, I)$
- Prevents the encoder from collapsing to deterministic mappings
- Ensures the latent space is smooth and continuous

**Closed-form KL for Gaussians**:

When $q_\phi(z|x) = \mathcal{N}(\mu, \text{diag}(\sigma^2))$ and $p(z) = \mathcal{N}(0, I)$, the KL divergence has a closed form:

$$
D_{\text{KL}}(q_\phi(z|x) \| p(z)) = -\frac{1}{2} \sum_{j=1}^{D_z} \left( 1 + \log \sigma_j^2 - \mu_j^2 - \sigma_j^2 \right)
$$

Or equivalently, using log-variance $\gamma_j = \log \sigma_j^2$:

$$
D_{\text{KL}} = -\frac{1}{2} \sum_{j=1}^{D_z} \left( 1 + \gamma_j - \mu_j^2 - e^{\gamma_j} \right)
$$

</span>

### The Reparameterization Trick

<span style="font-size: 15px;">

There's a problem: we need to backpropagate through the sampling operation $z \sim q_\phi(z|x)$, but sampling is **non-differentiable**.

**The Solution: Reparameterization**

Instead of sampling $z$ directly from $\mathcal{N}(\mu, \sigma^2)$, we reparameterize as:

$$
z = \mu + \sigma \odot \epsilon, \quad \text{where } \epsilon \sim \mathcal{N}(0, I)
$$

Here $\odot$ denotes element-wise multiplication.

**Why this works:**
- The randomness is now in $\epsilon$, which doesn't depend on any parameters
- $z$ is a deterministic function of $\mu$, $\sigma$, and $\epsilon$
- Gradients can flow through $\mu$ and $\sigma$ to the encoder

**In PyTorch:**
```python
def reparameterize(self, mu, logvar):
    std = torch.exp(0.5 * logvar)  # σ = exp(0.5 * log(σ²))
    eps = torch.randn_like(std)     # ε ~ N(0, I)
    return mu + eps * std           # z = μ + ε ⊙ σ
```

</span>

## VAE Architecture Summary

<span style="font-size: 15px;">

The complete VAE architecture can be summarized as:

$$
\boxed{
x \xrightarrow{\text{Encoder}} (\mu, \log\sigma^2) \xrightarrow{\text{Reparameterize}} z = \mu + \sigma \odot \epsilon \xrightarrow{\text{Decoder}} \hat{x}
}
$$

**Components:**

| Component | Input | Output | Function |
|-----------|-------|--------|----------|
| Encoder $E_\phi$ | $x \in \mathbb{R}^{D_x}$ | $\mu, \log\sigma^2 \in \mathbb{R}^{D_z}$ | Learn approximate posterior |
| Reparameterization | $\mu, \sigma, \epsilon$ | $z \in \mathbb{R}^{D_z}$ | Enable backpropagation |
| Decoder $D_\theta$ | $z \in \mathbb{R}^{D_z}$ | $\hat{x} \in \mathbb{R}^{D_x}$ | Generate reconstruction |

**Training:**
$$
\mathcal{L}_{\text{VAE}} = \underbrace{\|x - \hat{x}\|^2}_{\text{Reconstruction}} + \underbrace{\beta \cdot D_{\text{KL}}(q_\phi(z|x) \| p(z))}_{\text{Regularization}}
$$

where $\beta$ is a hyperparameter (often $\beta = 1$, but $\beta$-VAE uses different values).

**Generation:**
$$
z \sim \mathcal{N}(0, I) \xrightarrow{D_\theta} \hat{x}_{\text{new}}
$$

</span>

## VAE Implementation in PyTorch

<span style="font-size: 15px;">

We now implement a convolutional VAE in PyTorch. The architecture uses:
- **Encoder**: Convolutional layers to extract features, then linear layers to output $\mu$ and $\log\sigma^2$
- **Decoder**: Linear layer followed by transposed convolutions to reconstruct the image

We'll design this for 28×28 grayscale images (MNIST/Fashion-MNIST).

</span>

### Building the VAE

In [2]:
class VAE(nn.Module):
    """
    Convolutional Variational Autoencoder.
    
    Architecture for 28x28 input:
        Encoder: (1, 28, 28) -> Conv layers -> Flatten -> (μ, log σ²)
        Decoder: z -> FC -> Unflatten -> ConvTranspose layers -> (1, 28, 28)
    
    The encoder outputs parameters of q(z|x) = N(μ, σ²I)
    The decoder defines p(x|z)
    """
    
    def __init__(self, latent_dim: int = 32):
        super().__init__()
        
        self.latent_dim = latent_dim
        
        # ============ Encoder ============
        # Convolutional layers: (1, 28, 28) -> (32, 4, 4)
        self.encoder = nn.Sequential(
            # (1, 28, 28) -> (8, 14, 14)
            nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            # (8, 14, 14) -> (16, 7, 7)
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            # (16, 7, 7) -> (32, 4, 4)  (note: 7->4 with stride=2, padding=1)
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            # Flatten: (32, 4, 4) -> (512,)
            nn.Flatten()
        )
        
        # Latent space projections
        # The flattened encoder output has size 32 * 4 * 4 = 512
        self.fc_mu = nn.Linear(512, latent_dim)      # μ
        self.fc_logvar = nn.Linear(512, latent_dim)  # log(σ²)
        
        # ============ Decoder ============
        # Project from latent space back to feature maps
        self.fc_decode = nn.Linear(latent_dim, 512)
        
        # Transposed convolutions: (32, 4, 4) -> (1, 28, 28)
        self.decoder = nn.Sequential(
            # Unflatten: (512,) -> (32, 4, 4)
            nn.Unflatten(1, (32, 4, 4)),
            # (32, 4, 4) -> (16, 7, 7)
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=0),
            nn.ReLU(),
            # (16, 7, 7) -> (8, 14, 14)
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            # (8, 14, 14) -> (1, 28, 28)
            nn.ConvTranspose2d(8, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()  # Output in [0, 1] for pixel values
        )
    
    def encode(self, x):
        """
        Encode input to distribution parameters.
        
        Args:
            x: Input tensor of shape (B, 1, 28, 28)
        
        Returns:
            mu: Mean of q(z|x), shape (B, latent_dim)
            logvar: Log-variance of q(z|x), shape (B, latent_dim)
        """
        h = self.encoder(x)  # (B, 512)
        mu = self.fc_mu(h)        # (B, latent_dim)
        logvar = self.fc_logvar(h)  # (B, latent_dim)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick: z = μ + σ ⊙ ε, where ε ~ N(0, I)
        
        This allows backpropagation through the sampling operation.
        
        Args:
            mu: Mean of the latent distribution (B, latent_dim)
            logvar: Log-variance of the latent distribution (B, latent_dim)
        
        Returns:
            z: Sampled latent vector (B, latent_dim)
        """
        # Compute standard deviation: σ = exp(0.5 * log(σ²))
        std = torch.exp(0.5 * logvar)
        
        # Sample ε ~ N(0, I)
        eps = torch.randn_like(std)
        
        # z = μ + ε ⊙ σ
        return mu + eps * std
    
    def decode(self, z):
        """
        Decode latent vector to reconstruction.
        
        Args:
            z: Latent vector of shape (B, latent_dim)
        
        Returns:
            x_hat: Reconstructed image of shape (B, 1, 28, 28)
        """
        h = self.fc_decode(z)  # (B, 512)
        return self.decoder(h)  # (B, 1, 28, 28)
    
    def forward(self, x):
        """
        Full forward pass: encode -> reparameterize -> decode
        
        Args:
            x: Input tensor of shape (B, 1, 28, 28)
        
        Returns:
            x_hat: Reconstruction of shape (B, 1, 28, 28)
            mu: Mean of q(z|x) for KL computation
            logvar: Log-variance of q(z|x) for KL computation
        """
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar