In [None]:
# Varational AutoEncoder (no convolutional layers)
# https://blog.keras.io/building-autoencoders-in-keras.html

In [None]:
import keras
from keras import layers
from keras.datasets import mnist
import numpy as np

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [None]:
x_train = x_train.astype("float32") / 255.
x_test = x_test.astype("float32") / 255.

In [None]:
intermediate_dim = 64
latent_dim = 2
img_height = x_train[0].shape[0]
img_width = x_train[0].shape[1]
img_pixels = img_width*img_height
original_dim = img_pixels
print(img_pixels)

In [None]:
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

In [None]:
print(x_train.shape, x_test.shape)

In [None]:
inputs = keras.Input(shape=(original_dim,))
h = layers.Dense(intermediate_dim, activation='relu')(inputs)
z_mean = layers.Dense(latent_dim)(h)
z_log_sigma = layers.Dense(latent_dim)(h)
#encoded = layers.Dense(encoding_dim, activation="relu")(input_img)
#decoded = layers.Dense(img_pixels, activation="sigmoid")(encoded)
#autoencoder = keras.Model(input_img, decoded)

In [None]:
from keras import backend as K

def sampling(args):
    z_mean, z_log_sigma = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),
                              mean=0., stddev=0.1)
    return z_mean + K.exp(z_log_sigma) * epsilon

z = layers.Lambda(sampling)([z_mean, z_log_sigma])

In [None]:
encoder = keras.Model(inputs, [z_mean, z_log_sigma, z], name='encoder')

In [None]:
latent_inputs = keras.Input(shape=(latent_dim,), name='z_sampling')
x = layers.Dense(intermediate_dim, activation='relu')(latent_inputs)
outputs = layers.Dense(original_dim, activation='sigmoid')(x)
decoder = keras.Model(latent_inputs, outputs, name='decoder')

# instantiate VAE model
outputs = decoder(encoder(inputs)[2])
vae = keras.Model(inputs, outputs, name='vae_mlp')

In [None]:
#encoded_input = keras.Input(shape=(encoding_dim,))
#decoder_layer = autoencoder.layers[-1]
#decoder = keras.Model(encoded_input, decoder_layer(encoded_input))

In [None]:
#autoencoder.compile(optimizer="adam", loss="binary_crossentropy")

In [None]:
reconstruction_loss = keras.losses.binary_crossentropy(inputs, outputs)
reconstruction_loss *= original_dim
kl_loss = 1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')

In [None]:
#autoencoder.fit(x_train, x_train, epochs=20, batch_size=64, shuffle=True, validation_data=(x_test, x_test))
vae.fit(x_train, x_train,
        epochs=10,
        batch_size=32,
        validation_data=(x_test, x_test))

In [None]:
import matplotlib.pyplot as plt
#encoded_imgs = encoder.predict(x_test)
batch_size = 32
x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
print(x_test_encoded[0])

In [None]:
plt.figure(figsize=(6, 6))
plt.scatter(x_test_encoded[0], x_test_encoded[1], c=y_test)
plt.colorbar()
plt.show()

In [None]:
# Display a 2D manifold of the digits
n = 15  # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# We will sample n points within [-15, 15] standard deviations
grid_x = np.linspace(-15, 15, n)
grid_y = np.linspace(-15, 15, n)

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        x_decoded = decoder.predict(z_sample)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()


In [None]:
decoded_imgs = decoder.predict(x_test_encoded)

In [None]:
n = 10
plt.figure(figsize=(40,4))
for i in range(n):
  # original
  ax = plt.subplot(2, n, i+1)
  plt.imshow(x_test[i].reshape(img_height, img_width))
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
  
  # reconstruction
  ax = plt.subplot(2, n, i+1+n)
  plt.imshow(decoded_imgs[i].reshape(imt_height, img_width))
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()