# Variational Autoencoders: the Basics

Variation autoencoders (VAEs) are an extended type of autoencoders (AEs). A VAE can enhance the robustness of content generation by regularising the encodings distribution in the latent space. In this notebook, we will go through the fundamentals of VAEs (motivation, theory and Keras-based implementation) using the `mnist-digits` dataset. We will also learn two useful extensions of VAEs: the disentangled VAEs ($\beta$-VAEs) and the conditional VAEs. 

This notebook involves the minimal math to understand the VAEs. Check [06_VAE_theory.ipynb](06_VAE_theory.ipynb) for a more detailed theory. 

In [None]:
# tensorflow
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# check version
print('Using TensorFlow v%s' % tf.__version__)
acc_str = 'accuracy' if tf.__version__[:2] == '2.' else 'acc'

# helpers
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')


# need certainty to explain some of the results
import random as python_random
python_random.seed(0)
np.random.seed(0)
tf.random.set_seed(0)

---

# The Dataset

In this notebook, we will use the `mnist-digits` dataset. It is simpler than the `mnist-fashion` dataset, allowing us to use only two latent features so that we can conveniently visualise and examine the encodings distribution in the latent space.

In [None]:
# load dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()

# normalise images
train_images = train_images / 255.0
test_images = test_images / 255.0

# print info
print("Number of training data: %d" % len(train_labels))
print("Number of test data: %d" % len(test_labels))
print("Image pixels: %s" % str(train_images[0].shape))
print("Number of classes: %d" % (np.max(train_labels) + 1))

In [None]:
# function to plot an image in a subplot
def subplot_image(image, label, nrows=1, ncols=1, iplot=0, label2='', label2_color='r'):
    plt.subplot(nrows, ncols, iplot + 1)
    plt.imshow(image, cmap=plt.cm.binary)
    plt.xlabel(label, c='k', fontsize=12)
    plt.title(label2, c=label2_color, fontsize=12, y=-0.33)
    plt.xticks([])
    plt.yticks([])
    
# ramdomly plot some images
nrows = 4
ncols = 20
plt.figure(dpi=100, figsize=(ncols * 2, nrows * 2.2))
for iplot, idata in enumerate(np.random.choice(len(train_labels), nrows * ncols)):
    subplot_image(1 - train_images[idata], '', nrows, ncols, iplot)
plt.show()

---

# Autoencoders and Regularity of Latent Space

Why we need a VAE? To answer this question, let us start with an ordinary AE and see what is unsatisfactory when we use it to generate new images. 

## 1. Build and train an autoencoder

Based on what we have learnt in [05_autoencoder_basics.ipynb](05_autoencoder_basics.ipynb), we can quickly build an AE with `Dense` layers. First, we specify the latent dimension or the size of the bottleneck; for `mnist-digits`, we can use 2.

In [None]:
# latent dimension
latent_dim = 2

### The encoder

The encoder contains four layers, an input layer with size 28$\times$28, two hidden layers with sizes 128 and 16, respectively, and the latent output layer:

In [None]:
# build the encoder
image_input = keras.Input(shape=(28, 28))
x = layers.Flatten()(image_input)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dense(16, activation="relu")(x)
latent_output = layers.Dense(latent_dim)(x)
encoder_AE = keras.Model(image_input, latent_output)
encoder_AE.summary()

### The decoder

The decoder also contains four layers that are reciprocal to those of the encoders, taking the latent representation as the input:

In [None]:
# build the decoder
latent_input = keras.Input(shape=(latent_dim,))
x = layers.Dense(16, activation="relu")(latent_input)
x = layers.Dense(128, activation="relu")(x)
x = layers.Dense(28 * 28, activation="sigmoid")(x)
image_output = layers.Reshape((28, 28))(x)
decoder_AE = keras.Model(latent_input, image_output)
decoder_AE.summary()

### The autoencoder

Joining up the encoder and the decoder, we obtain the AE network:

In [None]:
# build the AE
image_input = keras.Input(shape=(28, 28))
latent = encoder_AE(image_input)
image_output = decoder_AE(latent)
ae_model = keras.Model(image_input, image_output)
ae_model.summary()

# compile the AE
ae_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

### Train the autoencoder

Now we can train our AE with the `mnist-digits` dataset:

In [None]:
# train the AE
ae_model.fit(train_images, train_images, epochs=50, batch_size=128, 
             validation_data=(test_images, test_images))

## 2. Inspect the latent space

Let us inspect how the images are distributed in the latent space. 

### Encode images

First, we encode the images by our AE. After that, each image becomes a 2D point (because `latent_dim=2`) in the latent space. 

In [None]:
# encode images by AE
train_encodings_AE = encoder_AE.predict(train_images)

### Scatter plot
We can plot the points in the latent space and colour them by their true labels:

In [None]:
# scatter plot of encodings in the latent space
def scatter_plot_encodings_latent(encodings, labels):
    plt.figure(dpi=100)
    scat = plt.scatter(encodings[:, 0], encodings[:, 1], c=labels, s=.5, cmap='Paired')
    plt.gca().add_artist(plt.legend(*scat.legend_elements(), 
                         title='Image labels', bbox_to_anchor=(1.5, 1.)))
    plt.xlabel('Feature X')
    plt.ylabel('Feature Y')
    plt.gca().set_aspect(1)
    plt.show()
    
