In [1]:
import os
import tensorflow as tf
from tensorflow import keras
from keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, ReLU, Dropout
from keras.models import Model
from keras import backend as K
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint,LearningRateScheduler, EarlyStopping
from keras.utils import plot_model
import numpy as np
import pickle
import matplotlib.pyplot as plt

In [2]:
K.clear_session()

In [3]:
from tensorflow.keras.datasets import mnist
# load dataset
(trainX, trainy), (testX, testy) = mnist.load_data()

trainX=trainX/255
testX=testX/255

# summarize loaded dataset
print('Train: X=%s, y=%s' % (trainX.shape, trainy.shape))
print('Test: X=%s, y=%s' % (testX.shape, testy.shape))

Train: X=(60000, 28, 28), y=(60000,)
Test: X=(10000, 28, 28), y=(10000,)


In [7]:
def step_decay(epoch):

    initial_lrate=0.001
    step_size=10
    decay_factor=0.5
    new_lr=initial_lrate * (decay_factor ** np.floor(epoch/step_size))
    print(new_lr,epoch)
    return new_lr

class VariationalAutoEncoder():

    _ENC_FILTERS = [32,64,64,64]
    _ENC_STRIDES = [1,2,2,1]
    _FILTER_SIZE = [3,3,3,3]
    _LATENT_SPACE = 32
    _DEC_FILTERS = [64,64,32,1]
    _DEC_STRIDES = [1,2,2,1]
    _BATCH_NORM = True
    _DROPOUT = 0.2

    def __init__(self,input_dim):

        self.input_dim=input_dim
        #self._build()

    def _build_encoder(self):

        encoder_input=Input(shape=self.input_dim,name='encoder_input')
        x=encoder_input

        for lyr in range(len(self._ENC_FILTERS)):

            conv_layer=Conv2D(filters=self._ENC_FILTERS[lyr],
                              kernel_size=self._FILTER_SIZE[lyr],
                              strides=self._ENC_STRIDES[lyr],
                              name='encoder_conv'+str(lyr),
                              padding ='same')
            x=conv_layer(x)
            x=ReLU()(x)

            if self._BATCH_NORM==True:
                x=BatchNormalization()(x)

        shape_before_flattening=K.int_shape(x)[1:]
        x=Flatten()(x)

        self.mu = Dense(self._LATENT_SPACE, name='mu')(x)
        self.log_var = Dense(self._LATENT_SPACE, name='log_var')(x)
        encoder_mu_log_var = Model(encoder_input, (self.mu, self.log_var))

        def sampling(args):
            mu, log_var = args
            epsilon = K.random_normal(shape=K.shape(mu), mean=0., stddev=1.)
            return mu + K.exp(log_var / 2) * epsilon
        
        encoder_output = Lambda(sampling, name='encoder_output')([self.mu, self.log_var])
        encoder = Model(encoder_input, encoder_output)

        return encoder

    def _build_decoder(self,shape_before_flattening):

        decoder_input=Input(shape=self._LATENT_SPACE,name='decoder_input')
        x=Dense(np.prod(shape_before_flattening))(decoder_input)
        x=Reshape(shape_before_flattening)(x)

        for lyr in range(len(self._DEC_FILTERS)):

            conv_t_layer=Conv2DTranspose(filters=self._DEC_FILTERS[lyr],
                                         kernel_size=self._FILTER_SIZE[lyr],
                                         strides=self._DEC_STRIDES[lyr],
                                         name='decoder_conv_t'+str(lyr),
                                         padding ='same')
            x=conv_t_layer(x)

            if lyr<len(self._DEC_FILTERS)-1:

                x=ReLU()(x)

                if self._BATCH_NORM==True:
                    x = BatchNormalization()(x)
            else:
                x=Activation('sigmoid')(x)

        decoder_output=x
        decoder=Model(decoder_input,decoder_output, name = 'Decoder')
        return decoder

    def fit(self,Xtr,Xval,**kwargs):

        self.initial_lr = kwargs.get("initial_lr", 0.001)
        self.batch_size = kwargs.get("batch_size", 32)
        self.callbacks_list = kwargs.get("callbacks_list", [])
        self.epochs = kwargs.get("epochs", 100)
        self.shuffle = kwargs.get("shuffle", True)

        model, decoder, encoder = self.get_model()

        model.fit(Xtr, Xtr,
                       batch_size=self.batch_size,
                       validation_data=(Xval,Xval),
                       shuffle=self.shuffle,
                       epochs=self.epochs,
                       callbacks=self.callbacks_list)

        self.model = model
        self.decoder = decoder
        self.encoder = encoder


    def get_model(self):

        #Autoencoder Layer
        encoder = self._build_encoder()
        shape_before_flattening = encoder.layers[-5].get_output_at(0).get_shape().as_list()[1:]
        decoder = self._build_decoder(shape_before_flattening)

        model_input = Input(shape=self.input_dim,name='encoder_input')
        encoder_op = encoder(model_input)
        decoder_op = decoder(encoder_op)

        model = Model(model_input,decoder_op,name='ae_model')

        model.compile(optimizer=Adam(learning_rate=0.001),
                           loss = 'mean_squared_error',
                           )

        return model, decoder, encoder

In [8]:
vae = VariationalAutoEncoder(input_dim=(28,28,1))
model, decoder, encoder = vae.get_model()

In [9]:
encoder.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 encoder_input (InputLayer)     [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 encoder_conv0 (Conv2D)         (None, 28, 28, 32)   320         ['encoder_input[0][0]']          
                                                                                                  
 re_lu_4 (ReLU)                 (None, 28, 28, 32)   0           ['encoder_conv0[0][0]']          
                                                                                                  
 batch_normalization_4 (BatchNo  (None, 28, 28, 32)  128         ['re_lu_4[0][0]']                
 rmalization)                                                                               