Simple Auto Encoder

In [0]:
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow.keras as keras
import tensorflow.keras.layers as layers


### loading data ###
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train.reshape((-1, 28*28)) / 255.0
x_test = x_test.reshape((-1, 28*28)) / 255.0

Build Model & Train Model

In [0]:
code_dim = 32
auto_encoder = keras.Sequential(
[
    layers.Dense(code_dim, input_shape = (x_train.shape[1],), activation = 'relu', name = 'encode'),
    layers.Dense(x_train.shape[1], activation = 'softmax', name = 'output')
])

encoder_input = layers.Input(shape = (x_train.shape[1],))
encoder_output = auto_encoder.get_layer('encode')(encoder_input)
encoder = keras.Model(encoder_input, encoder_output)

decoder_input = keras.Input((code_dim,))
decoder_output = auto_encoder.get_layer('output')(decoder_input)
decoder = keras.Model(decoder_input, decoder_output)

auto_encoder.compile(optimizer = 'adam', loss = 'binary_crossentropy')

history = auto_encoder.fit(x_train, x_train, batch_size = 256, epochs = 5, validation_split = 0.1)

encoded = encoder.predict(x_test)
decoded = decoder.predict(encoded)

Plot orignal and decoded data

In [0]:
plt.figure(figsize=(10,4))

n = 5
for i in range(n):
    ax = plt.subplot(2, n, i+1)
    plt.imshow(x_test[i].reshape(28,28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    ax = plt.subplot(2, n, n+i+1)
    plt.imshow(decoded[i].reshape(28,28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

Auto Encoder with CNN

In [0]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

(train_x, train_y), (test_x, test_y) = keras.datasets.mnist.load_data()

train_x = train_x.astype('float32') / 255.
test_x = test_x.astype('float32') / 255.
train_x = np.reshape(train_x, (len(train_x), 28, 28, 1))
test_x = np.reshape(test_x, (len(test_x), 28, 28, 1))

Build Model & Train Model

In [0]:
input_img = layers.Input(shape=(28, 28, 1))

x = layers.Conv2D(16, (3, 3), activation = 'relu', padding = 'same')(input_img)
x = layers.MaxPool2D((2, 2), padding = 'same')(x)
x = layers.Conv2D(8, (3, 3), activation = 'relu', padding = 'same')(x)
x = layers.MaxPool2D((2, 2), padding = 'same')(x)
x = layers.Conv2D(8, (3, 3), activation = 'relu', padding = 'same')(x)
encoded = layers.MaxPool2D((2, 2), padding = 'same')(x)

x = layers.Conv2D(8, (3, 3), activation = 'relu', padding = 'same')(encoded)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(8, (3, 3), activation = 'relu', padding = 'same')(x)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(16, (3, 3), activation = 'relu')(x)
x = layers.UpSampling2D((2, 2))(x)
decoded = layers.Conv2D(1, (3, 3), activation = 'sigmoid', padding = 'same')(x)

cnn_autoencoder = keras.Model(input_img, decoded)

# cnn_autoencoder.summary()
cnn_autoencoder.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics=['accuracy'])

cnn_autoencoder.fit(train_x, train_x,
                    epochs = 5,
                    batch_size = 128,
                    shuffle = True,
                    validation_data = (test_x, test_x))

In [0]:
### plot orignal data and decoded data ###

decoded_imgs = cnn_autoencoder.predict(test_x)

n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(test_x[i+20].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + n + 1)
    plt.imshow(decoded_imgs[i+20].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

In [0]:
encoder = keras.Model(inputs = input_img, outputs = encoded)
encoded_imgs = encoder.predict(test_x)

In [0]:
### plot encoded data ###
n = 10
plt.figure(figsize=(20, 8))
for i in range(n):
    ax = plt.subplot(1, n, i + 1)
    plt.imshow(encoded_imgs[i].reshape(4, 4 * 8).T)
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()