In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv2D, LeakyReLU, Flatten, Dense, Activation
import numpy as np

In [None]:
# ENCODER:

In [None]:
class Encoder():
  def __init__(self, **kwargs):

    for key, value in kwargs.items():
      setattr(self, key, value)

    encoder_inp = keras.Input(shape=self.input_dim, name='encoder_input')
    x = encoder_inp

    # simply stack of Conv Layers
    for i in range(self.n_layers_enc):
      conv_layer = Conv2D(
          filters = self.encoder_conv_filters[i],
          kernel_size = self.encoder_conv_kernel_size[i],
          strides = self.encoder_conv_strides[i],
          padding = 'same',
          name = 'enc_conv_' + str(i)
      )
      x = conv_layer(x)
      x = LeakyReLU()(x)

    self.shape_before_flattening = x.shape[1:]
    print(self.shape_before_flattening)
    x = Flatten()(x)
    encoder_output = Dense(self.z_dim, name='encoder_output')(x)

    self.encoder_output = encoder_output

    self.enc = Model(encoder_inp, encoder_output)

In [None]:
AE = Encoder(
  input_dim = (28, 28, 1),
  encoder_conv_filters = [32,64,64, 64],
  encoder_conv_kernel_size = [3,3,3,3],
  encoder_conv_strides = [1,2,2,1],
  decoder_conv_t_filters = [64,64,32,1],
  decoder_conv_t_kernel_size = [3,3,3,3],
  decoder_conv_t_strides = [1,2,2,1],
  z_dim = 2,
  n_layers_enc = 4
)

(7, 7, 64)


In [None]:
shape_before_flattening = AE.shape_before_flattening

In [None]:
# DECODER

In [None]:
from tensorflow.keras.layers import Reshape, Conv2DTranspose

In [None]:
class Decoder():
  def __init__(self, **kwargs):

    for key, value in kwargs.items():
      setattr(self, key, value)

    decoder_inp = keras.Input(shape=(self.z_dim,), name='decoder_input')

    x = Dense(np.prod(shape_before_flattening))(decoder_inp)
    x = Reshape(shape_before_flattening)(x)

    for i in range(self.n_layers_dec):
      conv_t_layer = Conv2DTranspose(
          filters = self.decoder_conv_t_filters[i],
          kernel_size = self.decoder_conv_t_kernel_size[i],
          strides = self.decoder_conv_t_strides[i],
          padding = 'same',
          name = 'decoder_conv_t_' + str(i)
      )

      x = conv_t_layer(x)

      if i < self.n_layers_dec - 1:
        x = LeakyReLU()(x)
      else:
        x = Activation('sigmoid')(x)

    dec_output = x
    self.decoder = Model(decoder_inp, dec_output)

In [None]:
AD = Decoder(
  input_dim = (28, 28, 1),
  encoder_conv_filters = [32,64,64, 64],
  encoder_conv_kernel_size = [3,3,3,3],
  encoder_conv_strides = [1,2,2,1],
  decoder_conv_t_filters = [64,64,32,1],
  decoder_conv_t_kernel_size = [3,3,3,3],
  decoder_conv_t_strides = [1,2,2,1],
  z_dim = 2,
  n_layers_dec = 4
)

In [None]:
encoder = AE.enc
decoder = AD.decoder

In [None]:
encoder_output = AE.encoder_output

In [None]:
model_inp = encoder_inp
model_out = decoder(encoder_output)

In [None]:
model = Model(model_inp, model_out)