<a href="https://colab.research.google.com/github/adewoleopeyemi/Variational-Autoencoder/blob/master/Variational_Auto_Encoders.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#Importing import libraries
import keras 
from keras import layers
from keras import backend as k
from keras.models import Model
import numpy as np


img_shape = (28, 28, 1)
batch_size = 16
#Dimeansionality of the latent space a 2D plane
latent_dim = 2

input_img = keras.Input(shape = img_shape)

Using TensorFlow backend.


In [0]:
x = layers.Conv2D(32, 3, padding = "same", activation = "relu")(input_img)
x = layers.Conv2D(64, 3, padding = "same", activation = "relu", strides = (2, 2))(x)
x = layers.Conv2D(64, 3, padding = "same", activation = "relu")(x)
x = layers.Conv2D(64, 3, padding = "same", activation = "relu")(x)
shape_before_flattening = k.int_shape(x)

x = layers.Flatten()(x)
x = layers.Dense(32, activation = "relu")(x)

#The input image ends up being encoded into these two parameters
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x) 

In [0]:
def sampling(args):
  z_mean, z_log_var = args
  epsilon = k.random_normal(shape = (k.shape(z_mean)[0], latent_dim), mean = 0., stddev = 1.)

  return z_mean + k.exp(z_log_var) * epsilon
z = layers.Lambda(sampling)([z_mean, z_log_var])


In [0]:
decoder_input = layers.Input(k.int_shape(z)[1:])

#upsamples the input
x = layers.Dense(np.prod(shape_before_flattening[1:]), activation = "relu")(decoder_input)

#Reshapes z into a feature map of the same shape as the feature map just
#before the last Flatten layer in the encoder model
x = layers.Reshape(shape_before_flattening[1:])(x)

#Uses a Conv2DTranspose layer and Conv2D layer to decode z into a feature map
#the same size as the orignal image input
x = layers.Conv2DTranspose(32, 3, padding = "same", activation = "relu", strides = (2, 2))(x)


x = layers.Conv2D(1, 3,padding = "same", activation ="sigmoid")(x)

decoder = Model(decoder_input, x)

z_decoded = decoder(z)


In [0]:
class CustomVariationalLayer(keras.layers.Layer):
  def vae_loss(self, x, z_decoded):
    x = k.flatten(x)
    z_decoded = k.flatten(z_decoded)
    xent_loss = keras.metrics.binary_crossentropy(x, z_decoded)
    k1_loss = -5e-4 * k.mean(1 + z_log_var - k.square(z_mean) - k.exp(z_log_var), axis = -1)
    return k.mean(xent_loss + k1_loss)

  def call(self, inputs):
    x = inputs[0]
    z_decoded = inputs[1]
    loss = self.vae_loss(x, z_decoded)
    self.add_loss(loss, inputs = inputs)
    return x

y = CustomVariationalLayer()([input_img, z_decoded])

In [6]:
from keras.datasets import mnist
vae = Model(input_img, y)
vae.compile(optimizer = "rmsprop", loss = None)
vae.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 28, 28, 32)   320         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 14, 14, 64)   18496       conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 14, 14, 64)   36928       conv2d_2[0][0]                   
____________________________________________________________________________________________

  'be expecting any data to be passed to {0}.'.format(name))


In [7]:
(x_train, _), (x_test, y_test)  = mnist.load_data()

x_train = x_train.astype("float32") / 255.
x_train = x_train.reshape(x_train.shape + (1,))

x_test = x_test.astype("float32")/255.
x_test = x_test.reshape(x_test.shape + (1,))

vae.fit(x = x_train, y = None, shuffle = True, epochs = 10, batch_size = 128, validation_data = (x_test, None))


Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.callbacks.History at 0x7f24503414a8>

In [0]:
#visualization of images generated from latent space

import matplotlib.pyplot as plt
from scipy.stats import norm

#To display a 15 * 15 
n = 15
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

for i, yi in enumerate(grid_x):
  for j, xi in enumerate(grid_y):
    z_sample = np.array([xi, yi])
    z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
    x_decoded = decoder.predict