In [1]:
from keras.layers import Input,Conv2D,Conv2DTranspose,Dense,Flatten,LeakyReLU,Activation,Reshape,Dropout,Lambda,BatchNormalization
from keras.models  import Model
from keras import backend as K
import numpy as np
from keras.optimizers import Adam
from keras.datasets import mnist

Using TensorFlow backend.


In [2]:
class VarAutoEnc():
    def __init__(self, input_dims, 
                encoder_conv_filters,
                encoder_kernel_size,
                encoder_strides,
                decoder_conv_t_filtres,
                decoder_kernel_size,
                decoder_striders,
                latent_dims):
        
        self.input_dims=input_dims
        self.encoder_conv_filters=encoder_conv_filters
        self.encoder_kernel_size=encoder_kernel_size
        self.encoder_strides=encoder_strides
        self.decoder_conv_t_filtres=decoder_conv_t_filtres
        self.decoder_kernel_size=decoder_kernel_size
        self.decoder_striders=decoder_striders
        self.latent_dims=latent_dims
        
        self.no_enc_layers=len(encoder_conv_filters)
        self.no_dec_layers=len(decoder_conv_t_filtres)
        
        self._build()

    def load_mnist(self):
        (x_train, y_train), (x_test, y_test) = mnist.load_data()

        x_train = x_train.astype('float32') / 255.
        x_train = x_train.reshape(x_train.shape + (1,))
        x_test = x_test.astype('float32') / 255.
        x_test = x_test.reshape(x_test.shape + (1,))

        return (x_train, y_train), (x_test, y_test)

    def _build(self):
        
        encoder_input=Input(shape=self.input_dims,name='encoder_input')
        x=encoder_input # I always forget this. EVERY SINGLE TIME
        
        for i in range(self.no_enc_layers):
            enc_conv=Conv2D(filters=self.encoder_conv_filters[i],
                           kernel_size=self.encoder_kernel_size[i],
                           strides=self.encoder_strides[i],
                           padding='same',
                           name='encoder_conv_'+str(i))
            x=enc_conv(x)
            x = BatchNormalization()(x)
            x=LeakyReLU()(x)
            x=Dropout(0.25)(x)
            
        shape_before_flatten=K.int_shape(x)[1:]
        x=Flatten()(x)
            #splitting 'x' layer's op into 2 layers of size = latent dimensions
        self.mu=Dense(self.latent_dims,name='mu')(x) #making a layer 'mu' having units equal to latent dims 
        self.log_var=Dense(self.latent_dims,name='log_var')(x) #making a layer 'log_var' having units equal to latent dims 
            
            #not connectig the flattened layer directly to the latent space but using mu and logvar layers as a bridge between them
            # connecting the 2 layers mu and logvar instead of 'x'
        self.encoder_mu_log_var=Model(encoder_input,(self.mu,self.log_var))
            
            
            # mu = mean of the distribution
            # log_var= log of variance of each dim
            # point = mu + exp(log_var/2)*normal_distribution
            
        def sampling(args):
            mu,log_var=args
            norm_dis=K.random_normal(shape=K.shape(mu)) #default mean = 0 and std=1
            return mu+K.exp(log_var/2)*norm_dis
            
        encoder_output=Lambda(sampling,name='encoder_output')([self.mu,self.log_var])
        self.encoder=Model(encoder_input,encoder_output)
            
            
        decoder_input=Input(shape=(self.latent_dims,),name='decoder_input') # number of layers whe decoder connected to an encoder is equal to the 
                                                                                #latent dims to which encoder maps outputs
            
            
        x=Dense(np.prod(shape_before_flatten))(decoder_input)
        x=Reshape(shape_before_flatten)(x)
            
        for i in range(self.no_dec_layers):
            dec_conv=Conv2DTranspose(
                filters=self.decoder_conv_t_filtres[i],
                kernel_size=self.decoder_kernel_size[i],
                strides=self.decoder_striders[i],
                padding='same',
                name='dec_conv_transpose_'+str(i)
            )
            x=dec_conv(x)
            if i < self.no_dec_layers - 1:
                x = BatchNormalization()(x)
                x=LeakyReLU()(x)
                x=Dropout(0.25)(x)
            else:
                x=Activation('sigmoid')(x) #need lat layer op as 1 or 0
        decoder_output=x
            
        self.decoder=Model(decoder_input,decoder_output)
            
            
            # Combining the encoder and decoder
            
        model_input=encoder_input # AE has input same as encoder input
        model_output=self.decoder(encoder_output) # AE output is same as decoder output
            
        self.model=Model(model_input,model_output)
            
    def compile(self,lrate,loss_factor):
        self.lrate=lrate
                
        def vae_r_loss(y_true,y_pred):
            r_loss=K.mean(K.square(y_true-y_pred),axis=[1,2,3])
            return loss_factor*r_loss
            
                #calculating KL divergence for loss
        def vae_kl_loss(y_ture,y_pred):
            kl_loss = -0.5*K.sum(1+self.log_var-K.square(self.mu)-K.exp(self.log_var),axis=1)
            return kl_loss
            
        def vae_loss(y_true,y_pred):
            r_loss=vae_r_loss(y_true,y_pred)
            kl_loss=vae_kl_loss(y_true,y_pred)
            return r_loss+kl_loss
               
        optimizer=Adam(lr=lrate)
        self.model.compile(optimizer=optimizer,loss=vae_loss,metrics=[vae_r_loss,vae_kl_loss])
                
    def train(self,x_train,batch_size,epochs):
        self.model.fit(     
            x_train
            , x_train
            , batch_size = batch_size
            , shuffle = True
            , epochs = epochs)