In [1]:
import numpy as np

from keras.layers import Dense, Input, Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose

from keras.models import Model
from keras import backend as K

from keras.datasets import mnist

import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

# Denoising autoencoder with Keras

## Load, examine & transform the data

In [2]:
(X_train, _), (X_test, _) = mnist.load_data()

In [3]:
X_train.shape

(60000, 28, 28)

In [7]:
# Reshape to add a color channel 
X_train = X_train.reshape([-1, 28, 28, 1])
X_test = X_test.reshape([-1, 28, 28, 1])

In [8]:
# Get image size 
img_size = X_train.shape[1]

In [15]:
# Normalize the data
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255

* ### Get noisy images

In [32]:
# Generate train noise
noise = np.random.normal(loc   = .5, 
                         scale = .5, 
                         size  = X_train.shape)

In [33]:
X_train_noisy = X_train + noise

In [34]:
# Generate test noise
noise_2 = np.random.normal(loc   = .5, 
                           scale = .5, 
                           size  = X_test.shape)

In [35]:
X_test_noisy = X_test + noise_2

In [36]:
# Clip the values to 0 - 1 range
X_train_noisy = np.clip(X_train_noisy, 0, 1)
X_test_noisy  = np.clip(X_test_noisy, 0, 1)

## Define network params

In [37]:
input_shape = (img_size, img_size, 1)
batch_size = 32
kernel_size = 3
latent_dim = 16
layer_filters = [32, 64]

## Build the model

* ### Encoder

In [57]:
inputs = Input(shape = input_shape, 
               name  = 'encoder_input')

x = inputs

In [58]:
# Build a stack of Conv2D layers

for filters in layer_filters:
    x = Conv2D(filters     = filters, 
               kernel_size = kernel_size, 
               strides     = 2, 
               activation  = 'relu',
               padding     = 'same')(x)

In [59]:
# Get encoder's output shape
enc_out_shape = K.int_shape(x)

In [60]:
enc_out_shape

(None, 7, 7, 64)

In [61]:
# Build the latent vector

x = Flatten()(x)
latent = Dense(latent_dim, 
               name = 'latent_vector')(x)

In [62]:
# Instantiate encoder model

encoder = Model(inputs, latent, 
                name = 'encoder')

encoder.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder_input (InputLayer)   (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 14, 14, 32)        320       
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 7, 7, 64)          18496     
_________________________________________________________________
flatten_4 (Flatten)          (None, 3136)              0         
_________________________________________________________________
latent_vector (Dense)        (None, 16)                50192     
Total params: 69,008
Trainable params: 69,008
Non-trainable params: 0
_________________________________________________________________


* ### Decoder

In [63]:
decoder_inputs = Input(shape = (latent_dim,),
                       name  = 'decoder_input')

In [64]:
x = Dense(enc_out_shape[1] * enc_out_shape[2] * enc_out_shape[3])(decoder_inputs)

In [65]:
# Reshape input
x = Reshape((enc_out_shape[1], enc_out_shape[2], enc_out_shape[3]))(x)

In [66]:
# Build a stack of Conv2DTransp.

for filters in layer_filters[::-1]:
    x = Conv2DTranspose(filters     = filters, 
                        kernel_size = kernel_size,
                        strides     = 2,
                        activation  = 'relu',
                        padding     = 'same')(x)

In [67]:
# Reconstruct the denoised input
outputs = Conv2DTranspose(filters     = 1,
                          kernel_size = kernel_size,
                          padding     = 'same',
                          activation  = 'sigmoid',
                          name        = 'decoder_output')(x)

In [68]:
# Instantiate decoder model
decoder = Model(decoder_inputs, outputs, 
                name = 'decoder')

In [69]:
decoder.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
decoder_input (InputLayer)   (None, 16)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 3136)              53312     
_________________________________________________________________
reshape_2 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 14, 14, 64)        36928     
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 28, 28, 32)        18464     
_________________________________________________________________
decoder_output (Conv2DTransp (None, 28, 28, 1)         289       
Total params: 108,993
Trainable params: 108,993
Non-trainable params: 0
_________________________________________________________________


* ### Autoencoder

In [71]:
autoencoder = Model(inputs, decoder(encoder(inputs)),
                    name = 'autoencoder')

In [72]:
autoencoder.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder_input (InputLayer)   (None, 28, 28, 1)         0         
_________________________________________________________________
encoder (Model)              (None, 16)                69008     
_________________________________________________________________
decoder (Model)              (None, 28, 28, 1)         108993    
Total params: 178,001
Trainable params: 178,001
Non-trainable params: 0
_________________________________________________________________


In [73]:
# Compile the model
autoencoder.compile(loss      = 'mse',
                    optimizer = 'adam')

In [74]:
# Training 
autoencoder.fit(X_train_noisy, X_train,
                validation_data = (X_test_noisy, X_test),
                epochs          = 2, 
                batch_size      = batch_size)

Train on 60000 samples, validate on 10000 samples
Epoch 1/2
  352/60000 [..............................] - ETA: 5:40:27 - loss: 0.1942

KeyboardInterrupt: 