## Introduction

<a href="https://colab.research.google.com/github/ntt123/pax/blob/main/examples/notebooks/VAE.ipynb" target="_top"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" style="vertical-align:text-bottom"></a>

This is an example notebook showing how to train and test a Variational AutoEncoder (VAE).

It closely follows Keras VAE example at 
https://keras.io/examples/generative/vae

## Setting up

In [1]:
# uncomment to install PAX
# !pip3 install -q git+https://github.com/NTT123/pax.git#egg=pax[test]

In [None]:
from typing import NamedTuple

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import opax
import pax
import tensorflow as tf
import tensorflow_datasets as tfds

### config ###
batch_size = 128
vae_dim = 2
learning_rate = 1e-3
pax.seed_rng_key(42)


class LossInfo(NamedTuple):
    # A record of training losses.
    loss: jnp.ndarray
    reconstruction_loss: jnp.ndarray
    kl_loss: jnp.ndarray

## Convolutional VAE model

In [None]:
# use `leaky_relu` instead of `relu` for better gradient signals.
from functools import partial

leaky_relu = partial(jax.nn.leaky_relu, negative_slope=0.1)


class VariationalAutoEncoder(pax.Module):
    def __init__(self, latent_dim: int):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = pax.nn.Sequential(
            pax.nn.Conv2D(1, 32, 3, (2, 2), padding="SAME"),
            leaky_relu,
            pax.nn.Conv2D(32, 64, 3, (2, 2), padding="SAME"),
            leaky_relu,
            lambda x: jnp.reshape(x, (x.shape[0], -1)),
            pax.nn.Linear(3136, 16),
            leaky_relu,
            pax.nn.Linear(16, 2 * latent_dim),
        )

        self.decoder = pax.nn.Sequential(
            pax.nn.Linear(latent_dim, 7 * 7 * 32),
            leaky_relu,
            lambda x: jnp.reshape(x, (x.shape[0], 7, 7, 32)),
            pax.nn.Conv2DTranspose(32, 64, 3, 2, padding="SAME"),
            leaky_relu,
            pax.nn.Conv2DTranspose(64, 32, 3, 2, padding="SAME"),
            leaky_relu,
            pax.nn.Conv2DTranspose(32, 1, 3, 1, padding="SAME"),
        )
        # register an internal random key.
        self.register_state("rng_key", pax.next_rng_key())

    def __call__(self, x):
        x = self.encoder(x)
        vae_mean, vae_logvar = jnp.split(x, 2, axis=-1)
        if self.training:
            # refresh the internal random key.
            self.rng_key, rng_key = jax.random.split(self.rng_key)
        else:
            rng_key = self.rng_key
        noise = jax.random.normal(rng_key, shape=vae_mean.shape, dtype=vae_mean.dtype)
        x = noise * jnp.exp(0.5 * vae_logvar) + vae_mean
        x = self.decoder(x)
        return x, vae_mean, vae_logvar

## Loss function

The loss function is the sum of the ``reconstruction`` loss and the ``kl-divergence`` loss.

In [None]:
def sigmoid_binary_cross_entropy(logits, targets):
    ls = jax.nn.log_sigmoid
    return -ls(logits) * targets - ls(-logits) * (1.0 - targets)


@pax.pure
def loss_fn(model: VariationalAutoEncoder, inputs):
    output, vae_mean, vae_logvar = model(inputs)
    N = inputs.shape[0]  # batch size
    reconstruction_loss = jnp.sum(sigmoid_binary_cross_entropy(output, inputs))
    reconstruction_loss = reconstruction_loss / N
    kl_loss = jnp.sum(
        -0.5 * (1 + vae_logvar - jnp.square(vae_mean) - jnp.exp(vae_logvar))
    )
    kl_loss = kl_loss / N
    loss = reconstruction_loss + kl_loss
    return loss, (LossInfo(loss, reconstruction_loss, kl_loss), model)

## Tensorflow Dataloader

In [None]:
def load_dataset(with_label=False):
    """Return a tensorflow dataset.

    Arguments:
        with_label: bool, return `(data, label)`.

    """
    ds = tfds.load("mnist:3.*.*")
    ds = ds["train"].concatenate(ds["test"])
    ds = (
        ds.map(lambda x: (tf.cast(x["image"], tf.float32) / 255.0, x["label"]))
        .map((lambda x, y: (x, y)) if with_label else (lambda x_, y_: x_))
        .cache()
        .shuffle(len(ds))  # shuffle the whole dataset
    )

    return ds

## Train model

In [None]:
def train():
    vae = VariationalAutoEncoder(vae_dim)
    print(f"===== VAE MODEL =====\n{vae.summary()}\n\n")
    optimizer = opax.adam(learning_rate, eps=1e-7)(vae.parameters())

    train_data = load_dataset(with_label=False).batch(batch_size, drop_remainder=True)
    fast_update_fn = jax.jit(pax.utils.build_update_fn(loss_fn))

    training_losses = []
    for epoch in range(50):
        losses = LossInfo(0.0, 0.0, 0.0)
        for batch in train_data:
            batch = jax.tree_map(lambda x: x.numpy(), batch)
            vae, optimizer, loss_info = fast_update_fn(vae, optimizer, batch)
            training_losses.append(loss_info.loss)
            losses = jax.tree_map(lambda x, y: x + y, losses, loss_info)

        losses = jax.tree_map(lambda x: x / len(train_data), losses)
        print(
            f"[Epoch {epoch:>2}]  train loss {losses.loss:.3f}  reconstruction loss {losses.reconstruction_loss:.3f}  kl_loss {losses.kl_loss:.3f}"
        )

    plt.plot(training_losses)
    plt.grid("on")
    plt.ylabel("Loss")
    plt.xlabel("Steps")
    plt.title("Training History")
    plt.show()
    return vae


vae = train()

## 2D Latent Space

We inspect the 2d latent space learned by our VAE model.

In [None]:
# Source: https://keras.io/examples/generative/vae
def plot_latent_space(vae, n=30, figsize=15):
    # display a n*n 2D manifold of digits
    digit_size = 28
    scale = 1.0
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    fast_decoder = jax.jit(lambda model, z: model.decoder(z))

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = jax.nn.sigmoid(fast_decoder(vae, z_sample))
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent_space(vae.eval())

In [None]:
# Source: https://keras.io/examples/generative/vae
def plot_label_clusters(vae, data):
    # display a 2D plot of the digit classes in the latent space

    all_z_means = []
    all_labels = []

    fast_encoder = jax.jit(lambda model, data: model.encoder(data))

    for batch in data:
        batch = jax.tree_map(lambda x: x.numpy(), batch)
        data, label = batch
        z = fast_encoder(vae, data)
        z_mean, _ = jnp.split(z, 2, axis=-1)
        all_z_means.append(z_mean)
        all_labels.append(label)

    all_z_means = jnp.concatenate(all_z_means)
    all_labels = jnp.concatenate(all_labels)
    plt.figure(figsize=(12, 10))
    plt.scatter(all_z_means[:, 0], all_z_means[:, 1], c=all_labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()


data_with_label = load_dataset(with_label=True).batch(batch_size, drop_remainder=True)
plot_label_clusters(vae.eval(), data_with_label)