# 1-3 GAN

<img src="./img/gan.png" alt="autoencoder" width="500" align="left"/>

In [None]:
import os
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import gridspec as gridspec

In [None]:
CKPT_DIR = '../generated_output/GAN'

In [None]:
LEARNING_RATE = 1e-4
TRAINING_STEPS = 30000
BATCH_SIZE = 100
TRAINING_SAMPLES = TRAINING_STEPS * BATCH_SIZE
TRAINING_EPOCHS = TRAINING_SAMPLES / 60000

In [None]:
IMAGE_DIM = 784
NOISE_DIM = 100
GEN_HIDDEN_DIM = [256]
DISC_HIDDEN_DIM = [256]
graph = tf.Graph()

In [None]:
def progress_bar(current, total, prefix='', suffix='', decimals=1, length=50, bar=u"\u25AF", fill=u"\u25AE"):
    percent = ("{0:." + str(decimals) + "f}").format(100 * (current / float(total)))
    filledLength = int(length * current // total)
    bar = fill * filledLength + bar * (length - filledLength)
    print('\r%s [%s] %s%% %s' % (prefix, bar, percent, suffix), end = '\r')
    if current == total: 
        print()

<img src="./img/gan_loss.png" alt="ganloss" width="800" align="left"/>

In [None]:
def disc_model(features):
    with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
        net = features
        for units in DISC_HIDDEN_DIM:
            net = tf.layers.dense(
                net, units=units, 
                activation=tf.nn.relu, 
                kernel_initializer=tf.initializers.he_normal())
        net = tf.layers.dense(
            net, 1, 
            activation=tf.nn.sigmoid, 
            kernel_initializer=tf.initializers.he_normal())
        return net

In [None]:
def gen_model(features):
    with tf.variable_scope('generator', reuse=tf.AUTO_REUSE):
        net = features
        for units in GEN_HIDDEN_DIM:
            net = tf.layers.dense(
                net, units=units, 
                activation=tf.nn.relu, 
                kernel_initializer=tf.initializers.he_normal())
        net = tf.layers.dense(
            net, IMAGE_DIM, 
            activation=tf.nn.sigmoid, 
            kernel_initializer=tf.initializers.he_normal())
        return net

In [None]:
def train_input_fn(features, batch_size=BATCH_SIZE):
    with graph.as_default():
        dataset = tf.data.Dataset.from_tensor_slices(features)
        batch_dataset = dataset.shuffle(features.shape[0]).repeat().batch(batch_size)
        batch = batch_dataset.make_one_shot_iterator().get_next()
        return batch

In [None]:
def train(features):
    if not os.path.exists(os.path.dirname(CKPT_DIR)):
        os.makedirs(os.path.dirname(CKPT_DIR))
        
    with graph.as_default():
        features = train_input_fn(features)
        real_image = features
        fake_noise = tf.random.uniform(
            shape=[BATCH_SIZE, NOISE_DIM], 
            minval=-1., maxval=1., dtype=tf.float32)
        fake_image = gen_model(fake_noise)
        disc_real = disc_model(real_image)
        disc_fake = disc_model(fake_image)
        disc_loss = -tf.reduce_mean(tf.log(disc_real) + tf.log(1. - disc_fake))
        gen_loss = tf.reduce_mean(tf.log(1. - disc_fake))
        optimizer_disc = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
        optimizer_gen = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
        disc_train_op = optimizer_disc.minimize(
            disc_loss,
            var_list=tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope="discriminator"))
        gen_train_op = optimizer_gen.minimize(
            gen_loss, 
            var_list=tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope="generator"))
        
        accuracy = tf.metrics.accuracy(
            labels=tf.zeros(shape=[BATCH_SIZE], dtype=tf.float32),
            predictions=tf.cast((disc_fake > 0.5),tf.float32),
            name='acc_op')
        gen_image = tf.reshape(fake_image, [-1, 28, 28, 1])
        tf.summary.scalar('loss_gen', gen_loss)
        tf.summary.scalar('loss_disc', disc_loss)
        tf.summary.scalar('accuracy', accuracy[1])
        tf.summary.image('gen_image', gen_image, max_outputs=1)
        merged = tf.summary.merge_all()
        saver = tf.train.Saver()
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(CKPT_DIR, sess.graph)
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            for step in range(TRAINING_STEPS):
                train_step = step + 1
                train_sample = train_step * BATCH_SIZE
                train_epoch = train_sample / 60000
                sess.run([disc_train_op, gen_train_op])
                if (train_step % (TRAINING_STEPS/10) == 0):
                    summary = sess.run(merged)
                    summary_writer.add_summary(summary, step)
                if (train_step == TRAINING_STEPS):
                    saver.save(sess, CKPT_DIR + '/gan.ckpt')
                progress_bar(
                    train_step, 
                    TRAINING_STEPS, 
                    prefix='>>> Training', 
                    suffix='steps: %i/%i, samples: %i/%i, epochs: %i/%i' % (
                        train_step, 
                        TRAINING_STEPS,
                        train_sample, 
                        TRAINING_SAMPLES,
                        train_epoch,
                        TRAINING_EPOCHS))
            
            print('>>> Training Done')

In [None]:
x_train = tf.keras.datasets.mnist.load_data()[0][0] / 255.
x_train = x_train.reshape([-1, IMAGE_DIM]).astype(np.float32)

In [None]:
train(x_train)

In [None]:
def random_25_image_plot(seed=None):
    with graph.as_default():
        np.random.seed(seed)
        random_noise = np.random.uniform(-1., 1., size=[25, NOISE_DIM]).astype(np.float32)
        random_noise_input = train_input_fn(random_noise, batch_size=25)
        random_gen = gen_model(random_noise_input)

        fig = plt.figure(figsize=(10, 10))
        gs = gridspec.GridSpec(5, 5)
        gs.update(wspace=0.05)

        saver = tf.train.Saver()
        with tf.Session() as sess:
            saver.restore(sess, tf.train.latest_checkpoint(CKPT_DIR))
            random_image = sess.run(random_gen)
            random_image = random_image.reshape([-1, 28, 28])
            for i in range(25):
                plt.subplot(gs[i])
                plt.axis('off')
                plt.imshow(random_image[i], cmap = 'gray')

In [None]:
random_25_image_plot()