# Variational Autoencoders: the Basics

Variation autoencoders (VAEs) are an extended type of autoencoders (AEs). A VAE can enhance the robustness of content generation by regularising the encodings distribution in the latent space. In this notebook, we will go through the fundamentals of VAEs (motivation, theory and Keras-based implementation) using the `mnist-digits` dataset. We will also learn two useful extensions of VAEs: the disentangled VAEs ($\beta$-VAEs) and the conditional VAEs. 

This notebook involves the minimal math to understand the VAEs. Check [06_VAE_theory.ipynb](06_VAE_theory.ipynb) for a more detailed theory. 

In [None]:
# tensorflow
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# check version
print('Using TensorFlow v%s' % tf.__version__)
acc_str = 'accuracy' if tf.__version__[:2] == '2.' else 'acc'

# helpers
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')

# need certainty to explain some of the results
import random as python_random
python_random.seed(0)
np.random.seed(0)
tf.random.set_seed(0)

---

# The Dataset

In this notebook, we will use the `mnist-digits` dataset. It is simpler than the `mnist-fashion` dataset, allowing us to use only two features in the latent space so that we can conveniently visualise and examine the encodings distribution in the latent space.

In [None]:
# load dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()

# normalise images
train_images = train_images / 255.0
test_images = test_images / 255.0

# print info
print("Number of training data: %d" % len(train_labels))
print("Number of test data: %d" % len(test_labels))
print("Image pixels: %s" % str(train_images[0].shape))
print("Number of classes: %d" % (np.max(train_labels) + 1))

In [None]:
# function to plot an image in a subplot
def subplot_image(image, label, nrows=1, ncols=1, iplot=0, label2='', label2_color='r'):
    plt.subplot(nrows, ncols, iplot + 1)
    plt.imshow(image, cmap=plt.cm.binary)
    plt.xlabel(label, c='k', fontsize=12)
    plt.title(label2, c=label2_color, fontsize=12, y=-0.33)
    plt.xticks([])
    plt.yticks([])
    
# ramdomly plot some images
nrows = 4
ncols = 20
plt.figure(dpi=100, figsize=(ncols * 2, nrows * 2.2))
for iplot, idata in enumerate(np.random.choice(len(train_labels), nrows * ncols)):
    subplot_image(1 - train_images[idata], '', nrows, ncols, iplot)
plt.show()

---

# Autoencoders and Regularity of Latent Space

Why we need a VAE? To answer this question, let us start with an ordinary AE and see what is unsatisfactory when we use it to generate new images. 

## 1. Build and train an autoencoder

Based on what we have learnt in [05_autoencoder_basics.ipynb](05_autoencoder_basics.ipynb), we can quickly build an AE with `Dense` layers. First, we need to specify the latent dimension or the size of the bottleneck; for `mnist-digits`, we can use 2.

**NOTE**: The code in this notebook supports any latent dimension, but some of the descriptions are only for `latent_dim=2`.  

In [None]:
# latent dimension
latent_dim = 2

# other constants
n_digits = 10
n_img = 28

### The encoder

The encoder contains four layers, an input layer with size 28$\times$28, two hidden layers with sizes 128 and 16, respectively, and the latent output layer:

In [None]:
# build the encoder
image_input = keras.Input(shape=(n_img, n_img))
x = layers.Flatten()(image_input)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dense(16, activation="relu")(x)
latent_output = layers.Dense(latent_dim)(x)
encoder_AE = keras.Model(image_input, latent_output)
encoder_AE.summary()

### The decoder

The decoder contains four layers that are reciprocal to those of the encoders, taking the latent representation as the input:

In [None]:
# build the decoder
latent_input = keras.Input(shape=(latent_dim,))
x = layers.Dense(16, activation="relu")(latent_input)
x = layers.Dense(128, activation="relu")(x)
x = layers.Dense(n_img * n_img, activation="sigmoid")(x)
image_output = layers.Reshape((n_img, n_img))(x)
decoder_AE = keras.Model(latent_input, image_output)
decoder_AE.summary()

### The autoencoder

Joining up the encoder and the decoder, we obtain the AE network:

