# Variational autoencoders

Variational autoencoders can compress data and reconstruct the original data in a simplified form.

In [1]:
import keras
import numpy as np

Using TensorFlow backend.


In [2]:
from keras import layers
from keras import backend as K
from keras.models import Model

In [3]:
img_shape = (28,28,1)
batch_size = 16
latent_dim = 2 # the dimension of the latent space is a plane

In [4]:
input_img = keras.Input(batch_shape=(batch_size,) + img_shape)

# the encoder network

In [5]:
x = layers.Conv2D(32, 3, padding='same', activation='relu')(input_img)
x = layers.Conv2D(64, 3, padding='same', activation='relu', strides=(2,2))(x)
x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
shape_before_flattening = K.int_shape(x)

In [6]:
shape_before_flattening

(16, 14, 14, 64)

In [7]:
x = layers.Flatten()(x)
x = layers.Dense(32, activation='relu')(x)

In [8]:
shape_after_flattening = K.int_shape(x)
shape_after_flattening

(16, 32)

In [9]:
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)

# The sampling function

In [10]:
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(batch_size, latent_dim), mean = 0., stddev=1.)
    return z_mean + K.exp(z_log_var) * epsilon

z = layers.Lambda(sampling)([z_mean, z_log_var])

# The decoder network

In [11]:
np.prod(shape_before_flattening[1:])

12544

In [12]:
decoder_input = layers.Input(K.int_shape(z)[1:])

# upsampling to the original number of units ie: 12544
x = layers.Dense(np.prod(shape_before_flattening[1:]), activation='relu')(decoder_input)

#reshape
x = layers.Reshape(shape_before_flattening[1:])(x)
x = layers.Conv2DTranspose(32, 3, padding='same', activation='relu', strides=(2, 2))(x)
x = layers.Conv2D(1, 3, padding='same', activation='sigmoid')(x)

decoder = Model(decoder_input, x)

z_decoded = decoder(z)

# Custom loss layer

In [13]:
class CustomVariationalLayer(keras.layers.Layer):
    def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        xent_loss = keras.metrics.binary_crossentropy(x, z_decoded)
        kl_loss = -5e-4 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        return K.mean(xent_loss + kl_loss)
    
    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x, z_decoded)
        self.add_loss(loss, inputs=inputs)
        return x
    

In [14]:
y = CustomVariationalLayer()([input_img, z_decoded])

# Training

In [16]:
from keras.models import Model
from keras.datasets import mnist
import numpy as np

In [17]:
vae = Model(input_img, y)
vae.compile(optimizer='rmsprop', loss=None)
vae.summary()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_1 (InputLayer)             (16, 28, 28, 1)       0                                            
____________________________________________________________________________________________________
conv2d_1 (Conv2D)                (16, 28, 28, 32)      320         input_1[0][0]                    
____________________________________________________________________________________________________
conv2d_2 (Conv2D)                (16, 14, 14, 64)      18496       conv2d_1[0][0]                   
____________________________________________________________________________________________________
conv2d_3 (Conv2D)                (16, 14, 14, 64)      36928       conv2d_2[0][0]                   
___________________________________________________________________________________________

  


In [18]:
(x_train, _), (x_test, y_test) = mnist.load_data()

In [19]:
#scale
x_train = x_train.astype('float32')/255.
x_train = x_train.reshape(x_train.shape + (1,))

x_test = x_test.astype('float32')/255.
x_test = x_test.reshape(x_test.shape + (1,))

In [21]:
vae.fit(x_train,
       shuffle=True,
       epochs=10,
       batch_size=batch_size,
       validation_data=(x_test,x_test))

Train on 60000 samples, validate on 10000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f9884df6e80>

In [22]:
import matplotlib.pyplot as plt
%matplotlib inline
from scipy.stats import norm

In [23]:
n = 15
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))

In [24]:
grid_x = norm.ppf(np.linspace(0.05,0.95, n))
grid_y = norm.ppf(np.linspace(0.05,0.95, n))


In [None]:
for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi,yi]])
        z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
        x_decoded = decoder.predict