# MNIST Wasserstein GAN (Gradient Penalty)

## Notebook setup

In [None]:
# noqa
import os
COLAB = 'DATALAB_DEBUG' in os.environ

if COLAB:
    #!apt-get update && apt-get install git
    !rm -rf bstrap
    !git clone https://gist.github.com/oskopek/e27ca34cb2b813cae614520e8374e741 bstrap
    
    import bstrap.bootstrap as bootstrap
    import bstrap.drive_utils as drive_utils
    drive_u = drive_utils

else:
    wd = %%pwd
    if wd.endswith('notebooks'):
        print('Current directory:', wd)
        %cd ..
        %pwd
    
    import resources.our_colab_utils.bootstrap as bootstrap
    drive_u = None

bootstrap.bootstrap(branch='master', packages='dotmap==1.2.20 keras==2.1.4 pydicom==1.0.2 Pillow==5.0.0', drive_utils=drive_u)

if COLAB:
    !rm -rf bstrap

## Actual notebook

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow.contrib.summary as tf_summary

from models.base import BaseModel
from resources.data.utils import next_batch
from resources.data.mnist import read_mnist
from resources.model_utils import noise, tile_images

# Flags
from flags import flags_parser
flags_parser.parse('flags/gan.json', None)
FLAGS = flags_parser.FLAGS
assert FLAGS is not None

In [None]:
he_normal = tf.contrib.layers.variance_scaling_initializer(uniform=False, factor=2.0, mode='FAN_IN', dtype=tf.float32)

In [None]:
# Discriminator
def discriminator(X, model, reuse, batch_norm=True):
    with tf.variable_scope("Discriminator", reuse=reuse):
        # Conv 1
        outputs = tf.layers.conv2d(X, 64, [5, 5], strides=(2, 2), padding='SAME', kernel_initializer=he_normal)
        if batch_norm:
            outputs = tf.contrib.layers.batch_norm(
                outputs,
                decay=0.9,
                updates_collections=None,
                epsilon=1e-5,
                center=True,
                scale=True,
                is_training=model.training)
        outputs = tf.nn.leaky_relu(outputs)

        # Conv 2
        outputs = tf.layers.conv2d(outputs, 128, [5, 5], strides=(2, 2), padding='SAME', kernel_initializer=he_normal)
        if batch_norm:
            outputs = tf.contrib.layers.batch_norm(
                outputs,
                decay=0.9,
                updates_collections=None,
                epsilon=1e-5,
                center=True,
                scale=True,
                is_training=model.training)
        outputs = tf.nn.leaky_relu(outputs)

        # Conv 3
        outputs = tf.layers.conv2d(outputs, 256, [5, 5], strides=(2, 2), padding='SAME', kernel_initializer=he_normal)
        if batch_norm:
            outputs = tf.contrib.layers.batch_norm(
                outputs,
                decay=0.9,
                updates_collections=None,
                epsilon=1e-5,
                center=True,
                scale=True,
                is_training=model.training)
        outputs = tf.nn.leaky_relu(outputs)

        # Conv 4
        outputs = tf.layers.conv2d(outputs, 512, [5, 5], strides=(2, 2), padding='SAME', kernel_initializer=he_normal)
        if batch_norm:
            outputs = tf.contrib.layers.batch_norm(
                outputs,
                decay=0.9,
                updates_collections=None,
                epsilon=1e-5,
                center=True,
                scale=True,
                is_training=model.training)
        outputs = tf.nn.leaky_relu(outputs)

        # FC layer
        batch_size = outputs.get_shape()[0].value
        reshape = tf.reshape(outputs, [-1, 8192])
        outputs = tf.layers.dense(reshape, 1)
        return outputs


