# 1 AUTOENCODERS

In this lab we will create a simple autoencoder to generate images from a vector of random numbers.

We begin with our usual imports

In [None]:
import tensorflow as tf
import numpy as np
import tensorflow.keras.backend as K

from tensorflow.keras import (layers, optimizers, losses, callbacks, models, datasets)
from utils import display
from tensorflow.keras.preprocessing.image import smart_resize
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [None]:
from tensorflow.keras.datasets import mnist
CHANNELS = 1     # 1 channel,since MNIST is a greyscale dataset
EMB_DIM = 3      # Dimension of our latent vector
IMAGE_SIZE = 16  # Width and height of images in pixels.

# Load and store filenames
SAVE_FILENAME_ENC = "mnist-encoder.keras"
SAVE_FILENAME_DEC = "mnist-decoder.keras"
SAVE_FILENAME_AE = "mnist-ae.keras"

# Training Epochs
EPOCHS = 5

# Batch size
BATCH_SIZE = 64

In [None]:
(x_train, _), (x_test, _) = mnist.load_data()
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

# Resize and normalize
x_train = np.array([smart_resize(img, (IMAGE_SIZE, IMAGE_SIZE)) for img in x_train])/255.0
x_test = np.array([smart_resize(img, (IMAGE_SIZE, IMAGE_SIZE)) for img in x_test])/255.0


In [None]:
# Create the encoder
# BatchNormalization is needed for good convergence and to preserve color data. Without batch normalization we
# get a grey square when there's too many layers.

encoder_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS), name="encoder_input")

layer1 = layers.Conv2D(32, kernel_size=(3, 3), strides = 2, 
                       activation = 'relu', name="encoder_layer1", 
                      padding = 'same')(encoder_input)
layer1 = layers.BatchNormalization(momentum=0.8)(layer1)
layer2 = layers.Conv2D(64, kernel_size=(3, 3), strides = 2, 
                       activation = 'relu', name="encoder_layer2",
                      padding = 'same')(layer1)
layer2 = layers.BatchNormalization(momentum=0.8)(layer2)
layer3 = layers.Conv2D(128, kernel_size=(3, 3), strides = 2, 
                       activation = 'relu', name="encoder_layer3",
                      padding = 'same')(layer2)
layer3 = layers.BatchNormalization(momentum=0.8)(layer3)

# We need this for the decoder
shape_before_flattening = K.int_shape(layer2)[1:]
flatten_layer = layers.Flatten()(layer3)

# The embedding layer
encoder_output = layers.Dense(EMB_DIM, activation='sigmoid', name="encoder_output")(flatten_layer)

encoder = models.Model(inputs = encoder_input, outputs = encoder_output)
encoder.summary()


In [None]:
#Create the decoder
# Adding batch normalization restores the color

decoder_input = layers.Input(shape = (EMB_DIM,), name = "decoder_input")
layer1 = layers.Dense(np.prod(shape_before_flattening), activation='relu', 
                      name="decoder_layer1")(decoder_input)
reshape = layers.Reshape(target_shape = shape_before_flattening)(layer1)

layer2 = layers.Conv2DTranspose(128, kernel_size=(3, 3), strides = 1, 
                                activation = 'relu', name = 'decoder_layer2'
                               ,padding = 'same')(reshape)
layer2 = layers.BatchNormalization(momentum = 0.8)(layer2)

layer3 = layers.Conv2DTranspose(64, kernel_size=(3, 3), strides = 2, 
                                activation = 'relu', name = 'decoder_layer3'
                               ,padding = 'same')(layer2)

layer3 = layers.BatchNormalization(momentum = 0.8)(layer3)

layer4 = layers.Conv2DTranspose(32, kernel_size=(3, 3), strides = 2, 
                                activation = 'relu', name = 'decoder_layer4'
                               ,padding = 'same')(layer3)

layer4 = layers.BatchNormalization(momentum = 0.8)(layer4)

decoder_output = layers.Conv2DTranspose(CHANNELS, kernel_size = (3, 3), strides = 1,
                                        activation = 'sigmoid', name = 'decoder_output'
                                       ,padding = 'same')(layer4)

decoder = models.Model(inputs = decoder_input, outputs = decoder_output)
decoder.summary()


In [None]:
# Create the autoencoder
autoencoder = models.Model(inputs = encoder_input, outputs = decoder(encoder_output),
                           name = 'autoencoder')
autoencoder.summary()

In [None]:
# Set up the training
Adam = optimizers.Adam(learning_rate = 0.01)
autoencoder.compile(optimizer = Adam, loss = 'binary_crossentropy')

# Save the encoder, decoder and AE weights
checkpoint = callbacks.ModelCheckpoint(filepath = SAVE_FILENAME_AE, save_freq='epoch')

# Early Stopping
earlystop = callbacks.EarlyStopping(min_delta = 0.01, patience = 5)

# Create our own custom callback to save encoder and decoder weights
class saveCallback(callbacks.Callback):
    def __init__(self, encoder, decoder):
        self.encoder = encoder
        self.decoder = decoder
        
    def on_epoch_end(self, epoch, logs=None):
        print("\nSaving at epoch: %d.\n" % (epoch+1))
        self.encoder.save(SAVE_FILENAME_ENC)
        self.decoder.save(SAVE_FILENAME_DEC)

class displayImage(callbacks.Callback):
    def __init__(self, decoder, emb_dim):
        super().__init__()
        self.emb_dim = emb_dim
        self.decoder = decoder
       

    def on_epoch_end(self, epoch, logs = None):
        fake_vect = tf.random.uniform(shape=(10, self.emb_dim))
        fake_image = self.decoder.predict(fake_vect)
        display(fake_image)
        
save_callback = saveCallback(encoder, decoder)
display_image = displayImage(decoder, EMB_DIM)


In [None]:
# Start the training
autoencoder.fit(x = x_train, y = x_train, batch_size = BATCH_SIZE, 
                epochs = EPOCHS, shuffle=True, validation_data = (x_test, x_test),
                callbacks = [checkpoint, save_callback, earlystop, display_image])

In [None]:
# Generate test images
train_set = x_train[:10]
test_set = x_test[:10]
fake_vect = tf.random.uniform(shape=(10, EMB_DIM), minval=0, maxval=1)

train_img = autoencoder.predict(train_set)
test_img = autoencoder.predict(test_set)
train_scatter = encoder.predict(x_train[:500])
test_scatter = encoder.predict(x_test[:500])
fake_img = decoder.predict(fake_vect)

print("Train Images")
display(train_set)
display(train_img)

print("\nTest Images")
display(test_set)
display(test_img)

print("\nGenerated Images")
display(fake_img)


fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(train_scatter[:, 0], train_scatter[:, 1], train_scatter[:, 2], color='blue', label='Training Data')
ax.scatter(test_scatter[:, 0], test_scatter[:, 1], test_scatter[:, 2], color='green', label='Training Data')
ax.scatter(fake_vect[:, 0], fake_vect[:, 1], fake_vect[:, 2], color='red', label='Training Data')
ax.set_title("Latent Space Plot")
plt.show()