In [None]:
# Convolutional AutoEncoder 
# https://blog.keras.io/building-autoencoders-in-keras.html

In [None]:
import keras
from keras import layers
from keras.datasets import mnist
import numpy as np

In [None]:
(x_train, _), (x_test, _) = mnist.load_data()

In [None]:
x_train = x_train.astype("float32") / 255.
x_test = x_test.astype("float32") / 255.

In [None]:
encoding_dim = 32
img_height = x_train[0].shape[0]
img_width = x_train[0].shape[1]
img_pixels = img_width*img_height
print(img_pixels)

In [None]:
x_train = x_train.reshape((len(x_train), (len(x_train), 28, 28, 1))
x_test = x_test.reshape((len(x_test), (len(x_test), 28, 28, 1))

In [None]:
print(x_train.shape, x_test.shape)

In [None]:
input_img = keras.Input(shape=(img_height, img_width, 1))

x = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = layers.MaxPooling2D((2, 2), padding='same')(x)

# at this point the representation is (4, 4, 8) i.e. 128-dimensional

x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(16, (3, 3), activation='relu')(x)
x = layers.UpSampling2D((2, 2))(x)
decoded = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = keras.Model(input_img, decoded)

In [None]:
#encoder = keras.Model(input_img, encoded)

In [None]:
#encoded_input = keras.Input(shape=(encoding_dim,))
#decoder_layer = autoencoder.layers[-1]
#decoder = keras.Model(encoded_input, decoder_layer(encoded_input))

In [None]:
autoencoder.compile(optimizer="adam", loss="binary_crossentropy")

In [None]:
autoencoder.fit(x_train, x_train, epochs=10, batch_size=64, shuffle=True, validation_data=(x_test, x_test))

In [None]:
#encoded_imgs = encoder.predict(x_test)

In [None]:
#decoded_imgs = decoder.predict(encoded_imgs)

In [None]:
decoded_imgs = autoencoder.predict(x_test)

In [None]:
import matplotlib.pyplot as plt
n = 10
plt.figure(figsize=(40,4))
for i in range(n):
  # original
  ax = plt.subplot(2, n, i+1)
  plt.imshow(x_test[i].reshape(img_height, img_width))
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
  
  # reconstruction
  ax = plt.subplot(2, n, i+1+n)
  plt.imshow(decoded_imgs[i].reshape(img_height, img_width))
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

In [None]:
encoder = keras.Model(input_img, encoded)
encoded_imgs = encoder.predict(x_test)

n = 10
plt.figure(figsize=(20, 8))
for i in range(1, n + 1):
    ax = plt.subplot(1, n, i)
    plt.imshow(encoded_imgs[i].reshape((4, 4 * 8)).T)
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()