# Advanced Statistical Inference: Variational Autoencoders


In this notebook, you will implement a Variational Autoencoder (VAE) for the MNIST dataset.
The VAE is a generative model that learns a probabilistic latent representation of data. It consists of two components:

- An **encoder** network that approximates the posterior distribution $q(\mathbf{z}|\mathbf{x})$, mapping data $\mathbf{x}$ to a distribution in latent space $\mathbf{z}$.
- A **decoder** network that reconstructs the data from latent variables, approximating the likelihood $p(\mathbf{x}|\mathbf{z})$.

We will optimize the Evidence Lower Bound (ELBO), which balances reconstruction fidelity and latent space regularization, using stochastic gradient descent and the reparameterization trick.

The lab leverages JAX and Flax for differentiable programming and efficient computation.

**Important Notes**: It is highly recommended to run this notebook on a GPU or TPU for performance reasons. You can enable GPU/TPU support in Google Colab by going to `Runtime` > `Change runtime type` and selecting `GPU` or `TPU` as the hardware accelerator. Even a small GPU will significantly speed up training compared to a CPU.

## Flax Overview

* **Flax Overview**:

  * Flax is a high-performance neural network library built on top of JAX, developed by Google Research.
  * It is designed for flexibility and composability in research settings, with strong support for hardware acceleration via JAX's XLA backend.
  * Flax follows a functional programming model, which separates the definition of computation (pure functions) from data (model parameters and other state).
  * GPU/TPU support is automatic through JAX: computations are compiled via XLA and dispatched to the available device (CPU, GPU, or TPU) without modifying the model code.
  * Example: Unlike PyTorch, where model parameters are typically stored as object attributes, Flax models do not carry their own parameters. Instead, parameters are stored in a dictionary and passed explicitly to functions:

    ```python
    variables = model.init(rng, x)  # Returns {'params': ...}
    y = model.apply(variables, x)
    ```

* **Core Concepts**:

  * The basic building block in Flax is the `flax.linen.Module`, which describes a single layer or component of a model.
  * A Flax module is defined using a functional style with explicit inputs and parameter registration via attribute declarations.
  * Example:

    ```python
    class MLP(nn.Module):
        features: Sequence[int]

        @nn.compact
        def __call__(self, x):
            for feat in self.features:
                x = nn.Dense(feat)(x)
                x = nn.relu(x)
            return x
    ```

    This creates a multilayer perceptron where `nn.Dense` layers are registered in a scope, and their parameters are automatically managed.
  * Initialization is done by calling `init`, which produces a `variables` dict containing parameters:

    ```python
    rng = jax.random.PRNGKey(0)
    x = jnp.ones((1, 32))
    mlp = MLP(features=[64, 64])
    variables = mlp.init(rng, x)
    ```
  * Forward evaluation uses `apply`:

    ```python
    y = mlp.apply(variables, x)
    ```
  * Device placement is automatic: the arrays returned by `init` and `apply` are JAX DeviceArrays, which are placed on GPU/TPU transparently. No `.to(device)` or manual device context is required.

* **State Management and Execution**:

  * Flax uses a "scope" system to manage variables and submodules during function calls.
  * Model state is explicitly stored in collections (e.g., `'params'`) and passed around. This encourages reproducibility and transformation compatibility.
  * Integration with JAX transformations is seamless. For example:

    * `jax.jit` compiles training steps for fast execution:

      ```python
      @jax.jit
      def train_step(variables, x, y):
          def loss_fn(params):
              logits = model.apply({'params': params}, x)
              loss = cross_entropy_loss(logits, y)
              return loss
          grads = jax.grad(loss_fn)(variables['params'])
          new_params = apply_gradients(variables['params'], grads)
          return {'params': new_params}
      ```
      jit also handles device placement automatically, so the compiled function runs on the most efficient device available.
    * `jax.vmap` automatically vectorizes computations over batch dimensions without rewriting the model:

      ```python
      batched_model = jax.vmap(lambda x: model.apply(variables, x))
      y = batched_model(batch_input)
      ```
  * All model computation and parameters are kept device-agnostic. JAX ensures all operations occur on the most efficient available device, and compiled functions automatically execute on GPU or TPU when present. This simplifies code deployment and accelerates training.


