## LSTM AE

This notebook demonstrates how train a LSTM Autoencoder on the MNIST dataset.

### Your task
1.   Learn and summarize LSTM AE with respect to: Architecture, Cost Funstion, Latent Space, Reparameterization, etc..
2.   Try to change the dataset to MNISt Denoising and Face-Sketch datasets



## Load Dataset

In [None]:
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt

(x_train, y_train), (x_test, y_train) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28,28, 1))
x_test = np.reshape(x_test, (len(x_test), 28,28, 1))
print(x_train.shape)
print(x_test.shape)


## Define LSTM AE Model

In [2]:
import keras
from keras import layers

timesteps = 28  # Length of your sequences
input_dim = 28
latent_dim = 32

inputs = keras.Input(shape=(timesteps, input_dim))
encoded = layers.LSTM(latent_dim)(inputs)

decoded = layers.RepeatVector(timesteps)(encoded)
decoded = layers.LSTM(input_dim, return_sequences=True)(decoded)

model_lstmAE = keras.Model(inputs, decoded)
model_lstmAE_encoder = keras.Model(inputs, encoded)

## Training and Testing

In [None]:
from keras.callbacks import TensorBoard

model_lstmAE.compile(optimizer='adam', loss='binary_crossentropy')

model_lstmAE.fit(x_train, x_train,
                epochs=50,
                batch_size=128,
                shuffle=True,
                validation_data=(x_test, x_test),
                callbacks=[TensorBoard(log_dir='/tmp/model_convAE')])

In [None]:
plt.plot(model_lstmAE.history.history["loss"])
plt.plot(model_lstmAE.history.history["val_loss"])

In [None]:
decoded_imgs = model_lstmAE.predict(x_test)

n = 10
plt.figure(figsize=(20, 4))
for i in range(1, n + 1):
    # Display original
    ax = plt.subplot(2, n, i)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # Display reconstruction
    ax = plt.subplot(2, n, i + n)
    plt.imshow(decoded_imgs[i])
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

In [None]:
encoded_imgs = model_lstmAE_encoder.predict(x_test)

n = 10
plt.figure(figsize=(20, 8))
for i in range(1, n + 1):
    ax = plt.subplot(1, n, i)
    plt.imshow(encoded_imgs[i].reshape(8, 4))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()