In [None]:
# Generator
def generator(X, model, reuse=False, batch_norm=True):
    stride = 4
    with tf.variable_scope('Generator', reuse=reuse):
        outputs = tf.layers.dense(X, 512 * stride * stride, kernel_initializer=he_normal)
        outputs = tf.reshape(outputs, [-1, stride, stride, 512])
        if batch_norm:
            outputs = tf.contrib.layers.batch_norm(
                outputs,
                decay=0.9,
                updates_collections=None,
                epsilon=1e-5,
                center=True,
                scale=True,
                is_training=model.training)
        outputs = tf.nn.relu(outputs)

        # Deconv 1
        outputs = tf.layers.conv2d_transpose(
            outputs, 256, [5, 5], strides=(2, 2), padding='SAME', kernel_initializer=he_normal)
        if batch_norm:
            outputs = tf.contrib.layers.batch_norm(
                outputs,
                decay=0.9,
                updates_collections=None,
                epsilon=1e-5,
                center=True,
                scale=True,
                is_training=model.training)
        outputs = tf.nn.relu(outputs)

        # Deconv 2
        outputs = tf.layers.conv2d_transpose(
            outputs, 128, [5, 5], strides=(2, 2), padding='SAME', kernel_initializer=he_normal)
        if batch_norm:
            outputs = tf.contrib.layers.batch_norm(
                outputs,
                decay=0.9,
                updates_collections=None,
                epsilon=1e-5,
                center=True,
                scale=True,
                is_training=model.training)
        outputs = tf.nn.relu(outputs)

        # Deconv 3
        outputs = tf.layers.conv2d_transpose(
            outputs, 64, [5, 5], strides=(2, 2), padding='SAME', kernel_initializer=he_normal)
        if batch_norm:
            outputs = tf.contrib.layers.batch_norm(
                outputs,
                decay=0.9,
                updates_collections=None,
                epsilon=1e-5,
                center=True,
                scale=True,
                is_training=model.training)
        outputs = tf.nn.relu(outputs)

        # Deconv 4
        outputs = tf.layers.conv2d_transpose(
            outputs, 1, [5, 5], strides=(2, 2), padding='SAME', kernel_initializer=he_normal)

        outputs = tf.tanh(outputs)
        return outputs


In [None]:
# Shortcut for cross-entropy loss calculation.
def cross_entropy_loss(logits=None, labels=None):
    #return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))
    return tf.reduce_mean(logits)