In [None]:
# If we are on colab runtime, we need to install the datasets library
if 'google.colab' in str(get_ipython()):
    %pip install --quiet datasets huggingface_hub fsspec 

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from datasets import load_dataset
from flax import linen as nn
from typing import Any
from functools import partial
import os
from matplotlib import rc

tfd = tfp.distributions


# Plot configuration

rc("font", **{"family": "sans-serif", "sans-serif": "DejaVu Sans"})
rc("text", **{"usetex": False})
rc("figure", **{"dpi": 200})
rc(
    "axes",
    **{"spines.right": False, "spines.top": False, "xmargin": 0.0, "ymargin": 0.05},
)

# Set seed for reproducibility
rng = jax.random.PRNGKey(0)
jax.config.update('jax_threefry_partitionable', True)


## 1. Load and preprocess MNIST

We begin by downloading and loading the MNIST dataset using the `datasets` library from Hugging Face, and normalize the image intensities to [0,1]. Each image is 28x28, and we flatten it into a 784-dimensional vector for input into a fully-connected neural network. This preprocessing prepares the data for learning the variational autoencoder latent representation.


In [None]:
def load_mnist():
    ds = load_dataset("mnist")

    def normalize(batch):
        batch["image"] = batch["image"] / 255.0  # Normalize pixel values to [0, 1]
        return batch

    train_ds = ds["train"].with_format("numpy").map(normalize, batched=True)
    test_ds = ds["test"].with_format("numpy").map(normalize, batched=True)

    X_train = np.stack(train_ds["image"]).reshape(-1, 28, 28, 1)
    X_test = np.stack(test_ds["image"]).reshape(-1, 28, 28, 1)
    return X_train, X_test


X_train, X_test = load_mnist()
print("Train shape:", X_train.shape)
print("Test shape:", X_test.shape)

In [None]:
def plot_examples(images, nrows=2, ncols=5, figsize=(10, 4), cmap="gray"):
    """
    Plot a grid of example images.

    Args:
      images (np.ndarray): Array of images, shape (N, H * W)
      nrows (int): Number of rows in the grid.
      ncols (int): Number of columns in the grid.
      figsize (tuple): Figure size.
      cmap (str): Colormap for grayscale images.
    """
    if images.ndim == 2:
        images = images.reshape(-1, 28, 28)  # Reshape if single image vector
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        if i < len(images):
            img = images[i]
            if img.ndim == 2 or (img.ndim == 3 and img.shape[-1] == 1):
                ax.imshow(img.squeeze(), cmap=cmap)
            else:
                ax.imshow(img)
            ax.axis("off")
        else:
            ax.axis("off")
    plt.tight_layout()
    plt.show()
    return fig


# Plot some training examples
plot_examples(X_train[:10], nrows=2, ncols=5, figsize=(10, 4), cmap="gray")

## 2. Define encoder and decoder networks

We define the architecture of the encoder and decoder using `flax.linen`. The encoder takes the input image vector $\mathbf{x} \in \mathbb{R}^{28 \times 28 \times 1}$ and outputs two vectors: $\boldsymbol{\mu}(\mathbf{x}), \log \boldsymbol{\sigma}^2(\mathbf{x})$, which parameterize the variational distribution $q(\mathbf{z}|\mathbf{x})$. The decoder maps latent variables $\mathbf{z} \in \mathbb{R}^d$ to reconstructed logits for compuuting the likelihood over pixels.




In [None]:
class AmortizationNetwork(nn.Module):
    num_layers: int
    hidden_dim: int
    latent_dim: int
    dtype: Any = jnp.bfloat16

    @nn.compact
    def __call__(self, x):
        for _ in range(self.num_layers - 1):
            x = nn.Conv(
                features=self.hidden_dim,
                kernel_size=(3, 3),
                strides=(1, 1),
                padding="SAME",
                dtype=self.dtype,
            )(x)
            x = nn.LayerNorm(dtype=self.dtype)(x)
            x = nn.relu(x)
        x = nn.Conv(
            features=self.hidden_dim,
            kernel_size=(3, 3),
            strides=(2, 2),
            dtype=self.dtype,
        )(x)
        x = x.reshape(x.shape[0], -1)  # Flatten the output
        mean = nn.Dense(self.latent_dim, dtype=self.dtype)(x)
        logvar = nn.Dense(self.latent_dim, dtype=self.dtype)(x)
        return mean, logvar


