<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc" style="margin-top: 1em;"><ul class="toc-item"><li><span><a href="#Data" data-toc-modified-id="Data-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Data</a></span></li><li><span><a href="#Model" data-toc-modified-id="Model-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Model</a></span></li><li><span><a href="#Training" data-toc-modified-id="Training-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Training</a></span></li><li><span><a href="#Test/Generate" data-toc-modified-id="Test/Generate-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Test/Generate</a></span></li><li><span><a href="#[TOFIX]-Training---Estimator" data-toc-modified-id="[TOFIX]-Training---Estimator-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>[TOFIX] Training - Estimator</a></span></li><li><span><a href="#[TOFIX]-Contrib-GAN" data-toc-modified-id="[TOFIX]-Contrib-GAN-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>[TOFIX] Contrib GAN</a></span></li></ul></div>

In [None]:
import yaml
import tensorflow as tf
import numpy as np
#tf.enable_eager_execution()
import functools

from pathlib import Path
import matplotlib.pyplot as plt
import time
from tqdm import tqdm_notebook as tqdm

from IPython.display import clear_output

%load_ext autoreload
%autoreload 2

import dcgan
import gan_utils

In [None]:
model_dir = Path.home() / "Documents/models/tf_playground/dcgan"
#model_dir = Path("/notebooks/models/dcgan")

In [None]:
with open('mnist_config.yaml', 'r') as f:
    config = yaml.load(f)

# Data

In [None]:
data_conf = config['data']
data_conf

In [None]:
INPUT_SHAPE = data_conf['input_shape']
Z_SHAPE = (data_conf['z_size'], )
BATCH_SIZE = config['training']['batch_size']

In [None]:
#(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

In [None]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

In [None]:
def preprocess_images(images):
    images = images.reshape(images.shape[0], *INPUT_SHAPE).astype('float32')
    images = (images - 127.5) / 127.5 # Normalize the images to [-1, 1]
    return images

In [None]:
train_images = preprocess_images(train_images)
test_images = preprocess_images(test_images)

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(data_conf['buffer_size'])\
                                                                .batch(BATCH_SIZE).take(512)
test_dataset = tf.data.Dataset.from_tensor_slices(test_images).take(128).batch(1)

# Model

In [None]:
real_input = tf.placeholder(tf.float32, name='real_input')
input_noise = tf.placeholder(tf.float32, shape=(None, data_conf['z_size']), name='input_noise')

In [None]:
generator = dcgan.get_generator(Z_SHAPE, **config['model']['generator'])
discriminator = dcgan.get_discriminator(INPUT_SHAPE, **config['model']['discriminator'])

In [None]:
# TODO validate generator output shape equal to discriminator input shape
generator.summary()

In [None]:
D_real = discriminator(real_input)
G_z = generator(input_noise)

D_fake = discriminator(G_z)

# Generator and Discrimnator Losses
G_loss = dcgan.generator_loss(D_fake)
D_loss = dcgan.discriminator_loss(D_real, D_fake)

In [None]:
#generator_optimizer = tf.train.AdamOptimizer(config['training']['generator']['learning_rate'])
#discriminator_optimizer = tf.train.AdamOptimizer(config['training']['discriminator']['learning_rate'])

# Training

In [None]:
train_discriminator = dcgan.get_discriminator_train(discriminator, D_loss,
                                                       config['training']['discriminator'])
train_generator = dcgan.get_generator_train(generator, G_loss,
                                                       config['training']['generator'])

In [None]:
EPOCHS = 40
CHECKPOINT_STEPS = config['training']['checkpoint_steps']
PLOT_SAMPLE_SIZE = config['training']['plot_sample_size']

In [None]:
LOAD_WEIGHTS = False

iterator = train_dataset.make_initializable_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    # tensorboard
    summ_writer = tf.summary.FileWriter(str(model_dir), sess.graph)
    gen_loss_summary = tf.summary.scalar('gen_loss', G_loss)
    disc_loss_summary = tf.summary.scalar('disc_loss', D_loss)
    
    test_gen_images = tf.placeholder(tf.float32, name='gen_images', shape=(1, 288, 288, 4))
    summary_imgs = tf.summary.image("plot", test_gen_images)
    
    z_fixed = np.random.normal(size=(PLOT_SAMPLE_SIZE, data_conf['z_size']))
    in_noise = tf.random_normal([BATCH_SIZE, data_conf['z_size']])
    
    sess.run(tf.global_variables_initializer())
    
    if LOAD_WEIGHTS:
        generator.load_weights(str(model_dir / "generator"))
        discriminator.load_weights(str(model_dir / "discriminator"))
        
    for epoch in tqdm(range(EPOCHS)):
        sess.run(iterator.initializer)
        
        # train across entire dataset
        batch_num = 0
        while True:
            try:
                # TODO cleaner way to connect input_image and noise to dataset and function
                input_image = sess.run(next_element)
                noise = sess.run(in_noise)
                g_loss_summ, d_loss_summ, _, _ = sess.run([gen_loss_summary, disc_loss_summary, 
                                                           train_discriminator, train_generator], 
                        feed_dict={real_input: input_image, input_noise: noise})
                batch_num += 1
                if batch_num >= 1:
                     break
            except tf.errors.OutOfRangeError:
                break
        
        summ_writer.add_summary(g_loss_summ, epoch)
        summ_writer.add_summary(d_loss_summ, epoch)
        
        predictions = sess.run(generator(test_noise), {test_noise: z_fixed})
        plot_buf = gan_utils.display_prediction(predictions, epoch)
        gen_images = tf.image.decode_png(plot_buf.getvalue(), channels=4)
        gen_images = tf.expand_dims(gen_images, 0)
        
        summary_imgs_val = sess.run(summary_imgs, {test_gen_images:gen_images.eval()})
        summ_writer.add_summary(summary_imgs_val, epoch)
        
        # saving checkpoint
        if (epoch + 1) % CHECKPOINT_STEPS == 0:
            # TODO rely on TF instead
            # TODO Validate if training params are also loaded
            # TODO export by epoch
            generator.save_weights(str(model_dir / "generator"))
            discriminator.save_weights(str(model_dir / "discriminator"))

# Test/Generate

In [None]:
generator.load_weights(str(model_dir / "generator"))
test_noise = np.random.rand(PLOT_SAMPLE_SIZE, data_conf['z_size'])
predictions = generator.predict(test_noise)
gan_utils.display_prediction(predictions, 0)