# Convolutional Variational Autoencoder

## Setup

In [None]:
!pip install matplotlib
!pip install tensorflow
!pip install tensorflow-probability

# to generate gifs
!pip install imageio

!pip install git+https://github.com/tensorflow/docs

In [None]:
from IPython import display

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import layers
import time
import tensorflow_docs.vis.embed as embed

## GLOBAL SETTINGS

In [None]:
# GENERAL
EPOCHS = 20

# set the dimensionality of the latent space to a plane for visualization later
LATENT_DIM_SIZE = 50
EXAMPLE_COUNT_PREVIEW = 16

# TRAINING PARAMETER
TRAIN_SIZE = 60000
BATCH_SIZE = 32
TEST_SIZE = 10000

# WEIGHT LOADING / SAVING
TRAIN_WEIGHTS = True
LOAD_WEIGHTS = True
LOAD_FROM_EPOCHS = 1
LOAD_FROM_NAME = "DEFAULT"
SAVE_TO_NAME = "DEFAULT"
SAVE_INTERVAL = 1

## METHODS / CLASSES
### PREPROCESSING

In [None]:
def preprocess_images(images):
    normalization_layer = layers.experimental.preprocessing.Rescaling(scale= 1./255)
    return images.map(lambda x: normalization_layer(x))

### MODEL

In [None]:
class CVAE(tf.keras.Model):
  """Convolutional variational autoencoder."""

  def __init__(self, LATENT_DIM_SIZE):
    super(CVAE, self).__init__()
    self.LATENT_DIM_SIZE = LATENT_DIM_SIZE
    self.encoder = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(256, 256, 3)),
            tf.keras.layers.Conv2D(
                filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
            tf.keras.layers.Conv2D(
                filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
            tf.keras.layers.Conv2D(
                filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
            tf.keras.layers.Flatten(),
            # No activation
            tf.keras.layers.Dense(LATENT_DIM_SIZE + LATENT_DIM_SIZE),
        ]
    )

    self.decoder = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(LATENT_DIM_SIZE,)),
            tf.keras.layers.Dense(units=8*8*64, activation=tf.nn.relu),
            tf.keras.layers.Reshape(target_shape=(8, 8, 64)),
            tf.keras.layers.Conv2DTranspose(
                filters=64, kernel_size=3, strides=2, padding='same',
                activation='relu'),
            tf.keras.layers.Conv2DTranspose(
                filters=32, kernel_size=3, strides=2, padding='same',
                activation='relu'),
            tf.keras.layers.Conv2DTranspose(
                filters=32, kernel_size=3, strides=2, padding='same',
                activation='relu'),
            tf.keras.layers.Conv2DTranspose(
                filters=32, kernel_size=3, strides=2, padding='same',
                activation='relu'),
            # No activation
            tf.keras.layers.Conv2DTranspose(
                filters=3, kernel_size=3, strides=2, padding='same'),
        ]
    )

  @tf.function
  def sample(self, eps=None):
    if eps is None:
      eps = tf.random.normal(shape=(100, self.LATENT_DIM_SIZE))
    return self.decode(eps, apply_sigmoid=True)

  def encode(self, x):
    mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
    return mean, logvar

  def reparameterize(self, mean, logvar):
    eps = tf.random.normal(shape=mean.shape)
    return eps * tf.exp(logvar * .5) + mean

  def decode(self, z, apply_sigmoid=False):
    logits = self.decoder(z)
    if apply_sigmoid:
      probs = tf.sigmoid(logits)
      return probs
    return logits

## HELPER LOSS COMPUTATION

In [None]:
def log_normal_pdf(sample, mean, logvar, raxis=1):
  log2pi = tf.math.log(2. * np.pi)
  return tf.reduce_sum(
      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
      axis=raxis)


def compute_loss(model, x):
  mean, logvar = model.encode(x)
  z = model.reparameterize(mean, logvar)
  x_logit = model.decode(z)
  cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
  logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
  logpz = log_normal_pdf(z, 0., 0.)
  logqz_x = log_normal_pdf(z, mean, logvar)
  return -tf.reduce_mean(logpx_z + logpz - logqz_x)

## TRAINING