In [None]:
# build the AE
image_input = keras.Input(shape=(n_img, n_img))
latent = encoder_AE(image_input)
image_output = decoder_AE(latent)
ae_model = keras.Model(image_input, image_output)
ae_model.summary()

# compile the AE
ae_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

### Train the autoencoder

Now we can train our AE with the `mnist-digits` dataset:

In [None]:
# train the AE
ae_model.fit(train_images, train_images, epochs=50, batch_size=128, 
             validation_data=(test_images, test_images))

## 2. Inspect the latent space

Let us inspect how the images are distributed in the latent space. 

### Encode images

First, we encode the images by our AE. After that, each image becomes a 2D point (when `latent_dim=2`) in the latent space. 

In [None]:
# encode images by AE
train_encodings_AE = encoder_AE.predict(train_images)

### Scatter plot

We can plot the 2D points and colour them by their true labels. 

**NOTE**: If `latent_dim>2`, use `feature_x` and `feature_y` to specify the two latent dimensions for plotting.

In [None]:
# scatter plot of encodings in the latent space
def scatter_plot_encodings_latent(encodings, labels, feature_x=0, feature_y=1):
    plt.figure(dpi=100)
    scat = plt.scatter(encodings[:, feature_x], encodings[:, feature_y], c=labels, s=.5, cmap='Paired')
    plt.gca().add_artist(plt.legend(*scat.legend_elements(), 
                         title='Image labels', bbox_to_anchor=(1.5, 1.)))
    plt.xlabel('Feature %d' % feature_x)
    plt.ylabel('Feature %d' % feature_y)
    plt.gca().set_aspect(1)
    plt.show()
    
# scatter plot of encodings by AE
scatter_plot_encodings_latent(train_encodings_AE, train_labels)

###  Histogram plot

Also, for each digit and each feature, we can plot the density histogram of encodings -- note that we are using the same feature range (range of $x$-axis) in all the histograms:

In [None]:
# histogram plot of encodings in the latent space
def hist_plot_encodings_latent(encodings, labels, digit, dim, ax, feature_range, color):
    # extract
    encodings_digit = encodings[labels == digit, dim]
    # histogram
    ax.hist(encodings_digit, bins=60, density=True, color=color, alpha=1)
    # mean and std dev
    mean = np.mean(encodings_digit)
    std = np.std(encodings_digit)
    ax.axvline(mean, c='k', ls='--')
    ax.set_xlabel('Feature %d, Digit %d\n${\cal N}(\mu=%.1f, \sigma^2=%.1f)$' % 
                  (dim, digit, mean, std * std), c='k')
    # feature range
    ax.set_xlim(feature_range)
    