# scatter plot of encodings by AE
scatter_plot_encodings_latent(train_encodings_AE, train_labels)

###  Histogram plot

Also, for each digit, we can plot the density histograms of the encodings along the two latent dimensions -- note that we are using the same feature range ($x$-axis) in all the histograms:

In [None]:
# histogram plot of encodings in the latent space
def hist_plot_encodings_latent(encodings, labels, digit, dim, ax):
    # extract
    encodings_digit = encodings[labels == digit, dim]
    # histogram
    ax.hist(encodings_digit, bins=60, density=True, color=['g', 'b'][dim], alpha=.5)
    # mean and std dev
    mean = np.mean(encodings_digit)
    std = np.std(encodings_digit)
    ax.axvline(mean, c='r')
    ax.set_xlabel('Digit %d, Feature %s\n~${\cal N}(\mu=%.1f, \sigma=%.1f)$' % 
                  (digit, ['X', 'Y'][dim], mean, std), c='k')
    
# histogram plot of encodings by AE
fig, axes = plt.subplots(5, 4, dpi=100, figsize=(15, 12), sharex=True)
plt.subplots_adjust(hspace=.4)
for digit in range(10):
    hist_plot_encodings_latent(train_encodings_AE, train_labels, digit, 0, 
                               axes[digit // 2, digit % 2 * 2 + 0])
    hist_plot_encodings_latent(train_encodings_AE, train_labels, digit, 1, 
                               axes[digit // 2, digit % 2 * 2 + 1])
plt.show()

### Regularity of the latent space

Both the scatter plot and the histogram plots show that the data distributions in the latent space are rather *irregular*. Some of the digits have very wide distributions (such as 1 and 7) and some very narrow distributions (such as 2 and 3). 

Remember that our goal of training this AE is neither dimensionality reduction nor denoising but to generate new images out of the original dataset. Image generation is done by the decoder, taking the latent representation (`X` and `Y` in the plots) as the input. An irregular latent space makes image generation less controllable and robust. Taking our case for example, two shortcomings are likely to emerge:

1. **Controllability**: sampling the entire latent space, we will generate much more of the widely distributed digits than the narrowly distributed ones; instead, if we limit the range of the latent space, we will loss some characteristics of the widely distributed ones;

2. **Robustness**: images that do not resemble any of the digits will be generated by the gaps between the distributions of the digits; such gaps increase with the range of the latent space.

## 3. Generate new images

The following function generates new images by uniformly sampling the latent space within a specified range (`[x0, x1]`, `[y0, y1]`). 

In [None]:
# generate images from the latent space
def generate_images_latent(decoder, x0, x1, dx, y0, y1, dy):
    # uniformly sample the latent space
    nx = round((x1 - x0) / dx) + 1
    ny = round((y1 - y0) / dy) + 1
    grid_x = np.linspace(x0, x1, nx)
    grid_y = np.linspace(y1, y0, ny)
    latent = np.array(np.meshgrid(grid_x, grid_y)).reshape(2, nx * ny).T

    # decode images
    decodings = decoder.predict(latent)
    
    # display a (nx, ny) 2D manifold of digits
    figure = np.zeros((28 * ny, 28 * nx))
    for iy in np.arange(ny):
        for ix in np.arange(nx):
            figure[iy * 28 : (iy + 1) * 28, ix * 28 : (ix + 1) * 28] = decodings[iy * nx + ix]
            
    # plot figure
    plt.figure(dpi=100, figsize=(nx / 3, ny / 3))
    plt.xticks(np.arange(28 // 2, nx * 28 + 28 // 2, 28), np.round(grid_x, 1), rotation=90)
    plt.yticks(np.arange(28 // 2, ny * 28 + 28 // 2, 28), np.round(grid_y, 1))
    plt.xlabel('Feature X')
    plt.ylabel('Feature Y')
    plt.imshow(figure, cmap="Greys_r")
    plt.grid(False)
    plt.show()

Let us see how the generated images look like. We choose a range of $-25<X<25$ and $-20<Y<30$, which encompasses all the digits and most of the data points. The two shortcomings can be observed:

1. Only a very few instances of the narrowly distributed digits are generated, such as 2 and 3;
2. Many images do not resemble any of the digits; note that the severely rotated digits should be recognised as non-digits in this context.

Feel free to try some other ranges. 

In [None]:
# generate images by AE
generate_images_latent(decoder_AE, x0=-20, x1=20, dx=1, y0=-20, y1=20, dy=1)

---

# Variational Autoencoders

Overfitting is the essential reason behind an irregular latent space of a naive AE, that is, the neural networks for encoding and decoding try their best to fit the data from end to end without caring about how the latent space is organised with respect to the original data. A VAE can regularise the latent space by imposing additional distributional properties on the latent space.

The following figure summarises **the two extensions** from an AE to a VAE:

1. Unlike a naive AE that encodes an input data $x$ as a single point $z$ in the latent space, a VAE encodes it as a normal distribution $\mathcal{N}(\mu, \sigma)$, and the latent representation $z$ is sampled from this distribution and then passed to the decoder;

2. An AE only minimises the reconstruction error $\lVert x-x'\rVert^2$ to fit the data, whereas a VAE minimises the sum of the reconstruction error and the KL divergence (Kullback–Leibler divergence) between the latent distribution $\mathcal{N}(\mu, \sigma)$ and the standard normal distribution $\mathcal{N}(0, 1)$.

How does a VAE regularise the latent space? The loss function provides a straightforward answer: in addition to fitting the data by minimising the reconstruction error, it also drags the latent distribution to a standard normal distribution. The final model is a trade-off between the two effects. Also, because each input image is encoded as a Gaussian blob instead of a single point, the gaps in the latent space can be filled by such blurring so that meaningless decodings can be largely avoided.

![ae-vae.png](https://i.ibb.co/DDTgq7Z/ae-vae.png)



## 1. Build and train a VAE

Now we will implement a VAE for `mnist-digits`. The rigorous theory is more complicated than explained above, which can be found in [06_VAE_theory.ipynb](06_VAE_theory.ipynb). 


### The encoder

To implement the probabilistic encoder, we first need a custom function to sample the latent distribution, as implemented by the `Sampling` class. Note that here we are using $\ln\sigma$ instead of $\sigma$ in the network; otherwise, the implementation will be complicated as we have to impose positiveness on $\sigma$.  

In [None]:
# sampling z with (z_mean, z_log_var)
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        epsilon = tf.keras.backend.random_normal(shape=tf.shape(z_mean))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
    
    
# build the encoder
image_input = keras.Input(shape=(28, 28))
x = layers.Flatten()(image_input)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z_output = Sampling()([z_mean, z_log_var])
encoder_VAE = keras.Model(image_input, [z_mean, z_log_var, z_output])
encoder_VAE.summary()

### The decoder

The docoder is the same as that of AE.

In [None]:
# build the decoder
z_input = keras.Input(shape=(latent_dim,))
x = layers.Dense(16, activation="relu")(z_input)
x = layers.Dense(128, activation="relu")(x)
x = layers.Dense(28 * 28, activation="sigmoid")(x)
image_output = layers.Reshape((28, 28))(x)
decoder_VAE = keras.Model(z_input, image_output)
decoder_VAE.summary()

### The VAE

To add the KL divergence to the loss, we create a class `VAE` derived from `keras.Model` and overwrite its `train_step()` method:

In [None]:
# VAE class
class VAE(keras.Model):
    # constructor
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    # customise train_step() to implement the loss 
    def train_step(self, x):
        if isinstance(x, tuple):
            x = x[0]
        with tf.GradientTape() as tape:
            # encoding
            z_mean, z_log_var, z = self.encoder(x)
            # decoding
            x_prime = self.decoder(z)
            # reconstruction error by binary crossentropy loss
            reconstruction_loss = tf.reduce_mean(keras.losses.binary_crossentropy(x, x_prime)) * 28 * 28
            # KL divergence
            kl_loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            # loss = reconstruction error + KL divergence
            loss = reconstruction_loss + kl_loss
        # apply gradient
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        # return loss for metrics log
        return {"loss": loss,
                "reconstruction_loss": reconstruction_loss,
                "kl_loss": kl_loss}

Now we can build, compile and train our VAE:

In [None]:
# build the VAE
vae_model = VAE(encoder_VAE, decoder_VAE)

# compile the VAE
vae_model.compile(optimizer=keras.optimizers.Adam())

# train the VAE
vae_model.fit(train_images, train_images, epochs=50, batch_size=128)

## 2. Inspect the latent space

Next, we can inspect the latent space following the same steps we did for the AE. Clearly, the latent distributions become much more regular than before.

In [None]:
# encode images by VAE
train_encodings_VAE = encoder_VAE.predict(train_images)

# scatter plot of encodings by VAE
scatter_plot_encodings_latent(train_encodings_VAE[2], train_labels)

# histogram plot of encodings by VAE
fig, axes = plt.subplots(5, 4, dpi=100, figsize=(15, 12), sharex=True)
plt.subplots_adjust(hspace=.4)
for digit in range(10):
    hist_plot_encodings_latent(train_encodings_VAE[2], train_labels, digit, 0, 
                               axes[digit // 2, digit % 2 * 2 + 0])
    hist_plot_encodings_latent(train_encodings_VAE[2], train_labels, digit, 1, 
                               axes[digit // 2, digit % 2 * 2 + 1])
plt.show()

## 3. Generate new images

Finally, we can generate new images with our VAE. The result shows that, compared to the AE, the numbers of the generated digits have become more in unison and the number of non-digit images has been greatly reduced.

In [None]:
# generate images by VAE
generate_images_latent(decoder_VAE, x0=-2, x1=2, dx=.1, y0=-2, y1=2, dy=.1)

---