## Autoencoder model using mnist

In [None]:
import os
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import TensorBoard
from datetime import datetime
import shutil

In [None]:
#Load the dataset
def preprocess_image(features):

    image = tf.cast(features['image'], tf.float32) / 255.0
    image = tf.image.resize(image, [28,28])
    image = tf.expand_dims(image[:, :, 0], -1)

    return image, image # Returns input and output the same for the autoencoder

ds_train = tfds.load('mnist', split='train', as_supervised=False)
ds_train = ds_train.map(preprocess_image, num_parallel_calls = tf.data.AUTOTUNE).batch(64)

ds_test = tfds.load('mnist', split='test', as_supervised=False)
ds_test = ds_test.map(preprocess_image, num_parallel_calls = tf.data.AUTOTUNE).batch(64)

In [None]:
# Define architecture
input_img = Input(shape = (28,28,1))

# Encoder
x = Conv2D(512,(3,3), activation = 'relu', padding = 'same')(input_img)
x = MaxPooling2D((2,2), padding = 'same')(x)
x = Conv2D(128, (3,3), activation = 'relu', padding = 'same')(x)
encoded = MaxPooling2D((2,2), padding='same')(x)

In [None]:
# Decoder
x = Conv2D(128,(3,3), activation = 'relu', padding = 'same')(encoded)
x = UpSampling2D((2,2))(x)
x = Conv2D(512,(3,3), activation = 'relu', padding = 'same')(x)
x = UpSampling2D((2,2))(x)
decoded = Conv2D(1,(3,3), activation='sigmoid', padding='same')(x)

In [None]:
#The model
autoencoder = Model(input_img, decoded)

In [None]:
# Training
autoencoder.compile(optimizer = 'adam', loss = 'binary_crossentropy')
 
#Setup for profiling
log_dir = './logs/' + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(log_dir = log_dir, histogram_freq=1, profile_batch='500,520')

#Train the model
autoencoder.fit(ds_train,
                epochs = 10,
                validation_data = ds_test,
                callbacks = [tensorboard_callback])

#Copy file "events.out.tfevents.1583461681.localhost.profile-empty" to each recorded log for display data
source_path = "./events.out.tfevents.1583461681.localhost.profile-empty"
destination_path = log_dir
shutil.copy(source_path, destination_path)

In [None]:
#Show the images

import matplotlib.pyplot as plt
import numpy as np
import tensorflow_datasets as tfds

# Extract a single test image
for test_images, _ in ds_test.take(30):
    test_image = test_images[0:1]

reconstructed_image = autoencoder.predict(test_image)

#Plot original image

fig, axes = plt.subplots(1,2)
axes[0].imshow(test_image[0,:,:,0], cmap='gray')
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(reconstructed_image[0,:,:,0], cmap='gray')
axes[1].set_title('Reconstructed Image')
axes[1].axis('off')

plt.show()

## Load TensorBoard

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir='./logs' --bind_all --port 6006 #will take a few seconds to show