# Autoencoder

This deep neural networks made of two parts:
- An encoder: a network that learns to represent/compress the high-dimensional input data into a lower dimensional latent space
- A decoder: a network that learns to decompress a given representation/vector in the latent space to a high-dimensional representation

Often used to remove noise from images.

It is really easy to create an autoencoder using Keras Model Subclassing API as show in this [Tensorflow tutorial](https://www.tensorflow.org/tutorials/generative/autoencoder).

See also: https://drive.google.com/drive/folders/1KPsQvVDUcJzsDRw7YU-F2TvPLGdcV0An

Let's build an autoencoder "from scratch" to have a better understanding!

## Hand Made Autoencoder

In [None]:
import os
import numpy as np
import pandas as pd
import pickle

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, LeakyReLU, Flatten, Dense, Reshape, Conv2DTranspose, Activation, BatchNormalization, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.callbacks import Callback, LearningRateScheduler
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.datasets import mnist
from keras.utils import plot_model

# Clear TensorFlow session
tf.keras.backend.clear_session()

import matplotlib.pyplot as plt

In [None]:
class CustomCallback(Callback):
    
    def __init__(self, run_folder, print_every_n_batches, initial_epoch, vae):
        self.epoch = initial_epoch
        self.run_folder = run_folder
        self.print_every_n_batches = print_every_n_batches
        self.vae = vae

    def on_batch_end(self, batch, logs={}):  
        if batch % self.print_every_n_batches == 0:
            z_new = np.random.normal(size = (1,self.vae.latent_dim))
            reconst = self.vae.decoder.predict(np.array(z_new))[0].squeeze()

            filepath = os.path.join(self.run_folder, 'images', 'img_' + str(self.epoch).zfill(3) + '_' + str(batch) + '.jpg')
            if len(reconst.shape) == 2:
                plt.imsave(filepath, reconst, cmap='gray_r')
            else:
                plt.imsave(filepath, reconst)

    def on_epoch_begin(self, epoch, logs={}):
        self.epoch += 1



def step_decay_schedule(initial_lr, decay_factor=0.5, step_size=1):
    '''
    Wrapper function to create a LearningRateScheduler with step decay schedule.
    '''
    def schedule(epoch):
        new_lr = initial_lr * (decay_factor ** np.floor(epoch/step_size))
        
        return new_lr

    return LearningRateScheduler(schedule)

### Encoder

In [None]:
class Encoder():

    def __init__(self, input_dim, encoder_n_layers, encoder_conv_filters, encoder_conv_kernel_sizes, encoder_conv_strides, latent_dim, batch_norm, drop_out):
        self.input_dim = input_dim
        self.encoder_n_layers = encoder_n_layers
        self.encoder_conv_filters = encoder_conv_filters
        self.encoder_conv_kernel_sizes = encoder_conv_kernel_sizes
        self.encoder_conv_strides = encoder_conv_strides
        self.latent_dim = latent_dim
        self.batch_norm = batch_norm
        self.drop_out = drop_out

        self.input = Input(shape=self.input_dim, name="encoder_input")

        x = self.input

        for i in range(self.encoder_n_layers):
            conv_layer = Conv2D(filters=self.encoder_conv_filters[i],
                                kernel_size=self.encoder_conv_kernel_sizes[i],
                                strides=self.encoder_conv_strides[i],
                                padding="same",
                                name="encoder_conv_" + str(i))
            x = conv_layer(x)
            x = LeakyReLU(name="encoder_leaky_relu_" + str(i))(x)

            if self.use_batch_norm:
                x = BatchNormalization()(x)

            if self.use_dropout:
                x = Dropout(rate = 0.25)(x)

        # self.shape_before_flattening = tf.keras.backend.int_shape(x)[1:] # See decoder
        self.shape_before_flattening = x.shape[1:] # See decoder

        x = Flatten(name="encoder_flatten")(x)

        self.output = Dense(self.latent_dim, name="output")(x)

        self.model = Model(self.input, self.output)

    def summary(self):
        self.model.summary()

    def predict(self, x):
        return self.model.predict(x)


### Decoder

In [None]:
class Decoder():

    def __init__(self, latent_dim, shape_before_flattening, decoder_n_layers, decoder_conv_t_filters, decoder_conv_t_kernel_sizes, decoder_conv_t_strides, output_dim, batch_norm, drop_out):
        self.output_dim = output_dim
        self.shape_before_flattening = shape_before_flattening
        self.decoder_n_layers = decoder_n_layers
        self.decoder_conv_t_filters = decoder_conv_t_filters
        self.decoder_conv_t_kernel_sizes = decoder_conv_t_kernel_sizes
        self.decoder_conv_t_strides = decoder_conv_t_strides
        self.latent_dim = latent_dim
        self.batch_norm = batch_norm
        self.drop_out = drop_out

        self.input = Input(shape=(self.latent_dim,), name="decoder_input")

        x = Dense(np.prod(shape_before_flattening))(self.input) # Connect the input to a dense layer

        x = Reshape(self.shape_before_flattening)(x) # Reshape latent space vector for convolutional transpose layers

        for i in range(self.decoder_n_layers):
            conv_t_layer = Conv2DTranspose(filters=self.decoder_conv_t_filters[i],
                                           kernel_size=self.decoder_conv_t_kernel_sizes[i],
                                           strides=self.decoder_conv_t_strides[i],
                                           padding="same",
                                           name="decoder_conv_t_" + str(i))
            x = conv_t_layer(x)

            if i < self.decoder_n_layers - 1:
                x = LeakyReLU(name="decoder_leaky_relu_" + str(i))(x)
                
                if self.use_batch_norm:
                    x = BatchNormalization()(x)
                if self.use_dropout:
                    x = Dropout(rate = 0.25)(x)
            else:
                # x = Activation("sigmoid")(x)
                x = Activation(tf.keras.activations.sigmoid, name="decoder_sigmoid_" + str(i))(x)

        self.output = x

        self.model = Model(self.input, self.output)

    def summary(self):
        self.model.summary()

    def predict(self, x):
        return self.model.predict(x)
        

### Autoencoder

In [None]:
class Autoencoder():

    def __init__(self, input_dim,
                 encoder_n_layers, encoder_conv_filters, encoder_conv_kernel_sizes, encoder_conv_strides,
                 latent_dim,
                 decoder_n_layers, decoder_conv_t_filters, decoder_conv_t_kernel_sizes, decoder_conv_t_strides,
                 output_dim,
                 learning_rate, batch_norm, drop_out):
        self.input_dim = input_dim
        self.encoder_n_layers = encoder_n_layers
        self.encoder_conv_filters = encoder_conv_filters
        self.encoder_conv_kernel_sizes = encoder_conv_kernel_sizes
        self.encoder_conv_strides = encoder_conv_strides
        self.latent_dim = latent_dim
        self.decoder_n_layers = decoder_n_layers
        self.decoder_conv_t_filters = decoder_conv_t_filters
        self.decoder_conv_t_kernel_sizes = decoder_conv_t_kernel_sizes
        self.decoder_conv_t_strides = decoder_conv_t_strides
        self.output_dim = output_dim
        self.batch_norm = batch_norm
        self.dropout = drop_out

        self.learning_rate = learning_rate

        # Create encoder
        self.encoder = Encoder(self.input_dim,
                               self.encoder_n_layers, self.encoder_conv_filters, self.encoder_conv_kernel_sizes, self.encoder_conv_strides,
                               self.latent_dim)
        
        # Create decoder
        self.decoder = Decoder(self.latent_dim, self.encoder.shape_before_flattening,
                               self.decoder_n_layers, self.decoder_conv_t_filters, self.decoder_conv_t_kernel_sizes, self.decoder_conv_t_strides,
                               self.output_dim)
        
        # Create model
        self.model_input = self.encoder.input
        self.model_output = self.decoder.model(self.encoder.output)

        self.model = Model(self.model_input, self.model_output)

        # Compile model
        self.optimizer = Adam(learning_rate=self.learning_rate)
        self.model.compile(optimizer=self.optimizer, loss=MeanSquaredError())

    def summary(self):
        self.model.summary()

    def plot_model(self, run_folder):
        plot_model(self.model, to_file=os.path.join(run_folder ,'viz/model.png'), show_shapes=True, show_layer_names=True)
        plot_model(self.encoder.model, to_file=os.path.join(run_folder ,'viz/encoder.png'), show_shapes=True, show_layer_names=True)
        plot_model(self.decoder.model, to_file=os.path.join(run_folder ,'viz/decoder.png'), show_shapes=True, show_layer_names=True)

    def load_weights(self, filepath="model/weights/params.pkl"):
        self.model.load_weights(filepath)

    def fit(self, x, y, batch_size, epochs, validation_split, shuffle, initial_epoch=0, print_every_n_batches=100, lr_decay=1):

        # Callbacks
        # custom_callback = CustomCallback("model", print_every_n_batches, initial_epoch, self)
        lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1)
        # checkpoint2 = ModelCheckpoint(os.path.join("model", 'weights/weights.h5'), save_weights_only = True, verbose=1)

        callbacks_list = [lr_sched]

        # Training
        self.history = self.model.fit(x,
                                      y,
                                      batch_size=batch_size,
                                      epochs=epochs,
                                      callbacks=callbacks_list,
                                      validation_split=validation_split,
                                      shuffle=shuffle,
                                      initial_epoch=initial_epoch)
        return self.history

    def predict(self, x):
        return self.model.predict(x)

    def save(self, folder="model"):
        if not os.path.exists(folder):
            os.makedirs(folder)
            os.makedirs(os.path.join(folder, 'viz'))
            os.makedirs(os.path.join(folder, 'weights'))
            os.makedirs(os.path.join(folder, 'images'))

        with open(os.path.join(folder, 'weights/params.pkl'), 'wb') as f:
            pickle.dump([self.input_dim,
                         self.encoder_conv_filters,
                         self.encoder_conv_kernel_sizes,
                         self.encoder_conv_strides,
                         self.decoder_conv_t_filters,
                         self.decoder_conv_t_kernel_sizes,
                         self.decoder_conv_t_strides,
                         self.latent_dim,
                         self.batch_norm,
                         self.dropout], f)
        self.plot_model(folder)