In [None]:
@tf.function
def train_step(model, x, optimizer):
  """Executes one training step and returns the loss.

  This function computes the loss and gradients, and uses the latter to
  update the model's parameters.
  """
  with tf.GradientTape() as tape:
    loss = compute_loss(model, x)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

## IMAGE HELPER

In [None]:
def generate_and_save_images(model, epoch, test_sample):
  mean, logvar = model.encode(test_sample)
  z = model.reparameterize(mean, logvar)
  print(z[0])
  predictions = model.sample(z)
  fig = plt.figure(figsize=(4, 4))

  '''for i in range(predictions.shape[0]):
    plt.subplot(4, 4, i + 1)
    plt.imshow(predictions[i, :, :, 0])
    plt.axis('off')'''
  for i in range(predictions.shape[0]):
    ax = fig.add_subplot(4, 4, i+1)
    ax.axis('off')
    pred = predictions[i, :, :, :] * 255
    pred = np.array(pred)  
    pred = pred.astype(np.uint8)
    
    ax.imshow(pred)

  # tight_layout minimizes the overlap between 2 sub-plots
  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))


def plot_latent_images(model, n, digit_size=256):
  """Plots n x n digit images decoded from the latent space."""

  norm = tfp.distributions.Normal(0, 1)
  grid_x = norm.quantile(np.linspace(0.05, 0.95, n))
  grid_y = norm.quantile(np.linspace(0.05, 0.95, n))
  image_width = digit_size*n
  image_height = image_width
  image = np.zeros((image_height, image_width, 3))

  for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
      z = np.full((1, 50), xi)
      x_decoded = model.sample(z)
      #digit = tf.reshape(x_decoded[0], (digit_size, digit_size))
      digit = tf.reshape(x_decoded[0], (digit_size, digit_size, 3))
      image[i * digit_size: (i + 1) * digit_size,
            j * digit_size: (j + 1) * digit_size] = digit.numpy()

  plt.figure(figsize=(10, 10))
  plt.imshow(image)
  plt.axis('Off')
  plt.show()

# SCRIPT
## LOAD AND PRERPOCESS DATA

In [None]:
if TRAIN_WEIGHTS:
	train_images = tf.keras.preprocessing.image_dataset_from_directory(
		  'data/cartoonset10k',
		  image_size=(256, 256),
		  batch_size=BATCH_SIZE,
		  label_mode=None)
	test_images = tf.keras.preprocessing.image_dataset_from_directory(
		  'data/cartoonsetTest',
		  image_size=(256, 256),
		  batch_size=BATCH_SIZE,
		  label_mode=None)
	
	train_dataset = preprocess_images(train_images)
	test_dataset = preprocess_images(test_images)

## CREATE AND PRINT MODEL

In [None]:
model = CVAE(LATENT_DIM_SIZE)
model.encoder.summary()
model.decoder.summary()

starting_epoch = 0

## LOAD EXISTING WEIGHTS

In [None]:
if LOAD_WEIGHTS:
    model.load_weights("./model_weights/model_weights_epochs_" + str(LOAD_FROM_EPOCHS) + "_latentDimSize_" + str(LATENT_DIM_SIZE) + "_" + LOAD_FROM_NAME)
    starting_epoch = LOAD_FROM_EPOCHS

## DEFINE OPTIMIZER

In [None]:
optimizer = tf.keras.optimizers.Adam(1e-4)

## ASSURE ENOUGH IMAGES IN BATCH FOR EXAMPLE PREVIEW

In [None]:
assert BATCH_SIZE >= EXAMPLE_COUNT_PREVIEW

## Pick a sample of the test set for generating output images

In [None]:
if (TRAIN_WEIGHTS):
  for test_batch in test_dataset.take(1):
    test_sample = test_batch[0:EXAMPLE_COUNT_PREVIEW, :, :, :]
    generate_and_save_images(model, 0, test_sample)

## TRAIN WEIGHTS

