# 1-1 AutoEncoder

<img src="./img/ae.png" alt="autoencoder" width="500" align="left"/>

In [None]:
import tensorflow as tf
import numpy as np
import os
from matplotlib import pyplot as plt
from matplotlib import gridspec as gridspec

In [None]:
CKPT_DIR = "../generated_output/AE"

In [None]:
LEARNING_RATE = 0.0002
TRAINING_STEPS = 30000
BATCH_SIZE = 128

In [None]:
IMAGE_DIM = 784
LATENT_DIM = 128
ENDOCER_HIDDEN_DIM = [256]
DECODER_HIDDEN_DIM = [256]

[batch_size, 784]

$\rightarrow$ Dense(784, 256) $\rightarrow$ relu $\rightarrow$ [batch_size, 256]

$\rightarrow$ Dense(256, 128) $\rightarrow$ sigmoid $\rightarrow$ [batch_size, 128] 

In [None]:
def encoder_model(feature, encoder_hidden_dim=ENDOCER_HIDDEN_DIM, latent_dim=LATENT_DIM):
    with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
        net = feature
        for units in encoder_hidden_dim:
            net = tf.layers.Dense(units, activation=tf.nn.relu, kernel_initializer=tf.initializers.he_normal())(net)
        net = tf.layers.Dense(latent_dim, activation=tf.nn.sigmoid, kernel_initializer=tf.initializers.he_normal())(net)
        return net

[batch_size, 128]

$\rightarrow$ Dense(128, 256) $\rightarrow$ relu $\rightarrow$ [batch_size, 256]

$\rightarrow$ Dense(256, 784) $\rightarrow$ sigmoid $\rightarrow$ [batch_size, 784] 

In [None]:
def decoder_model(feature, decoder_hidden_dim=DECODER_HIDDEN_DIM, image_dim=IMAGE_DIM):
    with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
        net = feature
        for units in decoder_hidden_dim:
            net = tf.layers.Dense(units, activation=tf.nn.relu, kernel_initializer=tf.initializers.he_normal())(net)
        net = tf.layers.Dense(image_dim, activation=tf.nn.sigmoid, kernel_initializer=tf.initializers.he_normal())(net)
        return net

[batch_size, 784]

$\rightarrow$ Dense(784, 256) $\rightarrow$ relu $\rightarrow$ [batch_size, 256]

$\rightarrow$ Dense(256, 128) $\rightarrow$ sigmoid $\rightarrow$ [batch_size, 128] 

$\rightarrow$ Dense(128, 256) $\rightarrow$ relu $\rightarrow$ [batch_size, 256]

$\rightarrow$ Dense(256, 784) $\rightarrow$ sigmoid $\rightarrow$ [batch_size, 784] 

In [None]:
def ae_model_fn(input_shape, learning_rate=LEARNING_RATE):
    inputs = tf.keras.Input(shape=input_shape)
    latents = encoder_model(inputs)
    outputs = decoder_model(latents)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)

    model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),
              loss=tf.keras.losses.mean_squared_error)    
    return model

In [None]:
def data_input_fn(features, is_training, batch_size):
    if is_training == True:
        count = None
    else:
        count = 1
    dataset = tf.data.Dataset.from_tensor_slices((features, features))
    batch_dataset = dataset.shuffle(features.shape[0]).repeat(count=count).batch(batch_size)
    return batch_dataset.make_one_shot_iterator().get_next()

In [None]:
def train(features, batch_size=BATCH_SIZE, ckpt_dir=CKPT_DIR):
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    batch_x, batch_y = data_input_fn(features, is_training=True, batch_size=batch_size)
    model = ae_model_fn(features[0].shape)
    model.summary()
    cp_callback = tf.keras.callbacks.ModelCheckpoint(ckpt_dir+'/cp-{epoch:04d}.ckpt', verbose=1, period=1, save_weights_only=True)
    tb_callback = tf.keras.callbacks.TensorBoard(log_dir=ckpt_dir+'/Graph', histogram_freq=0, write_graph=True, write_images=True)
    model.fit(x=batch_x,y=batch_y, epochs=5, steps_per_epoch=469, callbacks=[cp_callback, tb_callback])

In [None]:
def predict(features, ckpt_dir=CKPT_DIR):
    features = np.expand_dims(features, axis=0)
    model = ae_model_fn(features[0].shape)
    model.load_weights(tf.train.latest_checkpoint(ckpt_dir))
    return model.predict(features)

In [None]:
def image_plot(true, recon):
    fig = plt.figure(figsize=(6, 3))
    gs = gridspec.GridSpec(1, 2)
    gs.update(wspace=0.05)
    plt.subplot(gs[0])
    plt.axis('off')
    plt.imshow(true.reshape([28, 28]), cmap = 'gray_r')
    plt.subplot(gs[1])
    plt.axis('off')
    plt.imshow(recon.reshape([28, 28]), cmap = 'gray_r')
    plt.show()

In [None]:
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.
x_test = x_test / 255.
x_train = x_train.reshape([-1, IMAGE_DIM]).astype(np.float32)
x_test = x_test.reshape([-1, IMAGE_DIM]).astype(np.float32)

In [None]:
train(x_train)

In [None]:
for i in range(10):
    j = np.random.randint(0,9999)
    image_plot(x_test[j], predict(x_test[j]))