In [1]:
'''Trains a denoising autoencoder on MNIST dataset.
Denoising is one of the classic applications of autoencoders.
The denoising process removes unwanted noise that corrupted the
true signal.
Noise + Data ---> Denoising Autoencoder ---> Data
Given a training dataset of corrupted data as input and
true signal as output, a denoising autoencoder can recover the
hidden structure to generate clean data.
This example has modular design. The encoder, decoder and autoencoder
are 3 models that share weights. For example, after training the
autoencoder, the encoder can be used to  generate latent vectors
of input data for low-dim visualization like PCA or TSNE.
adapted from https://github.com/keras-team/keras/blob/master/examples/mnist_denoising_autoencoder.py
'''
import tensorflow as tf
mnist = tf.keras.datasets.mnist
Model = tf.keras.models.Model
Input = tf.keras.layers.Input
Dense = tf.keras.layers.Dense
Flatten = tf.keras.layers.Flatten
Conv2D = tf.keras.layers.Conv2D 
MaxPooling2D = tf.keras.layers.MaxPooling2D
K = tf.keras.backend
Reshape = tf.keras.layers.Reshape
Conv2DTranspose = tf.keras.layers.Conv2DTranspose
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image

  from ._conv import register_converters as _register_converters


In [2]:
np.random.seed(1337)

In [3]:
# MNIST dataset
(x_train, _), (x_test, _) = mnist.load_data()

In [5]:
image_size = x_train.shape[1]
#assume channel is third dimension of input
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_test = np.reshape(x_test, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# Generate corrupted MNIST images by adding noise with normal dist
# centered at 0.5 and std=0.5
train_noise = np.random.normal(loc=0.5, scale=0.5, size=x_train.shape)
x_train_noisy = x_train + train_noise
test_noise = np.random.normal(loc=0.5, scale=0.5, size=x_test.shape)
x_test_noisy = x_test + test_noise

#if < 0 -> 0, if > 1 -> 1
x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)

In [10]:
# Network parameters
input_shape = (image_size, image_size, 1)
batch_size = 128
kernel_size = 3
latent_dim = 16
filters1 = 32
filters2 = 64
stride = 2

In [15]:
# Build the Autoencoder Model
def cnn_encode():
    """First build the Encoder Model"""
    a0 = Input(shape=input_shape, name='encoder_input')
    # Stack of Conv2D blocks
    # Notes:
    # 1) Use Batch Normalization before ReLU on deep networks
    # 2) Use MaxPooling2D as alternative to strides>1
    # - faster but not as good as strides>1
    a1 = Conv2D(filters=filters1,
                   kernel_size=kernel_size,
                   strides=stride,
                   activation='relu',
                   padding='same')(a0)
    a2 = Conv2D(filters=filters2,
                   kernel_size=kernel_size,
                   strides=stride,
                   activation='relu',
                   padding='same')(a1)
    

    # Shape info needed to build Decoder Model
    a2_shape = K.int_shape(a2)

    # Generate the latent vector (encoding)
    a2 = Flatten()(a2)
    latent = Dense(latent_dim, name='latent_vector')(a2) 
    return Model(inputs = a0, outputs = latent, name='encoder'), a2_shape

In [17]:
# Instantiate Encoder Model
encoder, a2_shape = cnn_encode()
encoder.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder_input (InputLayer)   (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 14, 14, 32)        320       
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 7, 7, 64)          18496     
_________________________________________________________________
flatten_4 (Flatten)          (None, 3136)              0         
_________________________________________________________________
latent_vector (Dense)        (None, 16)                50192     
Total params: 69,008
Trainable params: 69,008
Non-trainable params: 0
_________________________________________________________________


In [23]:
def cnn_decode():
    """Next build the Decoder Model"""
    latent = Input(shape=(latent_dim,), name='decoder_input')
    a2 = Dense(a2_shape[1] * a2_shape[2] * a2_shape[3])(latent)
    a2 = Reshape((a2_shape[1], a2_shape[2], a2_shape[3]))(a2)
    a1 = Conv2DTranspose(filters=filters2,
                        kernel_size=kernel_size,
                        strides=stride,
                        activation='relu',
                        padding='same')(a2)
    a0 = Conv2DTranspose(filters=filters1,
                        kernel_size=kernel_size,
                        strides=stride,
                        activation='relu',
                        padding='same')(a1)
    output = Conv2DTranspose(filters=1,
                        kernel_size=kernel_size,
                        padding='same',
                        activation = 'sigmoid')(a0)
    return Model(inputs = latent, outputs = output, name = 'decoder')

In [24]:
# Instantiate Decoder Model
decoder = cnn_decode()
decoder.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
decoder_input (InputLayer)   (None, 16)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 3136)              53312     
_________________________________________________________________
reshape_2 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 14, 14, 64)        36928     
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 28, 28, 32)        18464     
_________________________________________________________________
conv2d_transpose_6 (Conv2DTr (None, 28, 28, 1)         289       
Total params: 108,993
Trainable params: 108,993
Non-trainable params: 0
_________________________________________________________________


In [31]:
# Autoencoder = Encoder + Decoder
# Instantiate Autoencoder Model
#encoder.input returns a0 tensor
#encoder(encoder.input) returns latent tensor
#decoder(encoder(encoder.input)) returns output tensor
autoencoder = Model(inputs = encoder.input, outputs = decoder(encoder(encoder.input)), name='autoencoder')
autoencoder.summary()
autoencoder.compile(loss='mse', optimizer='adam')

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder_input (InputLayer)   (None, 28, 28, 1)         0         
_________________________________________________________________
encoder (Model)              (None, 16)                69008     
_________________________________________________________________
decoder (Model)              (None, 28, 28, 1)         108993    
Total params: 178,001
Trainable params: 178,001
Non-trainable params: 0
_________________________________________________________________
Instructions for updating:
keep_dims is deprecated, use keepdims instead


In [None]:
# Train the autoencoder
autoencoder.fit(x_train_noisy,
                x_train,
                validation_data=(x_test_noisy, x_test),
                epochs=10,
                batch_size=batch_size)

Train on 60000 samples, validate on 10000 samples
Epoch 1/10

Epoch 2/10

Epoch 3/10

Epoch 4/10

Epoch 5/10
 4864/60000 [=>............................] 4864/60000 [=>............................] - ETA: 6:23 - loss: 0.0210

In [None]:
# Predict the Autoencoder output from corrupted test images
x_decoded = autoencoder.predict(x_test_noisy)

In [None]:
# Display the 1st 8 corrupted and denoised images
rows, cols = 10, 30
num = rows * cols
imgs = np.concatenate([x_test[:num], x_test_noisy[:num], x_decoded[:num]])
imgs = imgs.reshape((rows * 3, cols, image_size, image_size))
imgs = np.vstack(np.split(imgs, rows, axis=1))
imgs = imgs.reshape((rows * 3, -1, image_size, image_size))
imgs = np.vstack([np.hstack(i) for i in imgs])
imgs = (imgs * 255).astype(np.uint8)
plt.figure()
plt.axis('off')
plt.title('Original images: top rows, '
          'Corrupted Input: middle rows, '
          'Denoised Input:  third rows')
plt.imshow(imgs, interpolation='none', cmap='gray')
Image.fromarray(imgs).save('corrupted_and_denoised.png')
plt.show()