In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras import backend as K
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from tensorflow.keras import datasets

(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()

In [None]:
def preprocess(images):
  images = images.astype('float32') / 255.0
  images = np.pad(images, ((0,0), (2,2), (2,2)), constant_values=0.0)
  images = np.expand_dims(images, -1)
  return images

x_train = preprocess(x_train)
x_test = preprocess(x_test)

In [None]:
class Encoder(tf.keras.layers.Layer):
  def build(self, input_shape, filter_seq):
    encoder_input_layer = tf.keras.layers.Input(shape=input_shape, name="encoder_input")
    x = tf.keras.layers.Conv2D(filter_seq[-1], kernel_size=(3,3), strides=2, activation='relu', padding="same", name="enc_conv_1")(encoder_input_layer)
    x = tf.keras.layers.Conv2D(filter_seq[-2], kernel_size=(3,3), strides=2, activation='relu', padding="same", name="enc_conv_2")(x)
    x = tf.keras.layers.Conv2D(filter_seq[-3], kernel_size=(3,3), strides=2, activation='relu', padding="same", name="enc_conv_3")(x)
    shape_before_flatten = K.int_shape(x)[1:]
    x = tf.keras.layers.Flatten()(x)

    encoder_output_layer = tf.keras.layers.Dense(2, name="enc_dense_1")(x)
    return encoder_input_layer, encoder_output_layer, shape_before_flatten

  def __call__(self, input_shape, FILTER_SEQ):
    encoder_input_layer, encoder_output_layer, shape_before_flatten = self.build(input_shape, FILTER_SEQ)
    return encoder_input_layer, encoder_output_layer, shape_before_flatten

In [None]:
class Decoder(tf.keras.layers.Layer):
  def build(self, shape_before_flatten, input_shape, filter_seq):
    decoder_input_layer = tf.keras.layers.Input(shape=(2,), name="decoder_output")
    x = tf.keras.layers.Dense(np.prod(shape_before_flatten), name="dec_dense_1")(decoder_input_layer)
    x = tf.keras.layers.Reshape(shape_before_flatten)(x)
    x = tf.keras.layers.Conv2DTranspose(filter_seq[0], kernel_size=(3,3), strides=2, activation='relu', padding="same", name="dec_transpose_1")(x)
    x = tf.keras.layers.Conv2DTranspose(filter_seq[1], kernel_size=(3,3), strides=2, activation='relu', padding="same", name="dec_transpose_2")(x)
    x = tf.keras.layers.Conv2DTranspose(filter_seq[2], kernel_size=(3,3), strides=2, activation='relu', padding="same", name="dec_transpose_3")(x)

    decoder_output_layer = tf.keras.layers.Conv2D(input_shape[-1], kernel_size=(3,3), strides=1, activation='sigmoid', padding="same", name="decoder_output_layer")(x)

    decoder_shape = K.int_shape(x)[1:]

    return decoder_input_layer, decoder_output_layer, decoder_shape

  def __call__(self, shape_before_flatten, input_shape, filter_seq):
    decoder_input_layer, decoder_output_layer, decoder_shape = self.build(shape_before_flatten, input_shape, filter_seq)
    return decoder_input_layer, decoder_output_layer, decoder_shape

In [None]:
class autoEncoder:
  def build(self, input_shape, filter_seq):
    encoder = Encoder()
    decoder = Decoder()

    encoder_input_layer, encoder_output_layer, shape_before_flatten = encoder(input_shape, filter_seq)
    self.encoder = tf.keras.models.Model(encoder_input_layer, encoder_output_layer)

    decoder_input_layer, decoder_output_layer, decoder_shape = decoder(shape_before_flatten, input_shape, filter_seq)
    self.decoder = tf.keras.models.Model(decoder_input_layer, decoder_output_layer)
    print(self.decoder.summary())

    model = tf.keras.models.Model(encoder_input_layer, self.decoder(encoder_output_layer))
    return model

  def predict(self, image):
    return self.decoder(image)

In [None]:
model_autoEncoder = autoEncoder()

model = model_autoEncoder.build(input_shape = (IMG_SIZE, IMG_SIZE, CHANNEL), filter_seq=FILTER_SEQ)
model.summary()

In [None]:
model.compile(
    optimizer = tf.keras.optimizers.Adam(learning_rate = 0.01),
    loss = 'binary_crossentropy',
    metrics = ['accuracy']
)

In [None]:
history = model.fit(
    x_train, x_train,
    validation_data = (x_test, x_test),
    epochs = 10,
    shuffle = True,
    batch_size = 1
)

## Image Generation

### Encoding

In [None]:
num_samples = 6

x_sample = np.random.normal(size=(num_samples, 2))
print(x_sample)

In [None]:
encoding = model_autoEncoder.encoder(x_sample)

In [None]:
for i in range(num_samples):
  plt.scatter(encoding[i, 0], encoding[i, 1])
plt.show()

### decoding

In [None]:
generate = model_autoEncoder.predict(x_test)

In [None]:
for i in range(num_samples):
  plt.subplot(int(num_samples/2), 2, i+1)
  plt.imshow(generate[0], cmap='gray')
plt.show()