# Variational Autoencoders on multiple datasets

A convolutional VAE is used to generate images from some of the datasets in `tensorflow_datasets`. The model is able to handle RGB and gray-scale images. Some generation and reconstruction examples gifs are presented at the end of the notebook for a few datasets.

In [0]:
!pip install -q imageio # gif generation

In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals

try:
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf

import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio
import tensorflow_datasets as tfds
import tensorflow.keras.layers as layers

from google.colab import files

from IPython import display

TensorFlow 2.x selected.


An image dataset is loaded into training and testing sets. If no testing set exists for the dataset 500 random training samples are used in place.

In [0]:
DATASET = 'horses_or_humans'

train_dataset = tfds.load(name=DATASET, split="train")

try:
  test_dataset = tfds.load(name=DATASET, split="test")
  print("Loaded test set.")

except AssertionError:
  test_dataset = train_dataset.shuffle(1000).take(500)
  print("No test set, generating from train.")

Our images should be a square and preferably a power of 2 for easy convolutions. 

`CHANNELS = 3` for RGB, `CHANNELS = 1` for gray-scale.

`preprocess` resizes the images and normalizes between [0, 1].

In [0]:
IMG_X, IMG_Y = 128, 128  # desired X, Y size for images

CHANNELS = 3

def preprocess(element):
  # Resize and normalize
  element['image'] = tf.image.resize(element['image'], size=(IMG_X,IMG_Y))  
  element['image'] = tf.cast(element['image'], dtype = tf.float32) / 255.
  
  return element

In [0]:
train_dataset = train_dataset.map(preprocess)
test_dataset = test_dataset.map(preprocess)

In [0]:
train_dataset = train_dataset.shuffle(1000).batch(100)
test_dataset = test_dataset.batch(100)

In [0]:
def save_images(fig, init_str, end_str):
  # saves a pyplot figure
  plt.savefig('{}_at_epoch_{:04d}.png'.format(init_str, end_str))

### Encoder model

The encoder receives the input and performs sequential convolutions to reduce the signal size. A kernel size of 3 and stride of 2 ensures we halve the signal size. That is, a 64 x 64 image will be reduced to 32 x 32.

The extracted features are flattened and passed to a standard fully connected layer, this is a low dimensional representation of the input data (VAEs encode each input as a distribution over this latent space - this ensures the latent space is regular, as opposed to traditional autoencoders). 

In [0]:
class encoder(tf.keras.Model):
  
  def __init__(self, X, Y, Channels):
    super(encoder, self).__init__()

    self.x = X
    self.y = Y
    self.channels = Channels

    if self.channels==1: 
      self.inputLayer = layers.InputLayer(input_shape=(self.x, self.y))
    else:
      self.inputLayer = layers.InputLayer(input_shape=(self.x, self.y, self.channels))

    self.conv1 = layers.Conv2D(
              filters=32, kernel_size=3, strides=(2,2), activation='relu')
    self.conv2 = layers.Conv2D(
              filters=64, kernel_size=3, strides=(2,2), activation='relu')
    self.flatten = layers.Flatten()
    self.dense_in = layers.Dense(latent_dim + latent_dim)

  def call(self, x): # forward pass
    x = self.conv1(self.inputLayer(x))
    x = self.conv2(x)
    x = self.dense_in(self.flatten(x))
    return x


### Decoder Model

The decoder receives a sampled low dimensional representation from the latent space and attempts to reproduce the original input through sequential deconvolutions. The output will have the same shape as the input.

