In [1]:
import tensorflow as tf
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

In [None]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, InputLayer, Dense

In [None]:
(X_train, _), (X_test, _) = mnist.load_data()
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
X_train = X_train.reshape((len(X_train), np.prod(X_train.shape[1:])))
X_test = X_test.reshape((len(X_test), np.prod(X_test.shape[1:])))

In [None]:
autoencoder = Sequential()

# Coder
autoencoder.add(InputLayer(input_shape=(784,)))
autoencoder.add(Dense(128, activation='relu'))
autoencoder.add(Dense(64, activation='relu'))
autoencoder.add(Dense(32, activation='relu'))

# Decoder
autoencoder.add(Dense(64, activation='relu'))
autoencoder.add(Dense(128, activation='relu'))
autoencoder.add(Dense(784, activation='sigmoid'))

autoencoder.summary()

In [None]:
autoencoder.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

In [None]:
autoencoder.fit(X_train, X_train, epochs=50, 
                batch_size=256, validation_data=(X_test, X_test))

In [None]:
original_dim = Input(shape=(784,))
encoder_layer1 = autoencoder.layers[0]
encoder_layer2 = autoencoder.layers[1]
encoder_layer3 = autoencoder.layers[2]
encoder = Model(original_dim, encoder_layer3(encoder_layer2(encoder_layer1(original_dim))))
encoder.summary()

In [None]:
coded_imgs = encoder.predict(X_test)

In [None]:
decoded_imgs = autoencoder.predict(X_test)

In [None]:
num_images = 10
test_images = np.random.randint(X_test.shape[0], size=num_images)
plt.figure(figsize=(18, 18))
for i, image_idx in enumerate(test_images):
    # display original
    ax = plt.subplot(10, 10, i + 1)
    plt.imshow(X_test[image_idx].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    
    # display encoded
    ax = plt.subplot(10, 10, i + 1 + num_images)
    plt.imshow(coded_imgs[image_idx].reshape(8, 4))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    
    # display reconstruction
    ax = plt.subplot(10, 10, i + 1 + 2*num_images)
    plt.imshow(decoded_imgs[image_idx].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    