# Variational Autoencoders

[![Open in Colab](https://lab.aef.me/files/assets/colab-badge.svg)](https://colab.research.google.com/github/adamelliotfields/lab/blob/main/files/tf/vae.ipynb)
[![Open in Kaggle](https://lab.aef.me/files/assets/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/adamelliotfields/lab/blob/main/files/tf/vae.ipynb)
[![Render nbviewer](https://lab.aef.me/files/assets/nbviewer_badge.svg)](https://nbviewer.org/github/adamelliotfields/lab/blob/main/files/tf/vae.ipynb)
[![W&B](https://img.shields.io/badge/Weights_&_Biases-FFCC33?logo=WeightsAndBiases&logoColor=black)](https://wandb.ai/adamelliotfields/vae-mnist)

This notebook includes implementations of fully-connected and convolutional variational autoencoders (VAEs) using TensorFlow and Keras, trained on MNIST. We'll use the [trainer pattern](https://keras.io/examples/keras_recipes/trainer_pattern/) to override the VAE's `train_step` method to implement a custom loss function in the training loop. We'll then sample points from the latent space at a regular interval and feed them through the decoder to transform them into a grid of new images.

![Latent digits](https://lab.aef.me/files/assets/vae_mlp_digits_v50.png)

## Introduction

An [autoencoder](https://en.wikipedia.org/wiki/Autoencoder) is a type of generative model that learns to encode data into a _latent space_ representation and then decode it back to the original space. The primary goal of an autoencoder is to learn an efficient representation (encoding) of the data.

In statistics, latent or "hidden" variables are ones that can only be inferred indirectly by modeling observed (visible) data. In the context of autoencoders, the latent space is a lower-dimensional representation of the data that captures the most important features.

A [variational autoencoder](https://en.wikipedia.org/wiki/Variational_autoencoder) (VAE) is an extension of the autoencoder that learns to model a probability distribution over the latent space. This approach allows the VAE to generate new data by sampling from the learned distribution. VAEs are a type of generative model, also known as a _latent variable model_, that can generate new data samples similar to the training data.

## Bayes Theorem

To understand VAEs, it's important to grasp some key concepts from probability theory.

[Bayes' Theorem](https://en.wikipedia.org/wiki/Bayes%27_theorem) is a cornerstone of [probability theory](https://en.wikipedia.org/wiki/Probability_theory) and provides a way to update the probability of a hypothesis based on new observations. It is defined as:

$p(z|x) = \frac{p(x|z) \cdot p(z)}{p(x)}$

Where:

* $p(z|x)$ is the posterior probability: the probability of the latent variable $z$ given the observed data $x$
* $p(x|z)$ is the likelihood: the probability of observing the data $x$ given the latent variable $z$
* $p(z)$ is the prior probability: the initial belief about the latent variable $z$ before observing the data
* $p(x)$ is the marginal likelihood: the total probability of observing the data $x$ under all possible values of $z$

In a VAE, $z$ is a lower-dimensional representation of the data $x$.

### Intractability of the Posterior

The marginal likelihood $p(x)$ (the denominator) requires [integrating](https://en.wikipedia.org/wiki/Integral) over all possible values of the latent variable $z$. In most cases, it is infeasible to compute this integral directly, hence why it is said to be intractable. To address this, variational inference (VI) is used to approximate the true posterior distribution with a simpler, tractable distribution $q(z|x)$. The goal is to find the distribution $q(z|x)$ that is as close as possible to $p(z|x)$.

To measure the difference between two probability distributions, the [Kullback-Leibler](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) (KL) divergence is used. However, directly minimizing the KL divergence is still a challenge because it depends on $p(z|x)$, which is intractable. Instead, VAEs optimize a different objective known as the Evidence Lower Bound (ELBO):

$\text{ELBO} = \mathbb{E}_{q(z|x)}[\log p(x|z)] - \text{KL}[q(z|x) || p(z)]$

The left term is the reconstruction loss, which measures how well the model reconstructs the input data. The right term is the KL divergence between the approximate posterior $q(z|x)$ and the prior $p(z)$. Thus, maximizing the ELBO is equivalent to minimizing the KL divergence between $q(z|x)$ and $p(z|x)$.

In practice, the encoder neural network approximates the posterior $q(z|x)$, and the decoder models the likelihood $p(x|z)$. The encoder net maps the input data $x$ to the parameters of $q(z|x)$, while the decoder net reconstructs $x$ from the latent space $z$. During training, both networks are optimized simultaneously to maximize the ELBO.

## Reparameterization Trick

In traditional neural networks using gradient descent, the weights are updated directly based on the gradients of the loss function. This relies on the _deterministic_ nature of the network's operations to compute the gradients. In a VAE, the encoder outputs the parameters of the distribution $q(z|x)$, which are then used to sample a latent variable $z$. This sampling operation is _stochastic_, introducing randomness which makes it impossible to compute the gradients.

To address this issue, the _reparameterization trick_ is used to express the random variable $z$ as a deterministic function of the input data $x$ and some additional random noise $\epsilon$. In practice, the encoder outputs the mean $\mu$ and the logarithm of the variance $\log \sigma^2$ of $q(z|x)$.

Instead of sampling $z$ directly from this distribution, we sample $\epsilon$ from a standard normal distribution $\mathcal{N}(0, I)$ (mean of `0` and standard deviation of `1`). Then, $z$ is computed as:

$z = \mu + \sigma \cdot \epsilon$

Here, $\sigma$ is obtained by taking the exponential of $\frac{1}{2} \log(\sigma^2)$ to ensure that $z$ is a differentiable function of $\mu$, $\sigma$, and $\epsilon$ allowing the model's parameters to be learned using standard backpropagation methods. In code, it looks like this:

```py
from keras import losses, ops

z_mean, z_log_var, z = encoder(data)
reconstruction = decoder(z)

reconstruction_loss = losses.binary_crossentropy(data, reconstruction)
reconstruction_loss = ops.mean(ops.sum(reconstruction_loss, axis=(1, 2)))

kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))

total_loss = reconstruction_loss + kl_loss
```

## Resources

https://github.com/lyeoni/pytorch-mnist-VAE

https://github.com/rcantini/CNN-VAE-MNIST

https://www.tensorflow.org/tutorials/generative/cvae

https://keras.io/examples/generative/vae

https://keras.io/examples/keras_recipes/trainer_pattern

[Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114) (Kingma & Welling, 2013)

[Tutorial on Variational Autoencoders](https://arxiv.org/abs/1606.05908) (Doersch, 2016)

In [None]:
import os
import subprocess

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["KERAS_BACKEND"] = "tensorflow"

try:
    from google.colab import userdata

    subprocess.run(["pip", "install", "-qU", "keras", "wandb"])
    os.environ["WANDB_DISABLE_GIT"] = "true"
    os.environ["WANDB_API_KEY"] = userdata.get("WANDB_API_KEY")
    os.environ["TFDS_DATA_DIR"] = "/content/drive/MyDrive/tensorflow_datasets"
except ImportError:
    pass

assert os.environ.get("WANDB_API_KEY"), "missing WANDB_API_KEY"

In [None]:
import wandb
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

from PIL import Image as PILImage
from wandb.integration.keras import WandbMetricsLogger

from keras import (
    Input,
    Model,
    backend,
    callbacks,
    layers,
    losses,
    metrics,
    ops,
    optimizers,
    random,
)

In [None]:
SEED = 42  # @param {type:"integer"}
EPOCHS = 50  # @param {type:"integer"}
VERBOSE = 1  # @param {type:"integer"}
LATENT_DIM = 2  # @param {type:"integer"}
BATCH_SIZE = 128  # @param {type:"integer"}
ACTIVATION = "gelu"  # @param ["relu", "leaky_relu", "swish", "gelu"] {type:"string"}
WEIGHT_DECAY = 0.004  # @param {type:"number"}
LEARNING_RATE = 0.001  # @param {type:"number"}

WANDB_PROJECT = "vae-mnist"  # @param {type:"string"}
WANDB_ENTITY = "adamelliotfields"  # @param {type:"string"}

In [None]:
# @title Functions
def get_latent_digits(decoder, n=20):
    scale = 1.0
    figure = np.zeros((28 * n, 28 * n))
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    # iterate over the grid
    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample, verbose=0)
            digit = x_decoded[0].reshape(28, 28)
            figure[
                i * 28 : (i + 1) * 28,
                j * 28 : (j + 1) * 28,
            ] = digit
    return figure

## Data

In [None]:
(mnist_train, mnist_test), mnist_info = tfds.load(
    "mnist",
    with_info=True,
    as_supervised=True,
    split=["train", "test"],
)
X_train = (
    mnist_train.concatenate(mnist_test)
    .map(lambda X, y: (tf.cast(X, tf.float32) / 255.0, y))
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

In [None]:
for X_batch, _ in X_train.take(1):
    for index in range(12):
        plt.subplot(3, 4, index + 1)
        plt.imshow(
            # index, height, width, channels[0]
            X_batch[index, :, :, 0],
            cmap="binary",
        )
        plt.axis("off")

## MLP VAE

In [None]:
# sampling layer
class MLPSampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def __init__(self, name="MLPSampling", **kwargs):
        super().__init__(**kwargs)
        self.name = name
        self.seed_generator = random.SeedGenerator(SEED)

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = ops.shape(z_mean)[0]
        dim = ops.shape(z_mean)[1]
        epsilon = random.normal(shape=(batch, dim), seed=self.seed_generator)
        return z_mean + ops.exp(0.5 * z_log_var) * epsilon


# encoder
def get_encoder_mlp(name="MLPEncoder"):
    inputs = Input(shape=(28, 28, 1))
    x = layers.Flatten()(inputs)
    x = layers.Dense(512, activation=ACTIVATION)(x)
    x = layers.Dense(256, activation=ACTIVATION)(x)
    z_mean = layers.Dense(LATENT_DIM)(x)
    z_log_var = layers.Dense(LATENT_DIM)(x)
    z = MLPSampling()([z_mean, z_log_var])
    return Model(inputs, [z_mean, z_log_var, z], name=name)


# decoder
def get_decoder_mlp(name="MLPDecoder"):
    inputs = Input(shape=(LATENT_DIM,))
    x = layers.Dense(256, activation=ACTIVATION)(inputs)
    x = layers.Dense(512, activation=ACTIVATION)(x)
    x = layers.Dense(28 * 28, activation="sigmoid")(x)
    outputs = layers.Reshape((28, 28, 1))(x)
    return Model(inputs, outputs, name=name)


# models
encoder_mlp = get_encoder_mlp()
decoder_mlp = get_decoder_mlp()


# VAE
class VAE_MLP(Model):
    def __init__(self, encoder, decoder, name="VAE_MLP", **kwargs):
        super().__init__(**kwargs)
        self.name = name
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss = metrics.Mean(name="total_loss")
        self.reconstruction_loss = metrics.Mean(name="reconstruction_loss")
        self.kl_loss = metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [self.total_loss, self.reconstruction_loss, self.kl_loss]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = losses.binary_crossentropy(data, reconstruction)
            reconstruction_loss = ops.mean(ops.sum(reconstruction_loss, axis=(1, 2)))
            kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
            kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss.update_state(total_loss)
        self.reconstruction_loss.update_state(reconstruction_loss)
        self.kl_loss.update_state(kl_loss)

        return {
            "loss": self.total_loss.result(),
            "reconstruction_loss": self.reconstruction_loss.result(),
            "kl_loss": self.kl_loss.result(),
        }


backend.clear_session()
vae_mlp = VAE_MLP(encoder_mlp, decoder_mlp)
vae_mlp.compile(optimizer=optimizers.AdamW(weight_decay=WEIGHT_DECAY, learning_rate=LEARNING_RATE))

In [None]:
with wandb.init(
    tags=["L4"],
    group="mlp",
    job_type="train",
    entity=WANDB_ENTITY,
    project=WANDB_PROJECT,
    config={
        "epochs": EPOCHS,
        "model": "MLP_VAE",
        "batch_size": BATCH_SIZE,
        "activation": ACTIVATION,
        "weight_decay": WEIGHT_DECAY,
        "learning_rate": LEARNING_RATE,
    },
) as run:
    vae_mlp.fit(
        X_train.map(lambda X, _: X),
        epochs=EPOCHS,
        verbose=VERBOSE,
        callbacks=[
            callbacks.ReduceLROnPlateau(mode="min", patience=5, monitor="reconstruction_loss"),
            callbacks.EarlyStopping(mode="min", patience=10, monitor="reconstruction_loss"),
            WandbMetricsLogger(log_freq="epoch"),
        ],
    )

    fig = get_latent_digits(decoder_mlp, n=20)
    img = PILImage.fromarray(np.uint8(fig * 255))
    img.save("latent_digits.png")
    run.log({"Latent Digits": wandb.Image("latent_digits.png")})

## CNN VAE

In [None]:
class CNNSampling(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = random.SeedGenerator(SEED)

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = ops.shape(z_mean)[0]
        dim = ops.shape(z_mean)[1]
        epsilon = random.normal(shape=(batch, dim), seed=self.seed_generator)
        return z_mean + ops.exp(0.5 * z_log_var) * epsilon


def get_encoder_cnn(name="CNNEncoder"):
    inputs = Input(shape=(28, 28, 1))
    x = layers.Conv2D(32, 3, activation=ACTIVATION, strides=2, padding="same")(inputs)
    x = layers.Conv2D(64, 3, activation=ACTIVATION, strides=2, padding="same")(x)
    x = layers.Flatten()(x)
    x = layers.Dense(16, activation=ACTIVATION)(x)
    z_mean = layers.Dense(LATENT_DIM)(x)
    z_log_var = layers.Dense(LATENT_DIM)(x)
    z = CNNSampling()([z_mean, z_log_var])
    return Model(inputs, [z_mean, z_log_var, z], name=name)


# decoder
def get_decoder_cnn(name="CNNDecoder"):
    inputs = Input(shape=(LATENT_DIM,))
    x = layers.Dense(7 * 7 * 64, activation=ACTIVATION)(inputs)
    x = layers.Reshape((7, 7, 64))(x)
    x = layers.Conv2DTranspose(64, 3, activation=ACTIVATION, strides=2, padding="same")(x)
    x = layers.Conv2DTranspose(32, 3, activation=ACTIVATION, strides=2, padding="same")(x)
    outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
    return Model(inputs, outputs, name=name)


# models
encoder_cnn = get_encoder_cnn()
decoder_cnn = get_decoder_cnn()


# VAE
class VAE_CNN(Model):
    def __init__(self, encoder, decoder, name="VAE_CNN", **kwargs):
        super().__init__(**kwargs)
        self.name = name
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss = metrics.Mean(name="total_loss")
        self.reconstruction_loss = metrics.Mean(name="reconstruction_loss")
        self.kl_loss = metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [self.total_loss, self.reconstruction_loss, self.kl_loss]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = losses.binary_crossentropy(data, reconstruction)
            reconstruction_loss = ops.mean(ops.sum(reconstruction_loss, axis=(1, 2)))
            kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
            kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss.update_state(total_loss)
        self.reconstruction_loss.update_state(reconstruction_loss)
        self.kl_loss.update_state(kl_loss)

        return {
            "loss": self.total_loss.result(),
            "reconstruction_loss": self.reconstruction_loss.result(),
            "kl_loss": self.kl_loss.result(),
        }


backend.clear_session()
vae_cnn = VAE_CNN(encoder_cnn, decoder_cnn)
vae_cnn.compile(optimizer=optimizers.AdamW(weight_decay=WEIGHT_DECAY, learning_rate=LEARNING_RATE))

In [None]:
with wandb.init(
    tags=["L4"],
    group="cnn",
    job_type="train",
    entity=WANDB_ENTITY,
    project=WANDB_PROJECT,
    config={
        "epochs": EPOCHS,
        "model": "CNN_VAE",
        "batch_size": BATCH_SIZE,
        "activation": ACTIVATION,
        "weight_decay": WEIGHT_DECAY,
        "learning_rate": LEARNING_RATE,
    },
) as run:
    vae_cnn.fit(
        X_train.map(lambda X, _: X),
        epochs=EPOCHS,
        verbose=VERBOSE,
        callbacks=[
            callbacks.ReduceLROnPlateau(mode="min", patience=5, monitor="reconstruction_loss"),
            callbacks.EarlyStopping(mode="min", patience=10, monitor="reconstruction_loss"),
            WandbMetricsLogger(log_freq="epoch"),
        ],
    )

    fig = get_latent_digits(decoder_cnn, n=20)
    img = PILImage.fromarray(np.uint8(fig * 255))
    img.save("latent_digits_cnn.png")
    run.log({"Latent Digits (CNN)": wandb.Image("latent_digits_cnn.png")})