class LikelihoodNetwork(nn.Module):
    num_layers: int
    hidden_dim: int
    output_dim: int
    dtype: Any = jnp.bfloat16

    @nn.compact
    def __call__(self, z):
        z = nn.Dense(self.hidden_dim, dtype=self.dtype)(z)
        z = z.reshape(z.shape[0], 1, 1, self.hidden_dim)
        for _ in range(self.num_layers - 1):
            z = nn.ConvTranspose(
                features=self.hidden_dim,
                kernel_size=(3, 3),
                strides=(2, 2) if z.shape[2] < 28 else (1, 1), # Up-sampling only if needed, otherwise keep same size
                dtype=self.dtype,
            )(z)
            z = nn.LayerNorm(dtype=self.dtype)(z)

        z = nn.ConvTranspose(
            features=1,
            kernel_size=(3, 3),
            strides=(1, 1),
            dtype=self.dtype,
        )(z)
        z = nn.Dense(self.output_dim, dtype=self.dtype)(z.reshape(z.shape[0], -1))
        z = z.reshape(
            z.shape[0], int(np.sqrt(self.output_dim)), int(np.sqrt(self.output_dim)), 1
        )
        return z

**Exercise**: Create the two models `AmortizationNetwork` and `LikelihoodNetwork` as defined above. The `AmortizationNetwork` will be used for the encoder, and the `LikelihoodNetwork` for the decoder. You can specify the number of layers, hidden dimensions, and latent dimensions as needed, but be reasonable.
After, initialize the model parameters using `init` method. The `init` method requires a random key and an example input shape. For the encoder, use an input shape of `(1, 28, 28, 1)` (a single MNIST image), and for the decoder, use an input shape of `(1, latent_dim)` where `latent_dim` is the dimension of the latent space.


In [None]:
latent_dim = 2
encoder_num_layers = 2
encoder_hidden_dim = 32
decoder_num_layers = 5
decoder_hidden_dim = 64
dtype = "bfloat16"

# @@ COMPLETE @@
# encoder = ...
# decoder = ...


fake_input = jnp.ones((1, 28, 28, 1), dtype=dtype)  # Example input for encoder
fake_latent = jnp.ones((1, latent_dim))  # Example input for decoder

# Initialize encoder parameters
encoder_params = encoder.init(rng, fake_input)
print(encoder.tabulate(rng, fake_input))

# Initialize decoder parameters
decoder_params = decoder.init(rng, fake_latent)
print(decoder.tabulate(rng, fake_latent))

# For simplicity, we store parameters in a single dictionary
params = {"encoder": encoder_params, "decoder": decoder_params}

## 3. Reparameterization trick

We want to sample from $\mathbf{z} \sim q(\mathbf{z}|\mathbf{x}) = \mathcal{N}(\boldsymbol{\mu}, \operatorname{diag}(\boldsymbol{\sigma}^2))$, but to allow gradient-based optimization, we use the reparameterization trick:

$$
\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon}, \quad \boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})
$$

This makes the sampling operation differentiable w.r.t. $\boldsymbol{\mu}, \log \boldsymbol{\sigma}^2$.

**Exercise**: Implement the reparameterization trick

In [None]:
def reparameterize(rng, mean, logvar):
    """
    Reparameterization trick to sample from a Gaussian distribution.

    Args:
        mean (jnp.ndarray): Mean of the Gaussian.
        logvar (jnp.ndarray): Log variance of the Gaussian.
        rng (jax.random.PRNGKey): Random number generator key.

    Returns:
        jnp.ndarray: Sampled latent variable.
    """
    eps = jax.random.normal(rng, shape=mean.shape, dtype=mean.dtype)
    # @@ COMPLETE @@
    # std = ...
    # sample = ...
    return sample


# Test the reparameterization function
mean = jnp.array([[0.0, 0.0]])
logvar = jnp.array([[0.0, 0.0]])
sampled_z = reparameterize(rng, mean, logvar)
print("Sampled z:", sampled_z)

## 4. Define VAE loss (ELBO)

The loss is the negative of the Evidence Lower Bound (ELBO):

$$
\mathcal{L}(\theta, \phi; \mathbf{x}) = \mathbb{E}_{q_{\phi}(\mathbf{z}|\mathbf{x})}[ \log p_\theta(\mathbf{x}|\mathbf{z}) ] - \operatorname{KL}( q_\phi(\mathbf{z}|\mathbf{x}) \| p(\mathbf{z}) )
$$

