In [0]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt


((train_x, train_label), (test_x, test_label)) = keras.datasets.cifar10.load_data()

img_rows = train_x.shape[1]
img_cols = train_x.shape[2]
channels = train_x.shape[3]

train_x = train_x.astype('float32') / 255.
test_x = test_x.astype('float32') / 255.

# train_x = np.reshape(train_x, (len(train_x), img_rows, img_cols, channels))
# test_x = np.reshape(test_x, (len(test_x), img_rows, img_cols, channels))

Build Model & Train Model

In [0]:
input_img = layers.Input(shape=(img_rows, img_cols, channels))

x = layers.Conv2D(16, (3, 3), activation = 'relu', padding = 'same')(input_img)
x = layers.MaxPool2D((2, 2), padding = 'same')(x)
x = layers.Conv2D(8, (3, 3), activation = 'relu', padding = 'same')(x)
x = layers.MaxPool2D((2, 2), padding = 'same')(x)
x = layers.Conv2D(8, (3, 3), activation = 'relu', padding = 'same')(x)
encoded = layers.MaxPool2D((2, 2), padding = 'same')(x)

x = layers.Conv2D(8, (3, 3), activation = 'relu', padding = 'same')(encoded)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(8, (3, 3), activation = 'relu', padding = 'same')(x)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(16, (3, 3), activation = 'relu', padding = 'same')(x)
x = layers.UpSampling2D((2, 2))(x)
decoded = layers.Conv2D(3, (3, 3), activation = 'sigmoid', padding = 'same')(x)

auto_encoder = keras.Model(input_img, decoded)

auto_encoder.compile(optimizer = 'adam', 
                 loss = 'binary_crossentropy')

auto_encoder.fit(train_x, train_x,
                 epochs = 50,
                 batch_size = 100,
                 shuffle = True,
                 validation_data = (test_x, test_x))


Plot orignal and decoded data

In [0]:
decoded_imgs = auto_encoder.predict(test_x)

n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(test_x[i].reshape(img_rows, img_cols, channels))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + n + 1)
    plt.imshow(decoded_imgs[i].reshape(img_rows, img_cols, channels))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

Plot encoded data

In [0]:
encoder = keras.Model(inputs = input_img, outputs = encoded)
encoded_imgs = encoder.predict(test_x)

### plot encoded data ###
n = 10
plt.figure(figsize=(20, 8))
for i in range(n):
    ax = plt.subplot(1, n, i + 1)
    plt.imshow(encoded_imgs[i].reshape(4, 4 * 8).T)
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()