In [None]:
autoencoder = Autoencoder(input_dim=(28,28,1),
                          encoder_n_layers=4,
                          encoder_conv_filters=[32,64,64,64],
                          encoder_conv_kernel_sizes=[3,3,3,3],
                          encoder_conv_strides=[1,2,2,1],
                          latent_dim=2,
                          decoder_n_layers=4,
                          decoder_conv_t_filters=[64,64,32,1],
                          decoder_conv_t_kernel_sizes=[3,3,3,3],
                          decoder_conv_t_strides=[1,2,2,1],
                          output_dim=(28,28,1),
                          learning_rate=0.0005,
                          batch_norm=False,
                          drop_out=False)

In [None]:
autoencoder.encoder.summary()

In [None]:
autoencoder.decoder.summary()

In [None]:
# autoencoder.summary()

In [None]:
# DEBUG
# for layer in autoencoder.model.layers:
#     print(layer.name, layer.output_shape)

### Load Data

In [None]:
# Load MNIST dataset
mnist_dataset = mnist.load_data()
(trainset, testset) = (mnist_dataset[0], mnist_dataset[1])
(X_train, y_train) = trainset
(X_test, y_test) = testset

# Preprocess data (convert to float and scale to between 0 and 1)
X_train = X_train.astype('float32')
X_train /= 255
X_test = X_test.astype('float32')
X_test /= 255

