# Use a CVAE on MNIST

In [5]:
pip install -q tensorflow-probability

Note: you may need to restart the kernel to use updated packages.


In [6]:
from IPython import display

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
import tensorflow_probability as tfp
import time

In [7]:
#load data

(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [9]:
#set up preprocessing and training / testing batch sizes

def preprocess_images(images):
  images = images.reshape((images.shape[0], 28, 28, 1)) / 255.
  return np.where(images > .5, 1.0, 0.0).astype('float32')

train_images = preprocess_images(train_images)
test_images = preprocess_images(test_images)

train_size = 60000
batch_size = 32
test_size = 10000

#batch and shuffle
train_dataset = (tf.data.Dataset.from_tensor_slices(train_images)
                 .shuffle(train_size).batch(batch_size))
test_dataset = (tf.data.Dataset.from_tensor_slices(test_images)
                .shuffle(test_size).batch(batch_size))

### Define encoder/decoder networks
Two small ConvNets for the encoder and decoder networks.

x: observation variable
z: latent variable

Encoder:
Defines approximate posterior distribution q(z|x)
    Input: observation
    Output: set of parameters - conditional distribution of latent representation z
    
The distribution is modelled here as diagonal Gaussian, and network outputs the mean & log-variance (instead of variance for numerical stability) parameters of factorized Gaussian.

Decoder:
Defines conditional distribution of observation p(x|z)
    Input: latent sample z
    Output: parameters for distribution of observation
Model latent distribution prior p(z) as a unit gaussian.

Reparametrization trick:
Bottleneck exists when sampling from latent distribution - backpropogation cannot flow through a random node.

In this example: z = mu + sigma target looking thing epsilon

mu = mean; sigma = st. dev of a gaussian <- derived from decoder
epsilon: random noise (maintains stochasticity of z) generated from random distribution

### Network Architecture
encoder: 2 conv layers, fully connected layer
decoder: fully connected layer, 3 convolution transpose

In [10]:
class CVAE(tf.keras.Model):
  """Convolutional variational autoencoder."""

  def __init__(self, latent_dim):
    super(CVAE, self).__init__()
    self.latent_dim = latent_dim
    self.encoder = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
            tf.keras.layers.Conv2D(
                filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
            tf.keras.layers.Conv2D(
                filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
            tf.keras.layers.Flatten(),
            # No activation
            tf.keras.layers.Dense(latent_dim + latent_dim),
        ]
    )

    self.decoder = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
            tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),
            tf.keras.layers.Reshape(target_shape=(7, 7, 32)),
            tf.keras.layers.Conv2DTranspose(
                filters=64, kernel_size=3, strides=2, padding='same',
                activation='relu'),
            tf.keras.layers.Conv2DTranspose(
                filters=32, kernel_size=3, strides=2, padding='same',
                activation='relu'),
            # No activation
            tf.keras.layers.Conv2DTranspose(
                filters=1, kernel_size=3, strides=1, padding='same'),
        ]
    )

  @tf.function
  def sample(self, eps=None):
    if eps is None:
      eps = tf.random.normal(shape=(100, self.latent_dim))
    return self.decode(eps, apply_sigmoid=True)

  def encode(self, x):
    mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
    return mean, logvar

  def reparameterize(self, mean, logvar):
    eps = tf.random.normal(shape=mean.shape)
    return eps * tf.exp(logvar * .5) + mean

  def decode(self, z, apply_sigmoid=False):
    logits = self.decoder(z)
    if apply_sigmoid:
      probs = tf.sigmoid(logits)
      return probs
    return logits

### Define loss function and optimizer

VAE trains by maximising evidence lower bound (ELBO) on marginal log-likelihood.

