In [1]:
from keras.layers import Input, Conv2D,Conv2DTranspose,Flatten,Dense,BatchNormalization,LeakyReLU,Dropout,Activation,Reshape
from keras.models import Model
from keras.optimizers import Adam
from keras.datasets import mnist
from keras.losses import mean_squared_error
from keras import backend as K
import numpy as np

Using TensorFlow backend.


In [2]:
class AutoEncoder():
    def __init__(self , input_dims , encoder_conv_filters , encoder_kernel_size , encoder_strides ,
                  decoder_convt_filters , decoder_convt_kernel_size , decoder_convt_strides ,
                 latent_dims):
        
        self.input_dims                 =   input_dims
        self.encoder_conv_filters       =   encoder_conv_filters
        self.encoder_kernel_size        =   encoder_kernel_size
        self.encoder_strides            =   encoder_strides
        self.decoder_convt_filters      =   decoder_convt_filters
        self.decoder_convt_kernel_size  =   decoder_convt_kernel_size
        self.decoder_convt_strides      =   decoder_convt_strides
        self.n_of_encoder_layers        =   len(encoder_conv_filters)
        self.n_of_decoder_layers        =   len(decoder_convt_filters)
        self.latent_dims                =   latent_dims

        self._build()
        
    def load_mnist(self):
        (x_train, y_train), (x_test, y_test) = mnist.load_data()

        x_train = x_train.astype('float32') / 255.
        x_train = x_train.reshape(x_train.shape + (1,))
        x_test = x_test.astype('float32') / 255.
        x_test = x_test.reshape(x_test.shape + (1,))

        return (x_train, y_train), (x_test, y_test)
    
    
    def _build(self):
        
    
        encoder_input=Input(shape=self.input_dims,name='encoder_input')
        
        x=encoder_input
        
        for i in range(self.n_of_encoder_layers):
            conv_en_layer=Conv2D(filters=self.encoder_conv_filters[i],
                                 kernel_size=self.encoder_kernel_size[i],
                                 strides=self.encoder_strides[i],
                                padding='same',
                                name='encoder_conv_'+str(i))
            x=conv_en_layer(x)
            x=LeakyReLU()(x)
            
            #shape_before_flatten=K.int_shape(x)
            #shape_before_flatten=K.int_shape(x)[:]
            #shape_before_flatten=K.int_shape(x)[1:0]
            #x.Flatten(x)
            
            
        shape_before_flatten=K.int_shape(x)[1:]
        x=Flatten()(x)
        encoder_output=Dense(units=self.latent_dims,name='encoder_output')(x)
        #encoder_output=x
        self.encoder=Model(encoder_input,encoder_output)





        decoder_input=Input(shape=(self.latent_dims,),name='decoder_input')
        
        x=Dense(np.prod(shape_before_flatten))(decoder_input)
        
        x = Reshape(shape_before_flatten)(x)


        for i in range(self.n_of_decoder_layers):
            conv_trans=Conv2DTranspose(filters=self.decoder_convt_filters[i],
                                      kernel_size=self.decoder_convt_kernel_size[i],
                                      strides=self.decoder_convt_strides[i],
                                      padding='same',
                                      name='decoder_conv_t_'+str(i))
            x=conv_trans(x)
            if i < self.n_of_decoder_layers:
                x=LeakyReLU()(x)
            else:
                x=Activation('sigmoid')(x)

        decoder_output=x
        self.decoder=Model(decoder_input,decoder_output)

        model_input= encoder_input
        model_output=self.decoder(encoder_output)
        self.model=Model(model_input,model_output)
        
    def compile(self,learning_rate):
        self.learning_rate=learning_rate
        optimizer=Adam(lr=learning_rate)
        self.model.compile(optimizer=optimizer, loss = mean_squared_error)
        
        
    def train(self, x_train, batch_size, epochs):
        self.model.fit(     
            x_train
            , x_train
            , batch_size = batch_size
            , shuffle = True
            , epochs = epochs)

In [None]:
#model_input= encoder_input
#model_output=self.decoder(encoder_output)

In [None]:
#self.model=Model(model_input,model_output)

In [None]:
#x = np.array([1, 2, 3, 4])

#np.prod(x)