# Preprocess data (convert to uint8)
# y_train = y_train.astype('uint8')
# y_test = y_test.astype('uint8')

### Train

In [None]:
BATCH_SIZE = 32
EPOCHS = 200

In [None]:
autoencoder.fit(x=X_train,
                y=X_train,
                batch_size=BATCH_SIZE,
                epochs=EPOCHS,
                validation_split=0.1,
                shuffle=True)

### Evaluate Autoencoder

In [None]:
# Plot losses
losses = autoencoder.history.history
plt.plot(losses["loss"], label="train loss")
plt.plot(losses["val_loss"], label="val loss")
plt.legend()
plt.show()

In [None]:
# Plot AE reconstructions
decoded_images = autoencoder.predict(X_test, batch_size=BATCH_SIZE)
n_images = 10
for i in range(n_images):
    plt.figure(figsize=((5,5)))
    plt.subplot(1,2,1)
    plt.imshow(X_test[i], cmap="gray")
    plt.title("Test Image " + str(i+1))
    plt.axis("off")
    plt.subplot(1,2,2)
    plt.imshow(decoded_images[i].reshape(28, 28), cmap="gray")
    plt.title("Reconstruction " + str(i+1))
    plt.axis("off")

### Save Model

In [None]:
autoencoder.save()

### Load Pre-Trained Model

In [None]:
# autoencoder.load_weights()

### Predictions

#### Prediction from Dataset

In [None]:
predictions = autoencoder.predict(X_test)

In [None]:
plt.imshow(X_test[0])

In [None]:
plt.imshow(predictions[0])

#### Random Prediction

The distribution of the points in the latent space is unknown. For instance, with the mnist dataset and a latent space in 2D, it is possible top plot where each type of digit is located in the latent space.
- The plot is not symmetrical about the point (0,0): How should one choose a point in the latent space to produce a specific digit?
- Some digits are represented in the latent space over small areas and others over large areas: There will be a lack of diversity in the images produced. (More of the digits with the larger areas).
- There are large gaps between digits that contains few individuals: Some generated images (from the gaps) will be porrly formed
- The latent space is not continuous: Points in the middle of digits areas can also be ill-formed

In [None]:
latent_vector_1 = np.random.random(size=(1,2))
latent_vector_1

In [None]:
prediction_1 = autoencoder.decoder.predict(latent_vector_1)

In [None]:
# prediction_1 = prediction_1.numpy()
prediction_1 = prediction_1[0]
plt.imshow(prediction_1)

In [None]:
latent_vector_2 = latent_vector_1 + np.array([1, -0.5])
latent_vector_2

In [None]:
prediction_2 = autoencoder.decoder.predict(latent_vector_2)

In [None]:
prediction_2 = prediction_2[0]
plt.imshow(prediction_2)