# histogram plot of encodings by AE
feature_range = (-np.max(np.abs(train_encodings_AE)), np.max(np.abs(train_encodings_AE)))
for dim in range(latent_dim):
    print('Latent feature %d:' % dim)
    fig, axes = plt.subplots(2, 5, dpi=100, figsize=(15, 4), sharex=True)
    plt.subplots_adjust(hspace=.4)
    for digit in range(10):
        hist_plot_encodings_latent(train_encodings_AE, train_labels, digit, dim, 
                                   axes[digit // 5, digit % 5], feature_range, 'C%d' % dim)
    plt.show()

### Regularity of the latent space

Both the scatter plot and the histogram plots show that the data distributions in the latent space are rather *irregular*. Some of the digits have very wide distributions (such as 1 and 7) and some very narrow distributions (such as 2 and 3). 

Remember that our goal of training this AE is neither dimensionality reduction nor denoising but to generate new images out of the original dataset. Image generation is done by the decoder, taking the latent representation as the input. An irregular latent space makes image generation less controllable and robust. Taking our case for example, two shortcomings are likely to emerge:

1. **Controllability**: sampling the entire latent space, we will generate much more of the widely distributed digits than the narrowly distributed ones; instead, if we narrow down the latent space for sampling, we may loss some of the widely distributed ones;

2. **Robustness**: images that do not resemble any of the digits will be generated by sampling the gaps between the distributions of the digits; such gaps increase with the sampled range.

## 3. Generate new images

The following function generates new images by randomly sampling the latent space within a specified feature range. 

In [None]:
# generate images from the latent space
def generate_images_latent(decoder, n_generation, feature_range):
    # randomly sample the latent space
    latent = []
    for dim in range(latent_dim):
        if len(np.array(feature_range).shape) == 1:
            # only one range provided; used it for all dimensions
            latent.append(np.random.uniform(feature_range[0], feature_range[1], 
                                            n_generation * n_generation))
        else:
            # range provided for each dimension
            latent.append(np.random.uniform(feature_range[dim][0], feature_range[dim][1], 
                                            n_generation * n_generation))
    latent = np.array(latent).T
    
    # decode images
    decodings = decoder.predict(latent)
    
    # create and fill a large figure
    figure = np.zeros((n_img * n_generation, n_img * n_generation))
    for iy in np.arange(n_generation):
        for ix in np.arange(n_generation):
            figure[iy * n_img : (iy + 1) * n_img, ix * n_img : (ix + 1) * n_img] = decodings[iy * n_generation + ix]
            
    # plot figure
    plt.figure(dpi=100, figsize=(n_generation / 3, n_generation / 3))
    plt.imshow(figure, cmap="Greys_r")
    plt.axis('off')
    plt.show()

Let us see how the generated images look like. Here we choose a range of `[-25, 25]` for both the two features (when `latent_dim=2`), which encompasses all the digits and most of the data points. The two shortcomings can be observed:

1. Only a very few instances of the narrowly distributed digits are generated, such as 2 and 3;
2. Many images do not resemble any of the digits; note that the severely rotated digits should be recognised as non-digits in this context.

Feel free to try some other ranges. 

In [None]:
# generate images by AE
generate_images_latent(decoder_AE, n_generation=40, feature_range=[-25, 25])

---

# Variational Autoencoders

Overfitting is the essential reason behind an irregular latent space of a naive AE, that is, the neural networks for encoding and decoding try their best to fit the data from end to end without caring about how the latent space is organised with respect to the original data. A VAE can regularise the latent space by imposing additional distributional properties on the latent space.

The following figure summarises **the two extensions** from an AE to a VAE:

1. Unlike a naive AE that encodes an input data $x$ as a single point $z$ in the latent space, a VAE encodes it as a normal distribution $\mathcal{N}(\mu, \sigma^2)$, and the latent feature $z$ is sampled from this distribution and then passed to the decoder;

2. An AE only minimises the reconstruction error $\lVert x-x'\rVert^2$ to fit the data, whereas a VAE minimises the sum of the reconstruction error and the KL divergence (Kullback–Leibler divergence) between the latent distribution $\mathcal{N}(\mu, \sigma^2)$ and the standard normal distribution $\mathcal{N}(0, 1)$.

How does a VAE regularise the latent space? The loss function provides a straightforward answer: in addition to fitting the data by minimising the reconstruction error, it also drags the latent distribution to a standard normal distribution. The final model is a trade-off between the two effects. Also, because each input image is encoded as a Gaussian blob instead of a single point, the gaps in the latent space can be filled by such blurring so that meaningless decodings can be reduced.

![ae-vae.png](https://i.ibb.co/mJVqkNy/aevae.png)



## 1. Build and train a VAE

Now we will implement a VAE for `mnist-digits`. The rigorous theory is more complicated than explained above, which can be found in [06_VAE_theory.ipynb](06_VAE_theory.ipynb). 


### The encoder

To implement the probabilistic encoder, we first need a custom function to sample the latent distribution, as implemented by the `Sampling` class. Note that here we are using $\ln\sigma$ instead of $\sigma$ in the network; otherwise, the implementation will be complicated as we have to impose positiveness on $\sigma$.  

In [None]:
# sampling z with (z_mean, z_log_var)
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        epsilon = tf.keras.backend.random_normal(shape=tf.shape(z_mean))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
    
# build the encoder
image_input = keras.Input(shape=(n_img, n_img))
x = layers.Flatten()(image_input)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z_output = Sampling()([z_mean, z_log_var])
encoder_VAE = keras.Model(image_input, [z_mean, z_log_var, z_output])
encoder_VAE.summary()

### The decoder

The docoder is the same as that of AE.

In [None]:
# build the decoder
z_input = keras.Input(shape=(latent_dim,))
x = layers.Dense(16, activation="relu")(z_input)
x = layers.Dense(128, activation="relu")(x)
x = layers.Dense(n_img * n_img, activation="sigmoid")(x)
image_output = layers.Reshape((n_img, n_img))(x)
decoder_VAE = keras.Model(z_input, image_output)
decoder_VAE.summary()

### The VAE

To add the KL divergence to the loss, we create a class `VAE` derived from `keras.Model` and overwrite its `train_step()` method:

In [None]:
# VAE class
class VAE(keras.Model):
    # constructor
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    # customise train_step() to implement the loss 
    def train_step(self, x):
        if isinstance(x, tuple):
            x = x[0]
        with tf.GradientTape() as tape:
            # encoding
            z_mean, z_log_var, z = self.encoder(x)
            # decoding
            x_prime = self.decoder(z)
            # reconstruction error by binary crossentropy loss
            reconstruction_loss = tf.reduce_mean(keras.losses.binary_crossentropy(x, x_prime)) * n_img * n_img
            # KL divergence
            kl_loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            # loss = reconstruction error + KL divergence
            loss = reconstruction_loss + kl_loss
        # apply gradient
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        # return loss for metrics log
        return {"loss": loss,
                "reconstruction_loss": reconstruction_loss,
                "kl_loss": kl_loss}

Now we can build, compile and train our VAE:

In [None]:
# build the VAE
vae_model = VAE(encoder_VAE, decoder_VAE)

# compile the VAE
vae_model.compile(optimizer=keras.optimizers.Adam())

# train the VAE
training_history_VAE = vae_model.fit(train_images, train_images, epochs=50, batch_size=128)

Note that both `reconstruction_loss` and `kl_loss` are logged, so we can plot them out:

In [None]:
# plot loss
plt.figure(dpi=100, figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(training_history_VAE.history['loss'], label='Loss')
plt.plot(training_history_VAE.history['reconstruction_loss'], label='Reconstruction loss')
plt.legend()
plt.title("Loss")

plt.subplot(1, 2, 2)
plt.plot(training_history_VAE.history['kl_loss'], label='KL loss')
plt.legend()
plt.title("Loss")
plt.show()

## 2. Inspect the latent space

Next, we can inspect the latent space following the same steps we did for the AE. Clearly, the latent distributions become much more regular than before.

In [None]:
# encode images by VAE
train_encodings_VAE = encoder_VAE.predict(train_images)[2]

# scatter plot of encodings by VAE
scatter_plot_encodings_latent(train_encodings_VAE, train_labels)

# histogram plot of encodings by VAE
feature_range = (-np.max(np.abs(train_encodings_VAE)), np.max(np.abs(train_encodings_VAE)))
for dim in range(latent_dim):
    print('Latent feature %d:' % dim)
    fig, axes = plt.subplots(2, 5, dpi=100, figsize=(15, 4), sharex=True)
    plt.subplots_adjust(hspace=.4)
    for digit in range(10):
        hist_plot_encodings_latent(train_encodings_VAE, train_labels, digit, dim, 
                                   axes[digit // 5, digit % 5], feature_range, 'C%d' % dim)
    plt.show()

## 3. Generate new images

Finally, we can generate new images with our VAE. The result shows that, compared to the AE, the numbers of the generated digits have become more in unison and the number of non-digit images has been decreased.

In [None]:
# generate images by AE
generate_images_latent(decoder_VAE, n_generation=40, feature_range=[-2, 2])

---

# Disentangled Variational Autoencoders

Disentangled VAEs or $\beta$-VAEs are an extended type of VAEs. As explained before, the loss of a VAE consists of two parts, the reconstruction error and the KL divergence: the former tries to fit the network with data and the later to regularise the latent space. It is thus natural to introduce a hyperparameter to weigh these two forces, which gives rise to a disentangled VAE. As you can see in the following cartoon, a factor $\beta$ is applied to the KL divergence, and this $\beta$ is usually greater than 1. In practice, the value of $\beta$ has to be tuned for the best quality of content generation.

So, why we call it "disentangled"? Consider the limit $\beta\rightarrow\infty$: in this case, the distributions of all the latent features will simply converge to $\mathcal{N}(0,1)$, regardless of the data, so they are completely disentangled. In other words, the latent features become increasingly disentangled as $\beta$ increases. 


![bvae.png](https://i.ibb.co/mSmVrHy/bvaee.png)

## 1. Implement a $\beta$-VAE

It is straightforward to generalise our VAE implementation: just adding $\beta$ as a hyperparameter to the constructor of the `VAE` class and apply it to the KL divergence in the `train_step()` method. The differences are highlighted in the code.


### The encoder and decoder

They are exactly the same as before.

In [None]:
# build the encoder
image_input = keras.Input(shape=(n_img, n_img))
x = layers.Flatten()(image_input)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z_output = Sampling()([z_mean, z_log_var])
encoder_BVAE = keras.Model(image_input, [z_mean, z_log_var, z_output])

# build the decoder
z_input = keras.Input(shape=(latent_dim,))
x = layers.Dense(16, activation="relu")(z_input)
x = layers.Dense(128, activation="relu")(x)
x = layers.Dense(n_img * n_img, activation="sigmoid")(x)
image_output = layers.Reshape((n_img, n_img))(x)
decoder_BVAE = keras.Model(z_input, image_output)

### The $\beta$-VAE

We will use $\beta=10$ for demonstration. Try some other values (particularly a very large value such as 1000) and see what happens. Note that we are logging the KL loss scaled by $\beta$.

In [None]:
# BVAE class
class BVAE(keras.Model):
    # constructor
    ########################################################
    ######## NEW: passing beta as an extra argument ########
    ########################################################
    def __init__(self, encoder, decoder, beta, **kwargs):
        super(BVAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.beta = beta

    # customise train_step() to implement the loss 
    def train_step(self, x):
        if isinstance(x, tuple):
            x = x[0]
        with tf.GradientTape() as tape:
            # encoding
            z_mean, z_log_var, z = self.encoder(x)
            # decoding
            x_prime = self.decoder(z)
            # reconstruction error by binary crossentropy loss
            reconstruction_loss = tf.reduce_mean(keras.losses.binary_crossentropy(x, x_prime)) * n_img * n_img
            # KL divergence
            kl_loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            # loss = reconstruction error + KL divergence
            #######################################
            ######## NEW: scale KL by beta ########
            #######################################
            loss = reconstruction_loss + self.beta * kl_loss
        # apply gradient
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        #######################################
        ######### NEW: log scaled KL ##########
        #######################################
        # return loss for metrics log
        return {"loss": loss,
                "reconstruction_loss": reconstruction_loss,
                "beta_kl_loss": self.beta * kl_loss}

# build the BVAE
########################################
######## NEW: pass beta to BVAE ########
########################################
bvae_model = BVAE(encoder_BVAE, decoder_BVAE, beta=10.)

# compile the BVAE
bvae_model.compile(optimizer=keras.optimizers.Adam())

# train the BVAE
training_history_BAVE = bvae_model.fit(train_images, train_images, epochs=50, batch_size=128)

In [None]:
# plot loss
plt.figure(dpi=100, figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(training_history_BAVE.history['loss'], label='Loss')
plt.plot(training_history_BAVE.history['reconstruction_loss'], label='Reconstruction loss')
plt.legend()
plt.title("Loss")

plt.subplot(1, 2, 2)
plt.plot(training_history_BAVE.history['beta_kl_loss'], label='KL loss scaled by beta')
plt.legend()
plt.title("Loss")
plt.show()

## 2. Evaluate the results

The same as before, we can evaluate the latent distributions and generate new images using our $\beta$-VAE. As expected, the distributions become more close to normal distributions because $\beta=10$.  

In [None]:
# encode images by BVAE
train_encodings_BVAE = encoder_BVAE.predict(train_images)[2]

# scatter plot of encodings by BVAE
scatter_plot_encodings_latent(train_encodings_BVAE, train_labels)

# histogram plot of encodings by BVAE
feature_range = (-np.max(np.abs(train_encodings_BVAE)), np.max(np.abs(train_encodings_BVAE)))
for dim in range(latent_dim):
    print('Latent feature %d:' % dim)
    fig, axes = plt.subplots(2, 5, dpi=100, figsize=(15, 4), sharex=True)
    plt.subplots_adjust(hspace=.4)
    for digit in range(10):
        hist_plot_encodings_latent(train_encodings_BVAE, train_labels, digit, dim, 
                                   axes[digit // 5, digit % 5], feature_range, 'C%d' % dim)
    plt.show()

# generate images by BVAE
generate_images_latent(decoder_BVAE, n_generation=40, feature_range=[-2, 2])

---

# Conditional Variational Autoencoders

In many applications, we hope to generate contents based on the labels. For example, we hope to generate images for a given digit from `mnist-digits`. Conditional VAEs serve this purpose. In a conditional VAE, the labels are sent to both the encoder and the decoder as an extra input, as shown in the following figure. 

<img src="https://i.ibb.co/0Mj3Q8k/cvaec.png" width=100% height=100% />


## 1. Implement a conditional VAE

The simplest way to implement a conditional VAE is to concatenate the labels to both the input data and the latent features. We will use this approach for our implementation. The one-hot encodings of the labels will be concatenated, which is a vector of size 10 (the condition dimension). 

### The encoder

The encoder is the same as that of the VAE except that

1. the input size is increased by the condition dimension; 
2. we no longer flatten the images inside the network because of the concatenation.

In [None]:
# dimension of condition
condition_dim = n_digits

# encoder
###############################################################
######## NEW: input size is increased by condition_dim ########
###############################################################
image_input = keras.Input(shape=(n_img * n_img + condition_dim))
x = layers.Dense(128, activation='relu')(image_input)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z_output = Sampling()([z_mean, z_log_var])
encoder_CVAE = keras.Model(image_input, [z_mean, z_log_var, z_output])
encoder_CVAE.summary()

### The decoder

The same as the encoder, the input size of the decoder is increased by the condition dimension, and the reconstructed images will be flattened externally.

In [None]:
# decoder
###############################################################
######## NEW: input size is increased by condition_dim ########
###############################################################
z_input = keras.Input(shape=(latent_dim + condition_dim,))
x = layers.Dense(16, activation="relu")(z_input)
x = layers.Dense(128, activation="relu")(x)
image_output = layers.Dense(n_img * n_img, activation="sigmoid")(x)
decoder_CVAE = keras.Model(z_input, image_output)
decoder_CVAE.summary()

### The conditional VAE

The conditional VAE is the same as the VAE except that

1. the condition dimension is passed to the constructor;
2. the labels are concatenated to the latent features before they are decoded;
3. the labels are truncated from the input data to compute the reconstruction error.

In [None]:
# CVAE class
class CVAE(keras.Model):
    # constructor
    #################################################################
    ######## NEW: passing condition_dim as an extra argument ########
    #################################################################
    def __init__(self, encoder, decoder, condition_dim, **kwargs):
        super(CVAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.condition_dim = condition_dim

    # customise train_step() to implement the loss 
    def train_step(self, x):
        if isinstance(x, tuple):
            x = x[0]
        with tf.GradientTape() as tape:
            # encoding
            z_mean, z_log_var, z = self.encoder(x)
            ####################################################################
            ######## NEW: apply conditions to encodings before decoding ########
            ####################################################################
            z_cond = tf.concat([z, x[:, -self.condition_dim:]], axis=1)
            # decoding
            x_prime = self.decoder(z_cond)
            ###################################################################
            ######## NEW: truncate conditions for reconstruction error ########
            ###################################################################
            # reconstruction error by binary crossentropy loss
            reconstruction_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(x[:, :-self.condition_dim], x_prime)) * n_img * n_img
            # KL divergence
            kl_loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            # loss = reconstruction error + KL divergence
            loss = reconstruction_loss + kl_loss
        # apply gradient
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        # return loss for metrics log
        return {"loss": loss,
                "reconstruction_loss": reconstruction_loss,
                "kl_loss": kl_loss}

# build the CVAE
cvae_model = CVAE(encoder_CVAE, decoder_CVAE, condition_dim)

# compile the CVAE
cvae_model.compile(optimizer=keras.optimizers.Adam())

Before training, we need to concatenate the images and their labels to generate the training data.

In [None]:
# one-hot encoding labels
train_labels_onehot = np.eye(10)[train_labels]

# flatten images
train_images_flattened = train_images.reshape((len(train_images), n_img * n_img))

# concatenate labels to images
train_images_conditioned = np.concatenate((train_images_flattened, train_labels_onehot), axis=1)

# train the VAE
cvae_model.fit(train_images_conditioned, train_images_flattened, 
               epochs=50, batch_size=128)

## 2. Evaluate the results

The latent distributions can be assessed in the same way. Note that *each digit now has its own latent distributions because of the conditioning* (all resemble a $\mathcal{N}(0,1)$), so the latent space is highly regular and disentangled. 

In [None]:
# encode images by CVAE
train_encodings_CVAE = encoder_CVAE.predict(train_images_conditioned)[2]

# scatter plot of encodings by CVAE
scatter_plot_encodings_latent(train_encodings_CVAE, train_labels)

# histogram plot of encodings by BVAE
feature_range = (-np.max(np.abs(train_encodings_CVAE)), np.max(np.abs(train_encodings_CVAE)))
for dim in range(latent_dim):
    print('Latent feature %d:' % dim)
    fig, axes = plt.subplots(2, 5, dpi=100, figsize=(15, 4), sharex=True)
    plt.subplots_adjust(hspace=.4)
    for digit in range(10):
        hist_plot_encodings_latent(train_encodings_CVAE, train_labels, digit, dim, 
                                   axes[digit // 5, digit % 5], feature_range, 'C%d' % dim)
    plt.show()

Finally, we can generate new images for any given digit:

In [None]:
# generate images from the conditioned latent space
def generate_images_conditioned_latent(decoder, n_generation, feature_range, condition_digit, ax):
    # randomly sample the latent space
    latent = []
    for dim in range(latent_dim):
        if len(np.array(feature_range).shape) == 1:
            # only one range provided; used it for all dimensions
            latent.append(np.random.uniform(feature_range[0], feature_range[1], 
                                            n_generation * n_generation))
        else:
            # range provided for each dimension
            latent.append(np.random.uniform(feature_range[dim][0], feature_range[dim][1], 
                                            n_generation * n_generation))
    latent = np.array(latent).T
    
    # condition the latent space
    condiont_one_hot = np.eye(10)[condition_digit]
    latent = np.concatenate((latent, np.repeat([condiont_one_hot], len(latent), axis=0)), axis=1)
    
    # decode images
    decodings = decoder.predict(latent).reshape((len(latent), n_img, n_img))
    
    # create and fill a large figure
    figure = np.zeros((n_img * n_generation, n_img * n_generation))
    for iy in np.arange(n_generation):
        for ix in np.arange(n_generation):
            figure[iy * n_img : (iy + 1) * n_img, ix * n_img : (ix + 1) * n_img] = decodings[iy * n_generation + ix]
            
    # plot figure
    ax.imshow(figure, cmap="Greys_r")
    ax.axis('off')
    
# generate images by CVAE
fig, axes = plt.subplots(5, 2, dpi=100, figsize=(20, 50), )
for digit in range(10):
    generate_images_conditioned_latent(decoder_CVAE, n_generation=20, feature_range=[-1, 1], 
                                       condition_digit=digit, ax=axes[digit // 2, digit % 2])
plt.show()

---

## Exercises

1. Repeat the above processes with a larger `latent_dim`; you will find that, compare to `latent_dim=2`,
    * the improvement from AE to VAE and that from VAE to $\beta$-VAE both become more significant;
    * the variability and the resolution of the generated images can be improved (with more epochs for training). 
    

2. Use a VAE, $\beta$-VAE or conditional VAE to generate images from the `mnist-fashion` dataset. 