Assuming $p(\mathbf{z}) = \mathcal{N}(0, I)$, $q(\mathbf{z}|\mathbf{x}) = \mathcal{N}(\boldsymbol{\mu}, \operatorname{diag}(\boldsymbol{\sigma}^2))$, the KL term is:

$$
\operatorname{KL}(q \| p) = \frac{1}{2} \sum_j \left( \exp(\log \sigma_j^2) + \mu_j^2 - 1 - \log \sigma_j^2 \right)
$$

For the reconstruction loss, let's start with the Gaussian likelihood.

**Exercise**: Read the code below for computing the ELBO loss function, including the KL divergence and reconstruction loss. Make sure you understand how the loss is computed using the encoder (`AmortizationNetwork`) and decoder (`LikelihoodNetwork`) models. 

In [None]:

@partial(jax.jit, static_argnames=("likelihood"))
def elbo(params, x, rng, kl_weight: float = 1.0, likelihood: str = "gaussian"):
    """
    Compute the Evidence Lower Bound (ELBO) for the VAE.

    Args:
        params (dict): Model parameters containing encoder and decoder.
        x (jnp.ndarray): Input data.
        rng (jax.random.PRNGKey): Random number generator key.
        likelihood (str): Type of likelihood ('gaussian', 'truncated_gaussian', 'bernoulli').
        kl_weight (float): Weight for the KL divergence term.
    Returns:
        jnp.ndarray: ELBO loss value.
        jnp.ndarray: Reconstruction loss.
        jnp.ndarray: KL divergence.
    """

    # Compute mean and log variance from the encoder
    mean, logvar = encoder.apply(params["encoder"], x)

    # Sample from the latent space using reparameterization trick
    z = reparameterize(rng, mean, logvar)

    # Compute logits from the decoder
    preds = decoder.apply(params["decoder"], z)

    # Ensure preds and x have the same dtype
    preds = preds.astype("float32").reshape(preds.shape[0], -1)
    x = x.astype("float32").reshape(x.shape[0], -1)

    # Compute reconstruction loss based on likelihood type
    if likelihood == "gaussian":
        recon_loss = tfd.Normal(loc=preds, scale=1.0).log_prob(x).sum(axis=-1)
    else:
        raise ValueError(f"Unsupported likelihood type: {likelihood}")

    # Compute KL divergence
    kl = tfd.Normal(loc=mean, scale=jnp.exp(0.5 * logvar)).kl_divergence(
        tfd.Normal(loc=0.0, scale=1.0)
    )
    kl_div = kl_weight * jnp.sum(kl, axis=-1)

    # Compute ELBO loss
    elbo_loss = -jnp.mean(recon_loss - kl_div)
    recon_loss = -jnp.mean(recon_loss)
    kl_div = jnp.mean(kl_div)

    return elbo_loss, (recon_loss, kl_div)


# Test the ELBO function with a batch of data

batch_size = 32
x_batch = jnp.ones((batch_size, 28, 28, 1))  # Example batch of data
elbo_loss, (recon_loss, kl_div) = elbo(params, x_batch, rng)
print("ELBO Loss:", elbo_loss)
print("Reconstruction Loss:", recon_loss)
print("KL Divergence:", kl_div)

## 5. Training loop

Great, now we have almost everything we need to train our VAE. We will set up a training loop that iterates over the MNIST dataset, computes the ELBO loss, and updates the model parameters using gradient descent.
We need to set up the Adam optimizer (from `optax`). The `train_step` function will compute the gradients of the ELBO with respect to the encoder and decoder parameters and update them.

**Exercise**: Read the code below for the training loop. Make sure you understand how the training step is performed, including the use of `jax.grad` to compute gradients and how they are applied using the optimizer.

In [None]:
import optax
from flax.training.train_state import TrainState

@partial(jax.jit, static_argnames=("likelihood",), donate_argnums=(0,))
def train_step(state: TrainState, x, rng, kl_weight=1.0, likelihood="gaussian"):
    """
    Perform a single training step for the VAE.

    Args:
        state (TrainState): Current model state containing parameters and optimizer state.
        x (jnp.ndarray): Input data batch.
        rng (jax.random.PRNGKey): Random number generator key.
        kl_weight (float): Weight for the KL divergence term.
        likelihood (str): Type of likelihood ('gaussian', ...)

    Returns:
        state (TrainState): Updated model state after applying gradients.
        tuple: ELBO loss, reconstruction loss, and KL divergence.
    """
    # Define the gradient function for the ELBO 
    grads_fn = jax.value_and_grad(elbo, has_aux=True)

    (elbo_loss, (recon_loss, kl_div)), grads = grads_fn(
        state.params, x, rng, kl_weight, likelihood
    )   

    # Update parameters using the computed gradients
    state = state.apply_gradients(grads=grads)

    # Return the updated state and the losses
    return state, (-elbo_loss, -recon_loss, kl_div)


