<h1 style = "font-size:3rem;color:darkcyan"> Implementing autoencoders</h1>


Implementing an deep convoluational autoencoder class with mirrored encoder and decoder components. 
<i>Based on the work of Valerio Velardo.</i>

In [6]:
# importing libraries
import numpy as np
from tensorflow.keras import Model
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input, Conv2D, ReLU, \
BatchNormalization, Flatten, Dense

In [18]:
class Autoencoder:
    
    def __init__(self, 
                input_shape,
                conv_filters,
                conv_kernels,
                conv_strides,
                latent_space_dim):
        
        self.input_shape = input_shape 
        self.conv_filters = conv_filters
        self.conv_kernels = conv_kernels
        self.conv_strides = conv_strides
        self.latent_space_dim = latent_space_dim
        
        self.encoder = None
        self.decoder = None
        self.model = None
        
        self._num_conv_layers = len(conv_filters)
        self._shape_before_bottleneck = None
        
        self._build()
        
    def summary(self):
        self.encoder.summary()
        
    def _build(self):
        self._build_encoder()
        #self._build_decoder()
        #self._build_autoencoder() 
        
    
    def _build_encoder(self):
        encoder_input = self._add_encoder_input()
        conv_layers = self._add_conv_layers(encoder_input)
        bottleneck = self._add_bottleneck(conv_layers)
        self.encoder = Model(encoder_input, bottleneck, name = 'encoder')
    
    def _add_encoder_input(self):
        return Input(shape = self.input_shape, name = 'encoder_input')
    
    def _add_conv_layers(self, encoder_input):
        layer_graph = encoder_input
        for i in range(self._num_conv_layers):
            layer_graph = self._add_conv_layer(i, layer_graph)
        return layer_graph
    
    def _add_conv_layer(self, layer_index, layer_graph):
        # conv2D + ReLu + batch normalization
        
        current_layer = layer_index + 1
        conv_layer = Conv2D(
            filters = self.conv_filters[layer_index],
            kernel_size = self.conv_kernels[layer_index],
            strides = self.conv_strides[layer_index],
            padding = 'same',
            name = f'encoder_conv_layer_{current_layer}'
        )
        
        layer_graph = conv_layer(layer_graph)
        layer_graph = ReLU(name=f'encoder_relu{current_layer}')(layer_graph)
        layer_graph = BatchNormalization(name=f'encoder_bn_{current_layer}')(layer_graph)
        
        return layer_graph
    
    def _add_bottleneck(self, layer_graph): 
        # save shape for decoding
        self._shape_before_bottleneck = K.int_shape(layer_graph)
        # flatten data and add Dense layer (bottleneck)
        layer_graph = Flatten()(layer_graph)
        layer_graph = Dense(self.latent_space_dim, name = 'encoder_output')(layer_graph)
        return layer_graph
       

In [20]:
autoencoder = Autoencoder(
    input_shape = (28,28,1),
    conv_filters = (32,64,64,64),
    conv_kernels = (3,3,3,3),
    conv_strides = (1,2,2,1),
    latent_space_dim = 2
)

In [22]:
autoencoder.summary()


Model: "encoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder_input (InputLayer)  [(None, 28, 28, 1)]       0         
                                                                 
 encoder_conv_layer_1 (Conv2  (None, 28, 28, 32)       320       
 D)                                                              
                                                                 
 encoder_relu1 (ReLU)        (None, 28, 28, 32)        0         
                                                                 
 encoder_bn_1 (BatchNormaliz  (None, 28, 28, 32)       128       
 ation)                                                          
                                                                 
 encoder_conv_layer_2 (Conv2  (None, 14, 14, 64)       18496     
 D)                                                              
                                                           