In [1]:
import numpy as np

from keras.layers import Dense, Input, Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose

from keras.models import Model
from keras import backend as K

from keras.datasets import mnist

import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

# Denoising autoencoder with Keras

## Load, examine & transform the data

In [2]:
(X_train, _), (X_test, _) = mnist.load_data()

In [3]:
X_train.shape

(60000, 28, 28)

In [7]:
# Reshape to add a color channel 
X_train = X_train.reshape([-1, 28, 28, 1])
X_test = X_test.reshape([-1, 28, 28, 1])

In [8]:
# Get image size 
img_size = X_train.shape[1]

In [15]:
# Normalize the data
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255

* ### Get noisy images

In [32]:
# Generate train noise
noise = np.random.normal(loc   = .5, 
                         scale = .5, 
                         size  = X_train.shape)

In [33]:
X_train_noisy = X_train + noise

In [34]:
# Generate test noise
noise_2 = np.random.normal(loc   = .5, 
                           scale = .5, 
                           size  = X_test.shape)

In [35]:
X_test_noisy = X_test + noise_2

In [36]:
# Clip the values to 0 - 1 range
X_train_noisy = np.clip(X_train_noisy, 0, 1)
X_test_noisy  = np.clip(X_test_noisy, 0, 1)

## Define network params

In [37]:
input_shape = (img_size, img_size, 1)
batch_size = 32
kernel_size = 3
latent_dim = 16
layer_filters = [32, 64]

## Build the model

* ### Encoder

In [44]:
inputs = Input(shape = input_shape, 
               name  = 'encoder_input')

x = inputs

In [45]:
# Build a stack of Conv2D layers

for filters in layer_filters:
    x = Conv2D(filters     = filters, 
               kernel_size = kernel_size, 
               strides     = 2, 
               activation  = 'relu',
               padding     = 'same')(x)

In [46]:
# Get encoder's output shape
enc_out_shape = K.int_shape(x)

In [47]:
enc_out_shape

(None, 7, 7, 64)

In [48]:
# Build the latent vector

x = Flatten()(x)
latent = Dense(latent_dim, 
               name = 'latent_vector')(x)

In [49]:
# Instantiate encoder model

encoder = Model(inputs, latent, 
                name = 'encoder')

encoder.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder_input (InputLayer)   (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 14, 14, 32)        320       
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 7, 7, 64)          18496     
_________________________________________________________________
flatten_3 (Flatten)          (None, 3136)              0         
_________________________________________________________________
latent_vector (Dense)        (None, 16)                50192     
Total params: 69,008
Trainable params: 69,008
Non-trainable params: 0
_________________________________________________________________


* ### Decoder

In [50]:
decoder_inputs = Input(shape = (latent_dim,),
                       name  = 'decoder_input')

In [51]:
x = Dense(enc_out_shape[1] * enc_out_shape[2] * enc_out_shape[3])(decoder_inputs)

In [52]:
# Reshape input
x = Reshape((enc_out_shape[1], enc_out_shape[2], enc_out_shape[3]))(x)

In [53]:
# Build a stack of Conv2DTransp.

for filters in layer_filters[::-1]:
    x = Conv2DTranspose(filters     = filters, 
                        kernel_size = kernel_size,
                        strides     = 2,
                        activation  = 'relu',
                        padding     = 'same')(x)

In [54]:
# Reconstruct the denoised input
outputs = Conv2DTranspose(filters     = 1,
                          kernel_size = kernel_size,
                          padding     = 'same',
                          activation  = 'sigmoid',
                          name        = 'decoder_output')(x)

In [55]:
# Instantiate decoder model
decoder = Model(decoder_inputs, outputs, 
                name = 'decoder')

In [56]:
decoder.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
decoder_input (InputLayer)   (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 3136)              53312     
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 7, 7, 64)          36928     
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 7, 7, 32)          18464     
_________________________________________________________________
decoder_output (Conv2DTransp (None, 7, 7, 1)           289       
Total params: 108,993
Trainable params: 108,993
Non-trainable params: 0
_________________________________________________________________