In [None]:
if (TRAIN_WEIGHTS):
  for epoch in range(1 + starting_epoch, EPOCHS_TO_LEARN + starting_epoch + 1):
    start_time = time.time()
    for train_x in train_dataset:
      train_step(model, train_x, optimizer)
    end_time = time.time()
    loss = tf.keras.metrics.Mean()
    for test_x in test_dataset:
      loss(compute_loss(model, test_x))
    elbo = -loss.result()
    display.clear_output(wait=False)
    print('Epoch: {}, Test set ELBO: {}, time elapse for current epoch: {}'
          .format(epoch, elbo, end_time - start_time))
    generate_and_save_images(model, epoch, test_sample)
    
    # SAVE WEIGHTS IF INTERVAL IS REACHED
    if epoch % SAVE_INTERVAL == 0:
      model.save_weights("./model_weights/model_weights_epochs_" + str(epoch) + "_latentDimSize_" + str(LATENT_DIM_SIZE) + "_" + SAVE_TO_NAME)

## Display a generated image from the last training epoch

In [None]:
plt.imshow(display_image(starting_epoch + EPOCHS_TO_LEARN))
plt.axis('off')  # Display images

## Display an animated GIF of all the saved images

In [None]:
anim_file = 'cvae.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)

embed.embed_file(anim_file)

## Display a 2D manifold of digits from the latent space

Running the code below will show a continuous distribution of the different digit classes, with each digit morphing into another across the 2D latent space. Use [TensorFlow Probability](https://www.tensorflow.org/probability) to generate a standard normal distribution for the latent space.

In [None]:
plot_latent_images(model, 20)

## TESTING TO GENERATE THE MIDDLE OF TWO LATENT VECTORS

In [None]:
latent_vector_1 = np.array([[1.02536201e+00,-2.94102907e-01,1.64790154e+00,5.03614378e+00
,-1.36321330e+00,2.76926208e+00,2.62857294e+00,-2.85130680e-01
,5.45753288e+00,-5.88454783e-01,-3.84467030e+00,-1.64728820e+00
,3.35834241e+00,2.38665032e+00,4.59312290e-01,4.03293180e+00
,-2.09909391e+00,1.86549067e+00,2.92326331e+00,1.43550122e+00
,1.09568393e+00,6.20597601e-01,1.06155002e+00,-1.24831879e+00
,2.43738413e+00,2.81176567e+00,-2.76000261e+00,4.34566402e+00
,2.15681386e+00,3.06898761e+00,1.10142100e+00,-9.84009326e-01
,-2.62474108e+00,5.17143011e-01,-2.81945992e+00,-1.52866042e+00
,-1.58329403e+00,5.65068102e+00,2.48366165e+00,9.66846406e-01
,2.71668248e-02,-3.77461410e+00,-1.64389789e+00,-1.25075352e+00
,1.16503072e+00,1.71287203e+00,8.78006816e-01,3.15090984e-01
,3.11333919e+00,3.72572333e-01]])
latent_vector_2 = np.array([[3.11055994e+00,2.77322316e+00,-2.35439253e+00,2.41457534e+00
,1.65251720e+00,-1.03567481e+00,2.88300347e+00,-1.77975857e+00
,2.07237649e+00,1.63250697e+00,-2.08729815e+00,-2.19439089e-01
,3.46109211e-01,5.82732618e-01,9.14461792e-01,-2.96111321e+00
,4.46764648e-01,3.61388826e+00,5.34072161e+00,2.61096501e+00
,2.48561978e+00,-2.73324752e+00,-2.11599898e+00,-1.75424963e-02
,-1.14708841e+00,-4.10792542e+00,-3.23143697e+00,2.21529269e+00
,2.15655327e+00,-3.28153896e+00,2.36182594e+00,-9.16283727e-01
,-1.41235209e+00,-6.11949027e-01,-8.87000144e-01,-2.13673401e+00
,-1.98448431e+00,9.77515399e-01,2.29458719e-01,4.20075655e-02
,7.69120634e-01,2.98409581e+00,1.02206600e+00,2.46658057e-01
,-3.60222936e-01,1.56427526e+00,-1.47916779e-01,2.49203157e+00
,1.60648060e+00,1.21420026e+00]])
random_vector = np.random.normal(0., 0.75, size=(1, 50))
random_vector_2 = (np.random.rand(1, 50) - 0.5) * 4
print(random_vector_2)
combined_vector = (latent_vector_1 * 0.2 + latent_vector_2 * 0.8)
print(latent_vector_1)
pred = model.decode(combined_vector, True)[0, :, :, :] * 255
pred = np.array(pred)  
pred = pred.astype(np.uint8)
    
plt.imshow(pred)