**Exercise**: Run the following training loop for a few epochs to train the VAE on the MNIST dataset. Note that the code is set up to run on a GPU or TPU. A TPU is a chip containing multiple accelerators. In terms of compute, each TPU is equivalent to 8 small GPUs. 
To take full advantage of GPUs and TPUs, we can parallelize the training across multiple devices. For example, you can split the batch of images across multiple devices, each device processing a portion of the batch independently and then aggregating the results. 
This is known as **data parallelism**. The code below automatically handles this case. If you are running on a single device, it will simply run the training loop on that device. If you are running on a TPU, it will automatically split the batch across the 8 accelerators.


In [None]:
from tqdm import tqdm
from jax.sharding import PartitionSpec as P, NamedSharding

# Set up sharding for distributed training
# A mesh is a abstraction for the devices we are using
# A mesh can have multiple dimensions, depending on how we want to partition our data/model
# Here we use a single dimension "batch" for data parallelism
mesh = jax.make_mesh((jax.local_device_count(), ), ("batch",)) 

# Set up 2 strategies for sharding:
# - sharded: to distribute the "batch" dimension across devices (e.g. split the batch across devices)
# - replicated: to replicate and synchronize the data across all devices (e.g. to copy the model parameters)
sharded = NamedSharding(mesh, P("batch"))
replicated = NamedSharding(mesh, P())

# Initialize the optimizer
# We use Adam optimizer
optim = optax.adam(learning_rate=1e-3, b1=0.9, b2=0.999, eps=1e-8)

# Initialize the train state
# The train state contains the model parameters and the optimizer, 
# and (optionally) the forward function of the model
# Since we have two networks (encoder and decoder), we are not using `apply_fn` here to avoid confusion.
state = TrainState.create(params=params, tx=optim, apply_fn=None)

num_epochs = 1000
batch_size = 1024 * jax.device_count()  # Adjust batch size based on number of devices

n_train = X_train.shape[0]
steps_per_epoch = n_train // batch_size

losses = []

state = jax.device_put(state, replicated) # The state is replicated across devices
with tqdm(total=num_epochs, desc="Training VAE", unit="epoch", colour="blue") as pbar:
    for epoch in range(num_epochs):
        perm = np.random.permutation(n_train)
        _loss = 0.0
        for step in range(steps_per_epoch):
            rng, subkey = jax.random.split(rng)
            
            batch_idx = perm[step * batch_size : (step + 1) * batch_size]
            batch = X_train[batch_idx]
            
            batch = jax.device_put(batch, sharded) # The batch is sharded across devices
            
            state, (loss, recon_loss, kl_div) = train_step(state, batch, subkey)
            _loss += (loss)
        loss = _loss / steps_per_epoch

        losses.append(loss)
        pbar.set_postfix_str(f'loss={loss:.4f}')
        pbar.update(1)

## 6. Generate Samples from the VAE

Once training is complete, we can use the decoder network to generate new samples from the learned generative model.

The procedure follows from the probabilistic formulation of the generative model:

1. Sample $\mathbf{z} \sim p(\mathbf{z}) = \mathcal{N}(\mathbf{0}, \mathbf{I})$, where $\mathbf{z} \in \mathbb{R}^d$.
2. Decode $\mathbf{z}$ to generate $\mathbf{x}_{\text{logits}} = f_{\text{decoder}}(\mathbf{z})$.
3. If using a Gaussian likelihood: sample or visualize $\mathbf{x}_{\text{logits}}$ directly as reconstructions.
   If using Bernoulli likelihood: apply sigmoid to $\mathbf{x}_{\text{logits}}$ and sample binary images or threshold.

Sampling directly from $p(\mathbf{x}) = \int p(\mathbf{x}|\mathbf{z}) p(\mathbf{z}) d\mathbf{z}$ is intractable, but the decoder learned to approximate $p(\mathbf{x}|\mathbf{z})$.

