In [1]:
import numpy as np
import tensorflow as tf
from keras.layers import BatchNormalization, Conv2D, Dense, Flatten, Lambda, LeakyReLU
from keras.layers import Conv2DTranspose, Reshape
import keras.backend as k

In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
#Scaling
x_train = x_train / 255.0
x_test = x_test / 255.0

x_train = np.reshape(x_train, newshape=(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1)) 
x_test = np.reshape(x_test, newshape=(x_test.shape[0], x_train.shape[1], x_train.shape[2], 1))

In [4]:
#model
def build_encoder(image_shape = (28, 28, 1), latent_dim = 2):

  def sampeling(mu_logvar):
    mu, logvar = mu_logvar
    epsilon = k.random_normal(shape=k.shape(mu), mean=0, stddev=1)
    return mu + k.exp(1/2 * logvar) * epsilon

  input = tf.keras.Input(shape=image_shape)

  conv2d = Conv2D(1, kernel_size=(3, 3), strides=(1, 1), padding='same')(input)
  batch_norm = BatchNormalization()(conv2d)
  leaky = LeakyReLU(0.2)(batch_norm)

  conv2d = Conv2D(32, kernel_size=(3, 3), strides=(1, 1), padding='same')(leaky)
  batch_norm = BatchNormalization()(conv2d)
  leaky = LeakyReLU(0.2)(batch_norm)

  conv2d = Conv2D(64, kernel_size=(3, 3), strides=(2, 2), padding='same')(leaky)
  batch_norm = BatchNormalization()(conv2d)
  leaky = LeakyReLU(0.2)(batch_norm)

  conv2d = Conv2D(64, kernel_size=(3, 3), strides=(2, 2), padding='same')(leaky)
  batch_norm = BatchNormalization()(conv2d)
  leaky = LeakyReLU(0.2)(batch_norm)

  conv2d = Conv2D(64, kernel_size=(3, 3), strides=(1, 1), padding='same')(leaky)
  batch_norm = BatchNormalization()(conv2d)
  leaky = LeakyReLU(0.2)(batch_norm)

  flatten = Flatten()(leaky)
  mu = Dense(latent_dim)(flatten)
  logvar = Dense(latent_dim)(flatten)
  output = Lambda(sampeling)((mu, logvar))

  model = tf.keras.Model(inputs=input, outputs=output, name="encoder_model")
  model.summary()
  return model, mu, logvar


In [5]:
build_encoder()

Model: "encoder_model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 28, 28, 1)    10          ['input_1[0][0]']                
                                                                                                  
 batch_normalization (BatchNorm  (None, 28, 28, 1)   4           ['conv2d[0][0]']                 
 alization)                                                                                       
                                                                                                  
 leaky_re_lu (LeakyReLU)        (None, 28, 28, 1)    0           ['batch_normalization

(<keras.engine.functional.Functional at 0x7f1b9049a2d0>,
 <KerasTensor: shape=(None, 2) dtype=float32 (created by layer 'dense')>,
 <KerasTensor: shape=(None, 2) dtype=float32 (created by layer 'dense_1')>)

In [6]:
def build_decoder(latent_dim = 2):
  input = tf.keras.Input(shape=latent_dim)
  dense = Dense(3136)(input)
  reshaped_input = Reshape(target_shape=(7, 7, 64))(dense)

  conv_trans = Conv2DTranspose(64, kernel_size=(3, 3), strides=(1, 1), padding='same')(reshaped_input)
  batch_norm = BatchNormalization()(conv_trans)
  leaky = LeakyReLU(0.2)(batch_norm)

  conv_trans = Conv2DTranspose(64, kernel_size=(3, 3), strides=(2, 2), padding='same')(leaky)
  batch_norm = BatchNormalization()(conv_trans)
  leaky = LeakyReLU(0.2)(batch_norm)

  conv_trans = Conv2DTranspose(64, kernel_size=(3, 3), strides=(2, 2), padding='same')(leaky)
  batch_norm = BatchNormalization()(conv_trans)
  leaky = LeakyReLU(0.2)(batch_norm)  

  conv_trans = Conv2DTranspose(1, kernel_size=(3, 3), strides=(1, 1), padding='same')(leaky)
  output = LeakyReLU(0.2)(conv_trans)

  model = tf.keras.Model(inputs=input, outputs=output, name="decoder_model")
  model.summary()
  return model


In [7]:
build_decoder()

Model: "decoder_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 2)]               0         
                                                                 
 dense_2 (Dense)             (None, 3136)              9408      
                                                                 
 reshape (Reshape)           (None, 7, 7, 64)          0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 7, 7, 64)         36928     
 nspose)                                                         
                                                                 
 batch_normalization_5 (Batc  (None, 7, 7, 64)         256       
 hNormalization)                                                 
                                                                 
 leaky_re_lu_5 (LeakyReLU)   (None, 7, 7, 64)        

<keras.engine.functional.Functional at 0x7f1b903f0d10>

In [8]:
def loss_func(encoder_mu, encoder_logvar):
    def vae_reconstruction_loss(y_true, y_predict):
        reconstruction_loss_factor = 1000
        reconstruction_loss = k.mean(k.square(y_true-y_predict), axis=[1, 2, 3])
        return reconstruction_loss_factor * reconstruction_loss

    def vae_kl_loss(encoder_mu, encoder_logvar):
        kl_loss = -0.5 * k.sum(1.0 + encoder_logvar - k.square(encoder_mu) - k.exp(encoder_logvar), axis=1)
        return kl_loss

    def vae_kl_loss_metric(y_true, y_predict):
        kl_loss = -0.5 * k.sum(1.0 + encoder_logvar - k.square(encoder_mu) - k.exp(encoder_logvar), axis=1)
        return kl_loss

    def vae_loss(y_true, y_predict):
        reconstruction_loss = vae_reconstruction_loss(y_true, y_predict)
        kl_loss = vae_kl_loss(y_true, y_predict)

        loss = reconstruction_loss + kl_loss
        return loss

    return vae_loss

In [9]:
def vae(encoder, decoder, image_shape = (28, 28, 1)):

  input = tf.keras.Input(shape= image_shape, name="VAE_input")
  _, mu, logvar = build_encoder()

  encoder_output = encoder(input)
  decoder_output = decoder(encoder_output)

  vae = tf.keras.models.Model(inputs= input, outputs= decoder_output, name="VAE_model")
  vae.summary()

  vae.compile(optimizer=tf.keras.optimizers.Adam(lr=0.0005), loss=loss_func(mu, logvar))
  return vae


In [10]:
encoder, _, _ = build_encoder()
decoder = build_decoder()
vae_model = vae(encoder, decoder)

vae_model.fit(x_train, x_train, epochs=20, batch_size=32, shuffle=True, validation_data=(x_test, x_test))

Model: "encoder_model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 conv2d_5 (Conv2D)              (None, 28, 28, 1)    10          ['input_3[0][0]']                
                                                                                                  
 batch_normalization_8 (BatchNo  (None, 28, 28, 1)   4           ['conv2d_5[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 leaky_re_lu_9 (LeakyReLU)      (None, 28, 28, 1)    0           ['batch_normalization

  super(Adam, self).__init__(name, **kwargs)


Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x7f1b9022ff10>

In [11]:
# Testing
encoded_data = encoder.predict(x_test)
decoded_data = decoder.predict(encoded_data)