# 1-1 AutoEncoder

<img src="./img/ae.png" alt="autoencoder" width="500" align="left"/>

In [None]:
import tensorflow as tf
import numpy as np
import os
from matplotlib import pyplot as plt
from matplotlib import gridspec as gridspec

In [None]:
CKPT_DIR = "../generated_output/AE"

In [None]:
LEARNING_RATE = 1e-4
TRAINING_STEPS = 3000
BATCH_SIZE = 100
TRAINING_SAMPLES = TRAINING_STEPS * BATCH_SIZE
TRAINING_EPOCHS = TRAINING_SAMPLES / 60000

In [None]:
IMAGE_DIM = 784
LATENT_DIM = 128
ENDOCER_HIDDEN_DIM = [256]
DECODER_HIDDEN_DIM = [256]
graph = tf.Graph()

In [None]:
def progress_bar(current, total, prefix='', suffix='', decimals=1, length=50, bar=u"\u25AF", fill=u"\u25AE"):
    percent = ("{0:." + str(decimals) + "f}").format(100 * (current / float(total)))
    filledLength = int(length * current // total)
    bar = fill * filledLength + bar * (length - filledLength)
    print('\r%s [%s] %s%% %s' % (prefix, bar, percent, suffix), end = '\r')
    if current == total: 
        print()

[batch_size, 784]

$\rightarrow$ Dense(784, 256) $\rightarrow$ relu $\rightarrow$ [batch_size, 256]

$\rightarrow$ Dense(256, 128) $\rightarrow$ sigmoid $\rightarrow$ [batch_size, 128]

In [None]:
def encoder_model(feature):
    with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
        net = feature
        for units in ENDOCER_HIDDEN_DIM:
            net = tf.layers.dense(
                net, units, 
                activation=tf.nn.relu, 
                kernel_initializer=tf.initializers.he_normal())
        net = tf.layers.dense(
            net, LATENT_DIM, 
            activation=tf.nn.sigmoid, 
            kernel_initializer=tf.initializers.he_normal())
        return net

[batch_size, 128]

$\rightarrow$ Dense(128, 256) $\rightarrow$ relu $\rightarrow$ [batch_size, 256]

$\rightarrow$ Dense(256, 784) $\rightarrow$ sigmoid $\rightarrow$ [batch_size, 784]

In [None]:
def decoder_model(feature):
    with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
        net = feature
        for units in DECODER_HIDDEN_DIM:
            net = tf.layers.dense(
                net, units, 
                activation=tf.nn.relu, 
                kernel_initializer=tf.initializers.he_normal())
        net = tf.layers.dense(
            net, IMAGE_DIM, 
            activation=tf.nn.sigmoid, 
            kernel_initializer=tf.initializers.he_normal())
        return net

In [None]:
def train_input_fn(features, batch_size=BATCH_SIZE):
    
    with graph.as_default():
        dataset = tf.data.Dataset.from_tensor_slices(features)
        batch_dataset = dataset.shuffle(features.shape[0]).repeat().batch(batch_size)
        batch = batch_dataset.make_one_shot_iterator().get_next()
        return batch

[batch_size, 784]

$\rightarrow$ Dense(784, 256) $\rightarrow$ relu $\rightarrow$ [batch_size, 256]

$\rightarrow$ Dense(256, 128) $\rightarrow$ sigmoid $\rightarrow$ [batch_size, 128]

$\rightarrow$ Dense(128, 256) $\rightarrow$ relu $\rightarrow$ [batch_size, 256]

$\rightarrow$ Dense(256, 784) $\rightarrow$ sigmoid $\rightarrow$ [batch_size, 784]

In [None]:
def train(features):
    if not os.path.exists(CKPT_DIR):
        os.makedirs(CKPT_DIR)
        
    with graph.as_default():
        features = train_input_fn(features)
        latents = encoder_model(features)
        outputs = decoder_model(latents)
        loss = tf.losses.mean_squared_error(features, outputs)
        optimizer = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)
        
        org_image = tf.reshape(features, [-1, 28, 28, 1])
        rec_image = tf.reshape(outputs, [-1, 28, 28, 1])
        tf.summary.scalar('loss', loss)
        tf.summary.image('org_image', org_image, max_outputs=1)            
        tf.summary.image('rec_image', rec_image, max_outputs=1)
        merged = tf.summary.merge_all()
        saver = tf.train.Saver()
        
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(CKPT_DIR, sess.graph)
            sess.run(tf.global_variables_initializer())

            for step in range(TRAINING_STEPS):
                train_step = step + 1
                train_sample = train_step * BATCH_SIZE
                train_epoch = train_sample / 60000
                sess.run(optimizer)
                if (train_step % (TRAINING_STEPS/10) == 0):
                    summary = sess.run(merged)
                    summary_writer.add_summary(summary, step)
                if (train_step == TRAINING_STEPS):
                    saver.save(sess, CKPT_DIR + '/ae.ckpt')

                progress_bar(
                    train_step, 
                    TRAINING_STEPS, 
                    prefix='>>> Training', 
                    suffix='steps: %i/%i, samples: %i/%i, epochs: %i/%i' % (
                        train_step, 
                        TRAINING_STEPS,
                        train_sample, 
                        TRAINING_SAMPLES,
                        train_epoch,
                        TRAINING_EPOCHS))
                
            print('>>> Training Done')

In [None]:
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.
x_test = x_test / 255.
x_train = x_train.reshape([-1, IMAGE_DIM]).astype(np.float32)
x_test = x_test.reshape([-1, IMAGE_DIM]).astype(np.float32)

In [None]:
train(x_train)

In [None]:
def random_25_image_plot(seed=None):
    with graph.as_default():
        np.random.seed(seed)
        random_noise = np.random.normal(size=[25, LATENT_DIM]).astype(np.float32)
        random_noise_input = train_input_fn(random_noise, batch_size=25)
        random_gen = decoder_model(random_noise_input)

        fig = plt.figure(figsize=(10, 10))
        gs = gridspec.GridSpec(5, 5)
        gs.update(wspace=0.05)

        saver = tf.train.Saver()
        with tf.Session() as sess:
            saver.restore(sess, tf.train.latest_checkpoint(CKPT_DIR))
            random_image = sess.run(random_gen)
            random_image = random_image.reshape([-1, 28, 28])
            for i in range(25):
                plt.subplot(gs[i])
                plt.axis('off')
                plt.imshow(random_image[i], cmap = 'gray')

In [None]:
random_25_image_plot()