In [None]:
# Model
class Gan(BaseModel):
    # Setup constants
    IMAGE_SIZE = 28
    NEW_IMAGE_SIZE = 64
    IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
    NOISE_SIZE = 100

    def __init__(self):
        super(Gan, self).__init__(
            logdir_name=FLAGS.data.out_dir,
            checkpoint_dirname=FLAGS.training.checkpoint_dir,
            expname="GAN",
            threads=FLAGS.training.threads,
            seed=FLAGS.training.seed)
        with self.session.graph.as_default():
            self._build()
            self._init_variables()

    # Construct the graph
    def _build(self):
        self.d_step = tf.Variable(0, dtype=tf.int64, trainable=False, name="d_step")
        self.g_step = tf.Variable(0, dtype=tf.int64, trainable=False, name="g_step")

        self.images_input_pl = tf.placeholder(tf.float32, shape=(None, self.IMAGE_PIXELS))
        print(self.images_input_pl.get_shape())
        self.images_input = tf.reshape(self.images_input_pl, (-1, self.IMAGE_SIZE, self.IMAGE_SIZE, 1))
        print(self.images_input.get_shape())
        self.images_input = tf.image.resize_images(self.images_input, [self.NEW_IMAGE_SIZE, self.NEW_IMAGE_SIZE])
        print(self.images_input.get_shape())
        self.images_input = (self.images_input - 0.5) * 2.0
        print(self.images_input.get_shape())

        self.noise_input = tf.placeholder(tf.float32, shape=(None, self.NOISE_SIZE))
        self.training = tf.placeholder_with_default(False, shape=())
        self.noise_input_interpolated = tf.placeholder(tf.float32, shape=(None, self.NOISE_SIZE))

        # Losses
        g_sample = generator(self.noise_input, self, reuse=False, batch_norm=True)
        print("images_input", self.images_input.get_shape())
        d_real = discriminator(self.images_input, self, reuse=False, batch_norm=False)
        print("g_sample", self.images_input.get_shape())
        d_fake = discriminator(g_sample, self, reuse=True, batch_norm=False)

        d_loss_real = cross_entropy_loss(logits=d_real, labels=tf.ones_like(d_real))
        d_loss_fake = cross_entropy_loss(logits=d_fake, labels=tf.zeros_like(d_fake))
        self.d_loss = -d_loss_real + d_loss_fake
        self.g_loss = -d_loss_fake  #cross_entropy_loss(logits=d_fake, labels=tf.ones_like(d_fake))

        # Gradient penalty (WGAN-GP)
        BATCH_SIZE = FLAGS.model.optimization.batch_size
        alpha = tf.random_uniform(shape=[BATCH_SIZE, 1, 1, 1], minval=0., maxval=1.)
        differences = g_sample - self.images_input
        print("diff shape", differences.get_shape())
        interpolates = self.images_input + (alpha * differences)
        gradients = tf.gradients(discriminator(interpolates, self, reuse=True, batch_norm=False), [interpolates])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
        LAMBDA = 10
        self.d_loss += LAMBDA * gradient_penalty

        # Test summaries
        tiled_image_random = tile_images(g_sample, 6, 6, self.NEW_IMAGE_SIZE, self.NEW_IMAGE_SIZE)
        tiled_image_interpolated = tile_images(
            generator(self.noise_input_interpolated, self, reuse=True), 6, 6, self.NEW_IMAGE_SIZE, self.NEW_IMAGE_SIZE)
        with self.summary_writer.as_default(), tf_summary.always_record_summaries():
            gen_image_summary_op = tf_summary.image(
                'generated_images', tiled_image_random, max_images=1, step=self.g_step)
            gen_image_summary_interpolated_op = tf_summary.image(
                'generated_images_interpolated', tiled_image_interpolated, max_images=1, step=self.g_step)
            self.IMAGE_SUMMARIES = [gen_image_summary_op, gen_image_summary_interpolated_op]

        # Optimizers
        t_vars = tf.trainable_variables()
        LEARNING_RATE = FLAGS.model.optimization.learning_rate
        LEARNING_RATE = 1e-3
        self.clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in t_vars if 'Discriminator' in p.name]
        self.d_opt = tf.train.AdamOptimizer(
            LEARNING_RATE, beta1=0.5, beta2=0.9).minimize(
                self.d_loss, var_list=[var for var in t_vars if 'Discriminator' in var.name], global_step=self.d_step)
        self.g_opt = tf.train.AdamOptimizer(
            LEARNING_RATE, beta1=0.5, beta2=0.9).minimize(
                self.g_loss, var_list=[var for var in t_vars if 'Generator' in var.name], global_step=self.g_step)

        # saver = tf.train.Saver(max_to_keep=1) # TODO(jendelel): Set up saver.

    def train_batch(self, batch):
        BATCH_SIZE = FLAGS.model.optimization.batch_size

        # 1. Train Discriminator
        for i in range(5):
            batch_noise = noise((BATCH_SIZE, self.NOISE_SIZE))
            feed_dict = {self.images_input_pl: batch, self.noise_input: batch_noise, self.training: True}
            d_error, _ = self.session.run([self.d_loss, self.d_opt], feed_dict=feed_dict)

        # 2. Train Generator
        for i in range(1):
            feed_dict = {self.noise_input: batch_noise, self.training: True}
            g_error, _ = self.session.run([self.g_loss, self.g_opt], feed_dict=feed_dict)

        return d_error, g_error

    # Generate images from test noise
    def test_eval(self, noise_input, noise_input_interpolated):
        self.session.run(
            self.IMAGE_SUMMARIES,
            feed_dict={
                self.noise_input: noise_input,
                self.noise_input_interpolated: noise_input_interpolated
            })

    def run(self):
        BATCH_SIZE = FLAGS.model.optimization.batch_size
        train_X, train_Y = read_mnist(FLAGS.data.in_dir, no_gpu=FLAGS.training.no_gpu)

        test_noise_random = noise(size=(FLAGS.eval.num_test_samples, self.NOISE_SIZE), dist='uniform')
        test_noise_interpolated = noise(size=(FLAGS.eval.num_test_samples, self.NOISE_SIZE), dist='linspace')

        # Iterate through epochs
        for epoch in range(FLAGS.model.optimization.epochs):
            print("Epoch %d" % epoch, flush=True)
            for n_batch, batch in enumerate(next_batch(train_X, BATCH_SIZE)):
                if len(batch) != 100:
                    print("Batch size: ", len(batch))
                d_error, g_error = self.train_batch(batch)

                # Test noise
                #if n_batch % FLAGS.training.log_interval == 0:
                if n_batch % 50 == 0:
                    self.test_eval(test_noise_random, test_noise_interpolated)
                    print(
                        "Epoch: {}, Batch: {}, D_Loss: {}, G_Loss: {}".format(epoch, n_batch, d_error, g_error),
                        flush=True)

            if epoch % FLAGS.training.save_interval == 0:
                self.saver.save(self.session, os.path.join(self.logdir, FLAGS.training.checkpoint_dir, "model.ckpt"))


In [None]:
# Run
Gan().run()