**Exercise**: Generate and visualize samples from the decoder. Sample $\mathbf{z} \sim \mathcal{N}(0, I)$ and decode.

In [None]:
def generate_samples(
    decoder, decoder_params, rng, num_samples=10, latent_dim=4, likelihood="gaussian"
):
    """
    Generate samples from the VAE by decoding samples from the prior.

    Args:
        decoder: Flax decoder module.
        decoder_params: Decoder parameters.
        rng: PRNGKey for randomness.
        num_samples: Number of samples to generate.
        latent_dim: Dimensionality of latent space.

    Returns:
        np.ndarray: Decoded images in shape (num_samples, 28, 28).
    """
    # Sample latent variables from standard Gaussian
    rng, z_key = jax.random.split(rng)
    z_samples = jax.random.normal(z_key, shape=(num_samples, latent_dim))

    # Decode latent vectors into image space
    images = decoder.apply(decoder_params, z_samples)
    if likelihood != "gaussian":
        images = jax.nn.sigmoid(images)
    images_np = np.array(images).reshape(-1, 28, 28).astype(np.float32)
    return images_np


# Generate and plot samples
rng, sample_key = jax.random.split(rng)
sampled_images = generate_samples(
    decoder, state.params["decoder"], sample_key, num_samples=10, latent_dim=latent_dim
)
fig = plot_examples(sampled_images, nrows=2, ncols=5)


## Visualize Latent Space

We can visualize the learned latent space by encoding a grid of points in the latent space and decoding them to generate images. This allows us to see how the model organizes different digits or classes in the latent space.
Remember that the latent space is a continuous representation, so we can interpolate between points to see how the model transitions between latent representations.


**Exercise**: Create a grid of points in the latent space, decode them, and visualize the generated images. 

In [None]:

def plot_latent_space(vae, n=30, figsize=15):
    # display a n*n 2D manifold of digits
    digit_size = 28
    scale = 1.0
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]
    # Create a meshgrid of latent space coordinates
    grid_z = np.dstack(np.meshgrid(grid_x, grid_y)).reshape(-1, 2)
    
    # Decode all points in the grid at once (batch computation)
    x_decoded = decoder.apply(state.params["decoder"], grid_z)
    
    # Reshape and place in the figure
    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            idx = i * n + j
            digit = x_decoded[idx].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("$z_0$")
    plt.ylabel("$z_1$")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


# Plot the latent space
plot_latent_space(decoder, n=30, figsize=15)


## Likelihood choices
The choice of likelihood function in the VAE decoder can significantly affect the model's performance and the quality of generated samples. 
So far, we have used a Gaussian likelihood, which is suitable for continuous data. However, we are modeling pixel values in the range [0, 1], which suggests that the Gaussian likelihood may not be the best choice due to its unbounded nature.
Instead, we can consider using three different likelihood functions:
1. **Bernoulli**: This is a discrete distribution suitable for binary data, where each pixel is treated as a Bernoulli random variable. It can be used with a sigmoid activation to model pixel probabilities, but it requires that the data is binary (0 or 1). This is still not ideal for continuous data, but it is used in practice by binarizing the pixel values.
1. **Continuous Bernoulli**: This is a continuous relaxation of the Bernoulli distribution, which allows for modeling continuous data in the range [0, 1]. It is defined as a continuous distribution that can model pixel values in the range [0, 1] without binarization. It is suitable for continuous data and can be used with a sigmoid activation to model pixel probabilities.
1. **Truncated Gaussian**: This is a Gaussian distribution truncated to the range [a, b]. It can handle continuous data while respecting the pixel value bounds. This is a more appropriate choice for modeling pixel values in images


**Exercise**: Modify the ELBO function above to support these three likelihood functions. Note that all three likelihoods are implemented in `tensorflow_probability`, so directly use them to compute the reconstruction loss. You can use the `likelihood` argument to specify which likelihood function to use. Remember to perform the necessary transformations to ensure the outputs of the decoder are compatible with the chosen likelihood function. For example, you may need to apply a sigmoid activation to the decoder outputs to ensure they are in the range [0, 1].
The KL divergence term remains the same for all likelihoods, as it is independent of the likelihood choice.


**Exercise**: After modifying the ELBO function, train the VAE with each likelihood function and visualize the generated samples. Compare the quality of samples generated by each likelihood function and discuss the differences.