In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

import utils

In [None]:
# Load MNIST Fashion dataset
(x_train, _), (x_test, _) = keras.datasets.fashion_mnist.load_data()

x_train = utils.preprocess(x_train)
x_test = utils.preprocess(x_test)

x_train.shape, x_test.shape

In [None]:
# Autoencoder model definition
class Autoencoder(keras.models.Model):
    def __init__(self, latent_dims, input_shape):
        super(Autoencoder, self).__init__()
        
        self.latent_dims = latent_dims
        self.image_shape = input_shape
        
        # Encoder model
        self.encoder = keras.Sequential([
            keras.layers.Flatten(),
            keras.layers.Dense(512, activation='relu'),
            keras.layers.Dense(self.latent_dims, activation='relu')
        ])
        
        # Decoder model
        self.decoder = keras.Sequential([
            keras.layers.Dense(
                np.prod(np.array(self.image_shape), axis=0, dtype=np.uint32), 
                activation="sigmoid"
            ),
            keras.layers.Reshape(self.image_shape)
        ])
    
    def call(self, x):
        # Encode input
        encoded = self.encoder(x)
        
        # Decode input
        decoded = self.decoder(encoded)
        
        return decoded

In [None]:
# Constants definition
LATENT_DIMS = 64
INPUT_SHAPE = x_train.shape[1:]
EPOCHS = 10

In [None]:
# Initialize Autoencoder model
autoencoder = Autoencoder(latent_dims=LATENT_DIMS, input_shape=INPUT_SHAPE)

# Compile model
autoencoder.compile(
    optimizer='adam',
    loss='mse'
)

# Train model
history = autoencoder.fit(
    x_train,
    x_train,
    epochs=EPOCHS,
    shuffle=True,
    validation_data=(x_test, x_test)
).history

In [None]:
# Plot loss curve
utils.plot_loss(history)

In [None]:
# Generate images on test set
encoded_images = autoencoder.encoder(x_test).numpy()
decoded_images = autoencoder.decoder(encoded_images).numpy()

decoded_images.shape

In [None]:
# Plot the results
n = 10

plt.figure(figsize=(16, 4), dpi=200)

for i in range(n):
    # Plot original images in row 1
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i], cmap='gray')
    plt.title('original')
    plt.axis('off')
    
    # Plot generated images in row 2
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_images[i])
    plt.title('generated')
    plt.axis('off')