# Importing Libraries

In [1]:
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Input
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
tf.__version__

# Load Dataset

In [2]:
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

In [3]:
(X_train.shape, Y_train.shape), (X_test.shape, Y_test.shape)

# Visualize Images

In [4]:
i = np.random.randint(0, 59999)
plt.imshow(X_train[i], cmap='gray');
print(Y_train[i])

In [5]:
#10 x 10 images subplot
width = 10
height = 10
fig, axes = plt.subplots(height, width, figsize=(20, 20))
axes = axes.ravel() # 10 x 10 -> 100 x 1
for i in np.arange(0, width* height):
    index = np.random.randint(0, 59999)
    axes[i].imshow(X_train[index], cmap='gray')
    axes[i].set_title(Y_train[index], fontsize=8)
    axes[i].axis('off')
plt.subplots_adjust(hspace=0.4)

# Preprocessing Images

In [6]:
X_train[0].min(),X_train[0].max() #scale to 0 to 1
X_train = X_train / 255
X_test = X_test / 255
X_train[0].min(),X_train[0].max() #scale to 0 to 1

In [7]:
(X_train.shape, Y_train.shape), (X_test.shape, Y_test.shape)

In [8]:
#flattening

X_train = X_train.reshape(X_train.shape[0], X_train.shape[1]*X_train.shape[2])
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1]*X_test.shape[2])
(X_train.shape, Y_train.shape), (X_test.shape, Y_test.shape)

# Building Linear Autoencoder

### 784 -> 128 -> 64 -> 32 ===> 64 -> 128 -> 784

In [9]:
autoencoder = Sequential()

# Encode
autoencoder.add(Dense(units=128, activation='relu', input_dim=784))
autoencoder.add(Dense(units=64, activation='relu'))
autoencoder.add(Dense(units=32, activation='relu'))

# Decode
autoencoder.add(Dense(units=64, activation='relu'))
autoencoder.add(Dense(units=128, activation='relu'))
autoencoder.add(Dense(units=784, activation='sigmoid')) #activation is sigmoid bcuz we need pixel values in the range [0,1] (normalized image)
autoencoder.summary()

In [10]:
autoencoder.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
autoencoder.fit(X_train, X_train, epochs=50) #since we need the same image

Accuracy will be lower only.<br>
The way to validate the model is by comparing the difference between original and decoded image (not by accuracy)

# Encoding Images
### 784 -> 128 -> 64 -> 32

In [11]:
autoencoder.summary()

In [12]:
autoencoder.input, autoencoder.get_layer('dense_2').output

In [13]:
encoder = Model(inputs = autoencoder.input, outputs=autoencoder.get_layer('dense_2').output)
encoder.summary()

In [14]:
plt.imshow(X_test[0].reshape(28,28), cmap='gray');

In [15]:
X_test[0].shape

In [16]:
encoded_image = encoder.predict(X_test[0].reshape(1, 784))
plt.imshow(encoded_image.reshape(8,4), cmap='gray');
encoded_image.shape

# Decoding Images
### 64 -> 128 -> 784

In [17]:
autoencoder.summary()

In [18]:
#need to create custom input layer and separate layers
input_layer_decoder = Input(shape=(32,))
decoder_layer1 = autoencoder.get_layer('dense_3') # or autoencoder.layers[3]
decoder_layer2 = autoencoder.layers[4]
decoder_layer3 = autoencoder.layers[5]

# combine all layers
decoder = Model(inputs=input_layer_decoder, outputs=decoder_layer3(decoder_layer2(decoder_layer1(input_layer_decoder))))
decoder.summary()

In [19]:
decoded_image = decoder.predict(encoded_image)
print(decoded_image.shape)
plt.imshow(decoded_image.reshape(28,28), cmap='gray');

# Encode and Decode Test Images

In [20]:
n_images = 10
test_images = np.random.randint(0, X_test.shape[0]-1, size=n_images)
test_images
plt.figure(figsize=(20,20))
for i, image_index in enumerate(test_images):
    #original images
    ax = plt.subplot(10, 10, i+1)
    plt.imshow(X_test[image_index].reshape(28,28), cmap='gray')
    plt.axis('off')
    
    #encoded images
    ax = plt.subplot(10, 10, i+1+n_images)
    encoded_image = encoder.predict(X_test[image_index].reshape(1, -1))
    plt.imshow(encoded_image.reshape(8,4), cmap='gray')
    plt.axis('off')
    
    #decoded images
    ax = plt.subplot(10, 10, i+1+2*n_images)
    decoded_image = decoder.predict(encoded_image)
    plt.imshow(decoded_image.reshape(28,28), cmap='gray')
    plt.axis('off')