In [0]:
class decoder(tf.keras.Model):
  
  def __init__(self, X, Y, Channels):
    super(decoder, self).__init__()

    self.x = X
    self.y = Y
    self.channels = Channels

    self.inputLayer = layers.InputLayer(input_shape=(latent_dim,))

    # Image has been halved twice by the encoder
    self.dense_out = layers.Dense(units=(self.x//4)*(self.y//4)*32,                           
                                  activation=tf.nn.relu)
    
    self.reshape = layers.Reshape(target_shape=(self.x//4, self.y//4, 32))

    self.convT1 = layers.Conv2DTranspose(filters=64, kernel_size=3, 
                                         strides=(2,2), padding="SAME", 
                                         activation='relu')
    
    self.convT2 = layers.Conv2DTranspose(filters=32, kernel_size=3,
                                         strides=(2,2), padding="SAME", 
                                         activation='relu')
    
    # Output same number of input channels (3-RGB 1-Gray)
    self.convT3 = layers.Conv2DTranspose(filters=self.channels, kernel_size=3, 
                                         strides=(1, 1), padding="SAME")

  def call(self, x): # forward pass
    x = self.dense_out(self.inputLayer(x))
    x = self.convT1(self.reshape(x))
    x = self.convT2(x)
    x = self.convT3(x)

    return x

### VAE

The VAE is simply the encoder decoder pair. An input is passed to the VAE and encoded as a distribution over the latent space. We sample a point from this distribution and decode it, allowing reconstruction error to be computed.

The network is regularised by enforcing distributions to be close to a [standard normal distribution](https://stattrek.com/statistics/dictionary.aspx?definition=standard_normal_distribution), resulting in a continuous latent space. This space can be visualised nicely by setting the latent dimension to a size of 2.

To save a visualization of the input and reconstructed outputs, set `save_reconstruction` to `True`.

In [0]:
class CVAE(tf.keras.Model):

  def __init__(self, X, Y, Channels, latent_dim, save_reconstruction):
    super(CVAE, self).__init__()

    self.latent_dim = latent_dim  # Size of latent space, 
    self.x = X  # X size of image
    self.y = Y  # Y size of image
    self.channels = Channels  # Image channels
    self.save = save_reconstruction # Save reconstruction of inputs

    self.encoder = encoder(self.x, self.y, self.channels)
    self.decoder = decoder(self.x, self.y, self.channels)

  @tf.function
  def sample(self, eps=None):
    # Generate random vector from latent space and decode 
    # for generating new images
    if eps is None:
      eps = tf.random.normal(shape=(100, self.latent_dim))
    return self.decode(eps, apply_sigmoid=True)

  def encode(self, x):
    # Mean and logvariance of the encoded input are taken for sampling
    mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
    return mean, logvar

  def reparameterize(self, mean, logvar):
    # Allows for gradient descent: sample from Gauss., multiply by std, add mean
    eps = tf.random.normal(shape=mean.shape)
    return eps * tf.exp(logvar * .5) + mean

  def decode(self, z, apply_sigmoid=False):
    logits = self.decoder(z)
    if apply_sigmoid:
      probs = tf.sigmoid(logits)
      return probs
    return logits

### [Kullback–Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) (KL divergence)

 Let `X` be our data, `P(X)` be its probability distribution, `P(z)` be the probability distribution of the latent variable. `P(z|X)` is the distribution which projects our data into the latent space, this is estimated by a distribution `Q`. During training, the VAE aims to learn a `Q(z|X)` that is as close to the actual distribution `P(z|X)`. The KL divergence is the measure of difference between these distributions.

We now define the loss function and gradient propagation, code taken from [here](https://www.tensorflow.org/tutorials/generative/cvae).

In [0]:
optimizer = tf.keras.optimizers.Adam(1e-4)

def log_normal_pdf(sample, mean, logvar, raxis=1):
  log2pi = tf.math.log(2. * np.pi)
  return tf.reduce_sum(
      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
      axis=raxis)

@tf.function
def compute_loss(model, x):
  mean, logvar = model.encode(x)
  z = model.reparameterize(mean, logvar)
  x_logit = model.decode(z)

  # Calculate distributions

  cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
  logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
  logpz = log_normal_pdf(z, 0., 0.)
  logqz_x = log_normal_pdf(z, mean, logvar)

  # Maximizing ELBO <==> Minimizing KL

  return -tf.reduce_mean(logpx_z + logpz - logqz_x) # Monte Carlo estimate of ^

@tf.function
def compute_apply_gradients(model, x, optimizer):

  # Prop gradients, optimize

  with tf.GradientTape() as tape:
    loss = compute_loss(model, x)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

We are now ready to train the model.

`latent_dim` is the size of the latent space, 
`num_examples` is the number of new images to generate (edit axis in `generate_images` if you change from 16).

`generate_images` passes the random vector `gen_vect` to the model to generate new images. These are plotted and saved for animation later.

`reconstruct_save` plots and saves input and reconstructed outputs.

In [0]:
epochs = 200  
latent_dim = 100
num_examples = 16

# Random vect for generation visualization

gen_vect = tf.random.normal(shape=[num_examples, latent_dim]) 

model = CVAE(IMG_X, IMG_Y, CHANNELS, latent_dim, True)


In [0]:
def generate_images(model, epoch, test_input):
  
  # Generates new images with model, plots and saves them

  preds = model.sample(test_input)
  fig = plt.figure(figsize=(10,10))

  for i in range(preds.shape[0]):
      plt.subplot(4, 4, i+1)
      if CHANNELS==1:
        image = preds[i, :, :, 0]
        plt.imshow(image, cmap='gray')
      else:
        image = preds[i, :, :, :]
        plt.imshow(image, cmap=plt.cm.binary)
      plt.axis('off')

  save_images(fig, 'gen', epoch)
  plt.show()

In [0]:
def reconstruct_save(dataset, samples, epoch):

  # Save a set of input images and their reconstruction

  global model

  x = dataset
  mean, logvar = model.encode(x)
  z = model.reparameterize(mean, logvar)
  preds = model.decode(z, apply_sigmoid=True)
  fig = plt.figure(figsize=(10,10))
  idx = 0

  for i in range(0, x.shape[0]*2, 2): # To plot input and reconstruction side by side
    plt.subplot(samples, 2, i+1)
    if CHANNELS == 1:
      plt.imshow(tf.squeeze(x[idx]), cmap='gray')
      plt.axis('off')
      plt.subplot(samples, 2, i+2)
      plt.imshow(tf.squeeze(preds[idx]), cmap='gray')
    else:
      plt.imshow(x[idx])
      plt.axis('off')
      plt.subplot(samples, 2, i+2)
      plt.imshow(preds[idx])

    idx+=1
    plt.axis('off')

  save_images(fig, 'reconstructed', epoch)

We now generate our own images as the model trains. 

`samples_to_save` specifies how many reconstruction samples to save.

In [0]:
samples_to_save = 10

generate_images(model, 0, gen_vect)

if model.save:
  for batch in test_dataset.take(1):
    # To visualize reconstruction as training progresses
    visual_set = batch['image'][:samples_to_save] 

for epoch in range(1, epochs + 1):
  start_time = time.time()
  for train_x in train_dataset:
    # train
    compute_apply_gradients(model, train_x['image'], optimizer)
  end_time = time.time()

  if epoch % 1 == 0:
    loss = tf.keras.metrics.Mean()
    for test_x in test_dataset:
      # test
      loss(compute_loss(model, test_x['image']))
    elbo = -loss.result() 
    display.clear_output(wait=False)
    print('Epoch: {}, Test set ELBO: {}, '
          'time elapse for current epoch {}'.format(epoch, elbo,
                                                    end_time - start_time))
    generate_images(model, epoch, gen_vect)

  if model.save:
    reconstruct_save(visual_set, samples_to_save, epoch)

We can create a gif of our saved images for better visualization, `save_gif` saves a gif with name `name` formed from images beginning with `lead_str` (`lead_str` = `'reconstructed'` for the reconstructed images, `'gen'` for the generated images).

In [0]:
 def save_gif(name, lead_str): 

   # Saves gif from sequence of images

  anim_file = name+'.gif'

  with imageio.get_writer(anim_file, mode='I') as writer:
    filenames = glob.glob(lead_str+'*.png')
    filenames = sorted(filenames)
    last = -1
    for i,filename in enumerate(filenames):
      frame = 2*(i**0.5)
      if round(frame) > round(last):
        last = frame
      else:
        continue
      image = imageio.imread(filename)
      writer.append_data(image)
    image = imageio.imread(filename)
    writer.append_data(image)

  import IPython
  if IPython.version_info >= (6,2,0,''):
    display.Image(filename=anim_file)

  files.download(anim_file)

In [0]:
save_gif('horseGen', 'gen')

In [0]:
save_gif('horseRecon', 'reconstructed')

### Some generations and reconstructions on different datasets

Datasets were trained for 50-200 epochs.

Reconstruction of flowers from 'tf_flowers' dataset:

![alt text](https://i.imgur.com/HYSNRhL.gif)

Generation and reconstruction of Kuzushiji figures from the 'kmnist' dataset:



![kmnist_gen](https://imgur.com/81OzBLU.gif)



![kmnist_recon](https://i.imgur.com/GkT9kBi.gif)

Generation and reconstruction of horses and humans from 'horses_or_humans' dataset:

![horse_gen](https://imgur.com/vKs2w7D.gif)

![horse_recon](https://imgur.com/4j0dgjt.gif)

We may even download our own images and train the model on those. Using a script 200 images of Donald Trump were downloaded from Google Images. The model generated the following image after some training, an orange face in a suit:

![trump](https://imgur.com/Wh1jKMx.png)