<a href="https://colab.research.google.com/github/EiffL/Tutorials/blob/master/GenerativeModels/IntroToVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to Variational Auto-Encoders


Author: [@EiffL](https://github.com/EiffL) (Francois Lanusse)

### Overview

In this tutorial, we will progressively learn how to build a Variational Auto-Encoder starting from a classical Auto-Encoder. We will use simple convolutional architecture on the MNIST dataset, the goal being to understand all of the basic mechanisms.

Learning objectives:
  - Use TensorFlow Dataset to load MNIST digits
  - Use Keras to build and train an Auto-Encoder (AE)
  - Build insight of AE latent spaces, indentifying the limitations of this  model  
  - Learn how to use TensorFlow Probability probabilistic layers
  - Use Keras & TFP to build a Varriational Auto-Encoder (VAE)
  - Sample new digits from trained VAE


### Instructions for enabling GPU access

By default, notebooks are started without acceleration. To make sure that the runtime is configured for using GPUs, go to `Runtime > Change runtime type`, and select GPU in `Hardware Accelerator`.

### Imports and setup

In [None]:
%pylab inline
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
tfpl = tfp.layers
tfkl = tf.keras.layers
tfd = tfp.distributions

### Checking for GPU access

In [None]:
#Checking for GPU access
if tf.test.gpu_device_name() != '/device:GPU:0':
  print('WARNING: GPU device not found.')
else:
  print('SUCCESS: Found GPU: {}'.format(tf.test.gpu_device_name()))

## Loading data and creating an Input Pipeline

Our first step will be load the MNIST dataset using the extremely convenient library [TensorFlow Datasets](https://www.tensorflow.org/datasets). All sorts of common datasets are directly available through that library and can be accessed in just one line. You can see the full list of available datastets [here](https://www.tensorflow.org/datasets/catalog).




In [None]:
mnist_dset = tfds.load(name="mnist", # Name of the dataset
                       split="train") # split, "train" or "test"

This creates an instance of `tf.data.Datasets`:

In [None]:
mnist_dset

we see that the dataset contains a dictionary with images of size (28,28,1) of type int8 and an associated label. 
Examples can be drawn from the dataset like this:

In [None]:
for example in mnist_dset.take(1): # We take only one example from the dataset 
  print("Our example contains the following keys", example.keys())
  imshow(example['image'][:,:,0],cmap='gray'); colorbar()
  title("%d"%example['label'])


For our purpose of generative modeling, we only need to grab the image, we don't care about thte label, and we are going to preprocess these images as floats and rescale thembetween 0 and 1. This can all be done by a preprocessing function:

In [None]:
def normalize_img(example):
  """ Preprocessing function that rescales an image between 0,1
  This pre-processing function will return twice the image because for an 
  autoencoder the target is the same as the input.
  """
  im = tf.cast(example['image'], tf.float32) / 255.
  return im, im

We can now create a full input pipeline for our 
dataset using this pre-processing function.

In [None]:
dset = mnist_dset.map(normalize_img) # Apply the pre-processing function
dset = dset.cache()                  # Cache the results
dset = dset.shuffle(60000)           # Shuffle the data over a given buffer 
dset = dset.batch(128)               # Batch the data
dset = dset.prefetch(tf.data.experimental.AUTOTUNE) # Pre-fetch the data in parrallel

To learn more about how to use the `tf.data.Datasets` API, check out [this documentation](https://www.tensorflow.org/guide/data).

Let's sample a batch from our newly created dataset:

In [None]:
for im, target in dset.take(1):
  print("We now have a batch of images of size", im.shape)
  imshow(im[0,:,:,0],cmap='gray'); colorbar();

## Building a Keras Auto-Encoder


Now that we have access to some data, our first goal will be to create a Convolutional Auto-Encoder, which can compress images (in our case of size 28x28) down to some low dimensional latent representation (for instance 2).

### Building an encoder

We begin with the encoder. We want to build a function that can create an encoder to compress images down to some dimensionality.

In [None]:
def get_encoder(latent_dim=2):
  """ Creates a small convolutional encoder for the requested latent dimension
  """
  return tf.keras.Sequential([ 
      tfkl.Input(shape=(28,28,1)),
      tfkl.Conv2D(32, kernel_size=3, activation='relu', strides=2, padding='same'),
      tfkl.Conv2D(64, kernel_size=3, activation='relu', strides=2, padding='same'),
      tfkl.Flatten(),
      tfkl.Dense(latent_dim)
      ])

In [None]:
encoder = get_encoder() # Instantiate a Keras model using our function

This has instantiated our encoder, we print out a summary of the model like so:

In [None]:
encoder.summary()

We see that through the different layers of the model, the tensors change as follows:
 - images 28x28x1 at the input level (no shown)
 - images 14x14x32
 - images 7x7x64
 - vector 3136
 - vector 2 at the output level

Even though the model is not trained yet, we can already transform images with the encoder:

In [None]:
for batch_im, batch_target in dset.take(1): # Sample only one batch of images
  batch_encoded = encoder(batch_im)         # Apply the encoder on images 

# And we recover the encoding for all images of the batch
print(batch_encoded.shape) 

In [None]:
scatter(batch_encoded[:,0], batch_encoded[:,1])
xlabel('z1')
ylabel('z2')


### Building a decoder

The next step is to build an decoder that mirrors the encoder and transforms a vector of low dimensionality back to an image.

In [None]:
def get_decoder(latent_dim=2):
  """ Creates a small convolutional decoder for the requested latent dimension
  """
  return tf.keras.Sequential([
      tfkl.Input(shape=(latent_dim,)),
      tfkl.Dense(7*7*64, activation='relu'),
      tfkl.Reshape((7,7,64)),
      tfkl.Conv2DTranspose(64, kernel_size=3, activation='relu', strides=2, padding='same'),
      tfkl.Conv2DTranspose(32, kernel_size=3, activation='relu', strides=2, padding='same'),
      tfkl.Conv2DTranspose(1, kernel_size=3, activation='sigmoid', strides=1, padding='same')                 
  ])

In [None]:
decoder = get_decoder() # Instantiate decoder

In [None]:
decoder.summary() # Print out the summary

We see that now the following happens in the model:

 - vector 2 at the output level (not shown)
 - vector 3136
 - images 7x7x64
 - images 14x14x64
 - images 28x28x32
 - images 28x28x1 at the output

This more or less reflects our encoder but in reverse.

Although not trained yet, we can already run the model to decode our encoded images:

In [None]:
batch_decoded = decoder(batch_encoded) # Runs the decoder on our previously 
                                       # encoded images

And just for fun, we can try to see how the decoded images look like:

In [None]:
subplot(121)
imshow(batch_im[0,:,:,0],cmap='gray'); title('Input image')
subplot(122)
imshow(batch_decoded[0,:,:,0],cmap='gray'); title('Auto-Encoded image')

Unsurprisingly, we get a bunch of garbage :-D. Let's try to do some training and see what happens after that.

### Training the Auto-Encoder

We want to train the encoder and decoder simulateously so that they learn the identity, i.e. decoder(encoder(x)) = x.

To do this, let's define a new Keras model that just concatenate both models:

In [None]:
auto_encoder = tf.keras.Sequential([
      tfkl.InputLayer([28,28,1]),                              
      encoder,
      decoder])

This has created an auto-encoder by concatenation of individual models. To see what has happened we  can look at the model summary:

In [None]:
auto_encoder.summary()

This model does the following:
- images 28x28x1 at the input (not shown)
- vector 2
- images 28x28x1 at the output

The last step is to "compile" the Keras model, i.e. specifying an optimizer and a loss function. 

We are going to use the extremely popular `Adam` optimizer. As a loss function, since our data is binary (0 and 1), we are going to use a binary cross entropy.

In [None]:
auto_encoder.compile(optimizer=tf.keras.optimizers.Adam(),
                    loss=tf.keras.losses.binary_crossentropy)

Now that the model is compiled, it can be  fitted to the data using the `.fit()` method. We will just have to provide our dataset and Keras will take care of the rest. 

Note that our dataset returns tuples of data `(batch_im,  batch_target)`, this is interpreted by Keras as `batch_im` being the input of the model, and the second entry in the tuple `batch_target` being the desired output that the model should learn.

In [None]:
history = auto_encoder.fit(dset, epochs=20)  # Starts training both encoder and decoder
                                             # for 20 epochs

And that's it, our model should more or less be trained by now. We can check the model history to see what the loss function looks like as a function of training epochs:

In [None]:
plot(history.history['loss'])
xlabel('epoch')
ylabel('reconstruction loss');

And much more interestingly, we can apply the model on a batch of images and see what comes out:

In [None]:
codes = encoder(batch_im)
decoded_images = decoder(codes)
# Here we use the encoder/decoder separately but we could do just the same:
decoded_images = auto_encoder(batch_im)

In [None]:
figure(figsize=(15,5))
subplot(131)
imshow(batch_im[0,:,:,0],cmap='gray')
title('First input image of batch')
subplot(132)
scatter(codes[:,0], codes[:,1])
scatter(codes[0,0], codes[0,1])
title('Latent encoding of batch')
subplot(133)
imshow(decoded_images[0,:,:,0],cmap='gray')
title('Decoded image');

### Exploring the Auto-Encoder

In this last sub-section, we will try to build a little bit more insight into the Auto-Encoder, its latent space, and how it is behaving. This will serve as motivation for going to Variational Auto-Encoders.


We begin by defining a new dataset using the `test` split of the MNIST data, and this time we also want to keep the labels of each digit:

In [None]:
mnist_test_dset = tfds.load(name="mnist",
                            split='test')

def normalize_img_test(example):
  """ Normalize images, like during training, but also returns label
  """
  im = tf.cast(example['image'], tf.float32) / 255.
  return im, example['label']

# We build a simplified pipeline for testing
dset_test = mnist_test_dset.map(normalize_img_test)
dset_test = dset_test.batch(1024) # We use a large batch of 1024 examples

In [None]:
for batch_im, batch_labels in dset_test.take(1):
  # This extracts one batch of the test dset and shows the first example
  imshow(batch_im[0,:,:,0],cmap='gray')
  title('This is a %d'%batch_labels[0])

#### Auto-Encoding quality

Let us compare  input and output images for a few examples.

In [None]:
autoencoded_im = auto_encoder(batch_im)

Let's first draw a few images from the input dataset:

In [None]:
figure(figsize=(5,5))
for i in range(4):
  for j in range(4):
    subplot(4,4, i*4+j+1)
    imshow(batch_im[i*4+j,:,:,0],cmap='gray')
    axis('off')

And let's see how the model is able to represent these images:

In [None]:
figure(figsize=(5,5))
for i in range(4):
  for j in range(4):
    subplot(4,4, i*4+j+1)
    imshow(autoencoded_im[i*4+j,:,:,0],cmap='gray')
    axis('off')

We recognize more or less the digits but the quality is not excellent, what is interesting is that sometimes the representation is more semantic than a reconsrtuction, i.e. a 0 gets auto-encoded as a 0, but not necessarily as the same 0.

The main reason why the quality is not excellent is that the model is not powerful enough to map 28x28 down to 2 without losing information. There would be 2 solutions to improve on this:
 - Implement a more complex auto-encoder (more layers)
 - Increase the dimensionality of the latent space, to make the problem simpler.

#### Visualizing the latent space

We can now have a look at how different digits are encoded in the latent space.

In [None]:
codes = encoder(batch_im) # Encodes the images

In [None]:
scatter(codes[:,0],codes[:,1],c=batch_labels,cmap='tab10'); colorbar() # Plot the encoding

We see that the model tries to naturally place different digits in different regions of latent space without overlapping too much. 

But we see that the overall distribution of codes is not regular, it has gaps, arbitrary extent, and weird and non-trivial shapes. We can try to sample this latent space on a regular grid and see what the learned manifold looks like.


In [None]:
n = 30
scale = 2.
# Find out mean and std of encoding
x_mean, x_std = np.mean(codes[:,0]), np.std(codes[:,0])
y_mean, y_std = np.mean(codes[:,1]), np.std(codes[:,1])
# Create uniform grid
grid_x = np.linspace(x_mean - scale*x_std, x_mean + scale*x_std, n)
grid_y = np.linspace(y_mean - scale*y_std, y_mean + scale*x_std, n)[::-1]
# Reshape into batch of coordinates
batch_latents = np.stack(meshgrid(grid_x,grid_y),axis=-1)
batch_latents = batch_latents.reshape((-1,2))
# Run through decoder
batch_samples = tf.reshape(decoder(batch_latents), (30,30,28,28))
# Reshape into one giant image
fig = batch_samples.numpy().transpose((0,2,1,3)).reshape((30*28,30*28))

# Plot the figure with corresponding latent ticks
figure(figsize=(10, 10))
imshow(fig, cmap='gray')
start_range = 28 // 2
end_range = n * 28 + start_range 
pixel_range = np.arange(start_range, end_range, 28)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
xticks(pixel_range, sample_range_x)
yticks(pixel_range, sample_range_y)
xlabel("z[0]");
ylabel("z[1]");

#### Trying to sample new digits

Now that we have an embedding, we can *try* to use and create a simple generative model. For this, we will look at the distribution of latent space codes and randomly sample new points in a similar distribution.

In [None]:
hist(codes[:,0], 64, label='z0', alpha=0.6);
hist(codes[:,1], 64, label='z1', alpha=0.6);
legend();

We see that the distribution of latent variables is very irregular, but we can still try to fit it with a Gaussian, and then samples new digits from that Gaussian.

In [None]:
z_mu = np.median(codes,axis=0); z_std = np.std(codes,axis=0)

In [None]:
# Here we sample some new digits
latent_samples = z_mu +  z_std * randn(16,2)

In [None]:
# Let's see where our new samples fall in the latent distribution
scatter(codes[:,0],codes[:,1],c=batch_labels,cmap='tab10'); colorbar()
scatter(latent_samples[:,0], latent_samples[:,1], marker='+',c='r');
# The red points are our new samples

In [None]:
# and now we can decode them
batch_samples = tf.reshape(decoder(latent_samples), (4,4,28,28))
# Reshape into one giant image
fig = batch_samples.numpy().transpose((0,2,1,3)).reshape((4*28,4*28))
# Let's see the result
imshow(fig, cmap='gray'); axis('off');

This is not too bad! We recognize some digits, and they are roughly of similar quality as the 
auto-encoded results. **This means that the latent space is fairly regular**, we can sample from it without too much care, but the quality of samples is not great.

#### Increasing the latent space dimensionality

As mentioned previously, one option to improve the quality of samples is to increase the dimensionality of the latent space. In this section we are going to try to increase it to 10, and see two things: 
  - How is the image quality affected
  - How is the regularity of the latent space affected

First step is to create a new auto-encoder with larger latent space:

In [None]:
encoder = get_encoder(10)
decoder = get_decoder(10)
auto_encoder = tf.keras.Sequential([
    tfkl.InputLayer([28,28,1]),
    encoder,
    decoder])

In [None]:
# Let's check how our new auto-encoder looks like
auto_encoder.summary()

We see that our new encoder has a latent space of dimension 10.

Just as before, let's train it on our MNIST training set.

In [None]:
auto_encoder.compile(optimizer=tf.keras.optimizers.Adam(),
                    loss=tf.keras.losses.binary_crossentropy)

In [None]:
history10 = auto_encoder.fit(dset, epochs=20)  # Starts training both encoder and decoder
                                             # for 20 epochs

In [None]:
plot(history.history['loss'], label='latent_dim=2')
plot(history10.history['loss'], label='latent_dim=10')
xlabel('epoch')
ylabel('reconstruction loss')
legend();

We see that the loss now goes much lower than before. This is because the auto-encoder can preserve more information about the input at the latent level.

Let's see what the reconstruction quality looks like.

In [None]:
autoencoded_im = auto_encoder(batch_im)

In [None]:
# These are the input images
figure(figsize=(5,5))
for i in range(4):
  for j in range(4):
    subplot(4,4, i*4+j+1)
    imshow(batch_im[i*4+j,:,:,0],cmap='gray')
    axis('off')

In [None]:
# These are the reconstructed images
figure(figsize=(5,5))
for i in range(4):
  for j in range(4):
    subplot(4,4, i*4+j+1)
    imshow(autoencoded_im[i*4+j,:,:,0],cmap='gray')
    axis('off')

![](https://media1.tenor.com/images/c7b80e4cf9004b58ad9f3cfd5a3ab345/tenor.gif?itemid=15669873)

This is a looooot better than before :-D Great!

Now, let's see what the latent space looks like.


In [None]:
codes = encoder(batch_im) # Encodes the images
print("Our latent representation now has dimension d = %d"%codes.shape[1])
# Just like before, we can have a look at the latent space encoding, for instance
# along the 2 first dimensions
scatter(codes[:,0],codes[:,1],c=batch_labels,cmap='tab10'); colorbar();

Things are different compared to the d=2 case:    
  - The overall distribution of latent samples seems a lot more regular, i.e. looks like a Gaussian. 
  - The encoding for different digits appear to be overlapping.

So... **what will happen if we try to fit a Gaussian model to this latent space and sample from it?**, just like we did before in the previous section? Let's find out  

In [None]:
z_mu = np.median(codes,axis=0); z_std = np.std(codes,axis=0)
latent_samples = z_mu +  z_std * randn(16,10)

In [None]:
# Let's see where our new samples fall in the latent distribution
scatter(codes[:,0],codes[:,1],c=batch_labels,cmap='tab10'); colorbar()
scatter(latent_samples[:,0], latent_samples[:,1], marker='+',c='r',s=100);
# The red points are our new samples

In [None]:
# and now we can decode them
batch_samples = tf.reshape(decoder(latent_samples), (4,4,28,28))
# Reshape into one giant image
fig = batch_samples.numpy().transpose((0,2,1,3)).reshape((4*28,4*28))
# Let's see the result
imshow(fig, cmap='gray'); axis('off');

We get some digits.... but also some garbage...

![](https://media1.tenor.com/images/b386fbb5c9c59b3f7d690e6cdc9bb8fb/tenor.gif?itemid=14214249)

The quality of these samples is far from the quality of reconstructed images. 

**What we have gained in quality of auto-encoding, we have lost in regularity of latent space!** We can no longer sample decent digits just by drawing from a Gaussian.

Ideally, we would want a way to train the model with a penalty that would force it to make the latent space look like a Gaussian. This is exactly what a Variational Auto-Encoder does!



## Building a Keras & TFP Variational Auto-Encoder

Having built some insight on what happens in an auto-encoder at the previous section, we will now try to improve on a simple Auto-Encoder by implementing a Variational Auto-Encoder.

A key library that we will use in this section is the **excellent** [TensorFlow Probabilty](https://www.tensorflow.org/probability) library. In particular we are going to use the TFP probabilistic keras layers.

### Implementing the recognition model

In the VAE framework, the encoder is also known as a `recognition model`, it is very similar to our traditional encoder, but instead of outputing a code, it outputs a **distribution over possible codes**, a.k.a a posterior distribution. We will ask explicitly to the model to penalize departures of this distributions from a standard Gaussian.

Here is how we can modify our original encoder with some TFP magic to turn the output into a distribution:

In [None]:
def get_probabilistic_encoder(latent_dim=2):
  """ Creates a small convolutional encoder for the requested latent dimension
  """
  # We choose a prior distribution for the latent codes
  prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_dim))

  return tf.keras.Sequential([ 
      tfkl.Input(shape=(28,28,1)),
      tfkl.Conv2D(32, kernel_size=3, activation='relu', strides=2, padding='same'),
      tfkl.Conv2D(64, kernel_size=3, activation='relu', strides=2, padding='same'),
      tfkl.Flatten(),
      tfkl.Dense(128, activation='relu'),
      # We ask this layer to output a vector of size equal to the number of
      # parameters required to define a Multivariate Gaussian
      tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(latent_dim)),
      # At the last layer, we ask the model to output a **distribution**
      # In this case, a Multivariate Normal
      tfpl.MultivariateNormalTriL(latent_dim, 
              # And we specify a regularization for this distribution, used
              # during training, we want the KL divergence with the prior 
              # to be small, i.e. the encoded distribution should be close to a 
              # standard Gaussian
              activity_regularizer=tfpl.KLDivergenceRegularizer(prior))
      ])

Let's try to instantiate this encoder, and encode some images, to see what happens.

In [None]:
prob_encoder = get_probabilistic_encoder(latent_dim=10)

In [None]:
prob_encoder.summary()

In [None]:
for batch_im, batch_target in dset.take(1): # Sample only one batch of images
  batch_encoded = prob_encoder(batch_im)         # Apply the encoder on images 

Let's inspect what `batch_encoded` is:

In [None]:
batch_encoded

We see that this is an instance of a `tfp.distributions.MultivariateNormalTriL`, this is a distribution!

We can manipulate it in different ways:


In [None]:
# We can draw samples from it
batch_encoded.sample()

In [None]:
# We can retrieve the mean
batch_encoded.mean()

### Implementing the generator

Now that we have a recognition model, we want to implemenent the VAE equivalent of the decoder, aka the `generator`.

The model will be very similar to the decoder, but the difference is that we are going to need to assume a likelihood $p(x | z)$ for our generator. In the case of binary data, an obvious choice is to use a [Bernoulli Distribution](https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Bernoulli).

Here is how we can do that using TFP magic again:

In [None]:
def get_probabilistic_decoder(latent_dim=2):
  """ Creates a small convolutional decoder for the requested latent dimension
  """
  return tf.keras.Sequential([
      tfkl.Input(shape=(latent_dim,)),
      tfkl.Dense(7*7*64, activation='relu'),
      tfkl.Reshape((7,7,64)),
      tfkl.Conv2DTranspose(64, kernel_size=3, activation='relu', strides=2, padding='same'),
      tfkl.Conv2DTranspose(32, kernel_size=3, activation='relu', strides=2, padding='same'),
      tfkl.Conv2DTranspose(1, kernel_size=3, activation=None, strides=1, padding='same') ,
      tfkl.Flatten(),
      # We ask the model to output a Bernoulli distribution with shape [28x28x1]
      tfpl.IndependentBernoulli([28,28,1])                
  ])

In [None]:
# Let's instantiate the decoder
prob_decoder = get_probabilistic_decoder(latent_dim=10)
# And check its summary
prob_decoder.summary()

And we see that the model outputs a Bernoulli distribution, so a distribution of images, not a single image. 

Let's try to decode a random sample of our encoded images:

In [None]:
# Draw a radom sample of the code
code_sample = batch_encoded.sample()
# And decode that sample
decoded_im = prob_decoder(code_sample)

And let's inspect what we obtain:

In [None]:
decoded_im

This is again a distribution, so for instance, we might want to retrieve the mean, or a random sample.

In [None]:
figure(figsize=(9,3))
subplot(131)
imshow(batch_im[0,:,:,0],cmap='gray'); axis('off')
title("Input Image")
subplot(132)
imshow(decoded_im.sample()[0,:,:,0],cmap='gray'); axis('off')
title("Sample from generator output")
subplot(133)
imshow(decoded_im.mean()[0,:,:,0],cmap='gray'); axis('off')
title("Mean of generator output");

### Putting it all together: building a VAE

Now that we have a *regularized* recognition model and a generator, we can combine them into a single Keras VAE.


Let's start by building the model, by concatenating both models:


In [None]:
vae = tf.keras.Sequential([
          tfkl.InputLayer([28,28,1]),
          prob_encoder,
          prob_decoder])

In [None]:
vae.summary()

Let's just check what happens if we feed the VAE a same batch of images several times:

In [None]:
figure(figsize=(9,3))
subplot(131)
samples = vae(batch_im)
imshow(samples.mean()[0,:,:,0],cmap='gray'); axis('off');
title('run 1');
subplot(132)
samples = vae(batch_im)
imshow(samples.mean()[0,:,:,0],cmap='gray'); axis('off');
title('run 2');
subplot(133)
samples = vae(batch_im)
imshow(samples.mean()[0,:,:,0],cmap='gray'); axis('off');
title('run 3');

We obtain different images, because every times we run the VAE model, a different sample from latent space is used.

Ok, let's try to train the model, so that we can do more interesting things. We need to compile it with appropriate losses. The KL divergence on the recognition model will automatically be applied, as we specified it during construction. We will just need to tell Keras how to compute the *reconstruction loss* i.e. the likelihood of the input data under the generator. 



In [None]:
# We define the reconstruction loss as the negative log likelihood
negloglik = lambda x, rv_x: -rv_x.log_prob(x)
# And use it to compile the VAE
vae.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),
            loss=negloglik)

Ok, now we train:

In [None]:
historyVAE = vae.fit(dset, epochs=30)

In [None]:
plot(historyVAE.history['loss']);
ylabel('ELBO')
xlabel('epoch')

Ok, neat, it's training. Let's try to see an example of auto-encoding on the testing set:

In [None]:
for batch_im, batch_labels in dset_test.take(1):
  # This extracts one batch of the test dset and shows the first example
  imshow(batch_im[0,:,:,0],cmap='gray')
  title('This is a %d'%batch_labels[0])

In [None]:
autoencoded_im = vae(batch_im) # Run the input batch through the model

In [None]:
subplot(121)
# Plot the mean of the output Bernoulli distribution
imshow(autoencoded_im.mean()[0,:,:,0],cmap='gray'); axis('off'); 
title('Mean output')
subplot(122)
# Plot a random sample of the output Bernoulli distribution
imshow(autoencoded_im.sample()[0,:,:,0],cmap='gray'); axis('off');
title('Sample output');

We see that the model has indeed learned to auto-encode images. The plot above illustrates once again that the output of the model is a distribution, we may choose to look at its mean, or a sample from it.

To assess the  general quality of the auto-encoded images, let's look at a few examples:

In [None]:
# These are the input images
figure(figsize=(5,5))
for i in range(4):
  for j in range(4):
    subplot(4,4, i*4+j+1)
    imshow(batch_im[i*4+j,:,:,0],cmap='gray')
    axis('off')

In [None]:
# These are the reconstructed images
figure(figsize=(5,5))
for i in range(4):
  for j in range(4):
    subplot(4,4, i*4+j+1)
    imshow(autoencoded_im.mean()[i*4+j,:,:,0],cmap='gray')
    axis('off')

Pretty good!

#### Sampling from the generative model

The main reason for using a VAE is that we can use it as a proper generative model. For that, we draw from the latent space prior, and forward the samples through the generator.

Let's start by sampling some latent codes from the prior:


In [None]:
latent_samples = tfd.MultivariateNormalDiag(loc=tf.zeros(10)).sample(16)

Then, we forward these samples through the model:

In [None]:
image_samples = prob_decoder(latent_samples)

In [None]:
# Grab the mean images, and reshape them into one giant image
fig = image_samples.mean().numpy().reshape((4,4,28,28))
fig = fig.transpose((0,2,1,3)).reshape((4*28,4*28))
# Let's see the result
imshow(fig, cmap='gray'); axis('off');

And that's it! We have sampled new digits from a VAE. Compared to what we obtained in the AE case, these are a lot less garbage than before, illustrating that the latent space of the VAE is a lot more regular, even in the case d=10.

Also note that we didn't even have to look at the distrbution of the  latent space to draw these samples. Thanks to the KL regularization, the latent space naturally tries to follow a standard Gaussian. We will take a look at that in the next section.

#### Investigating the VAE latent space

Let's check what the latent space of the VAE looks like. First, let's encode some images:

In [None]:
codes = prob_encoder(batch_im) # Run the input batch through the model
codes_smpl = codes.sample() # Remember, latent codes are distributions, we draw one example

# Just like before, we can have a look at the latent space encoding, for instance
# along the 2 first dimensions
scatter(codes_smpl[:,0], codes_smpl[:,1], 
        c=batch_labels,cmap='tab10'); colorbar();

This looks very Gaussian, we can also look at the marginal distributions along each latent space axes:

In [None]:
for i in range(10):
  hist(codes_smpl[:,i], 64, range=[-3,3],alpha=0.2);

All of our latent space dimensions look Gaussian as expected. One last thing that we can look at, is the latent space distribution predicted by the recognition model for a single image:


In [None]:
codes_smpls = codes.sample(1000)

In [None]:
# The shape of these samples is [n_samples, batch_size, d]
codes_smpls.shape

In [None]:
figure(figsize=[5,5])

subplot(221)
imshow(batch_im[0,:,:,0],cmap='gray'); axis('off')
title('im 1')
subplot(222)
imshow(batch_im[1,:,:,0],cmap='gray'); axis('off')
title('im 2')
subplot(223)
hist2d(codes_smpls[:,0,0], codes_smpls[:,0,1],64, range=[[-3,3],[-3,3]]); gca().set_aspect('equal');
xlabel('z0')
ylabel('z1')
title('posterior im 1')
subplot(224)
hist2d(codes_smpls[:,1,0], codes_smpls[:,1,1],64, range=[[-3,3],[-3,3]]); gca().set_aspect('equal');
xlabel('z0')
ylabel('z1')
title('posterior im 2');

This plot shows that 2 different **images get encoded into entire regions** of latent space. We are only looking at the first two dimensions of the latent space here, but you can have a look a the other dimensions.

## Conclusion

In this notebook we have highlighted the fundamental difference between an auto-encoder and variational auto-encoder, that is the regularisation of latent space.

We have seen all of the fundamentals of how to build a VAE. To go beyond this toy example, you would just need to add more convolution layers to both encoder and decoder, but everything else remains the same.
