In [None]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

In [None]:
mnist = tf.keras.datasets.mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = np.concatenate((X_train, X_test))
X_train = (X_train - 127.5) / 127.5

In [None]:
class Generator:
    
    def __init__(self):
        fc = tf.keras.layers.Dense(7*7*128, activation=tf.nn.relu, name="fc_generator")
        
        reshape = tf.keras.layers.Reshape((7, 7, 128))

        conv_t1 = tf.keras.layers.Conv2DTranspose(filters=64,
                                                  kernel_size=(5,5),
                                                  strides=(1,1),
                                                  padding='same',
                                                  name="conv_t_generator1")
        bn1 = tf.keras.layers.BatchNormalization()
        act1 = tf.keras.layers.ReLU()

        conv_t2 = tf.keras.layers.Conv2DTranspose(filters=32,
                                                  kernel_size=(5,5),
                                                  strides=(2,2),
                                                  padding='same',
                                                  name="conv_t_generator2")
        bn2 = tf.keras.layers.BatchNormalization()
        act2 = tf.keras.layers.ReLU()

        conv_t3 = tf.keras.layers.Conv2DTranspose(filters=1,
                                                  kernel_size=(5,5),
                                                  strides=(2,2),
                                                  activation=tf.nn.tanh,
                                                  padding='same',
                                                  name="conv_t_generator3")

        self.layers = [fc, reshape, conv_t1, bn1, act1, conv_t2, bn2, act2, conv_t3]
        
    def generate(self, rand_noise):
        x = rand_noise
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
class Critic:
    
    def __init__(self):
        dropout1 = tf.keras.layers.Dropout(rate=0.2)
        conv1 = tf.keras.layers.Conv2D(filters=32,
                                       kernel_size=(5,5),
                                       strides=(3,3),
                                       activation=tf.nn.leaky_relu,
                                       name="conv_critic1")

        dropout2 = tf.keras.layers.Dropout(rate=0.2)
        conv2 = tf.keras.layers.Conv2D(filters=64,
                                       kernel_size=(5,5),
                                       strides=(3,3),
                                       activation=tf.nn.leaky_relu,
                                       name="conv_critic2")

        flatten = tf.keras.layers.Flatten()
        
        fc = tf.keras.layers.Dense(units=1, name="fc_discriminator")

        self.layers = [dropout1, conv1, dropout2, conv2, flatten, fc]
            
    def evaluate(self, image):
        x = image
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
class WGAN:
    
    def __init__(self, original_im_shape, dim_noise):
        self.original_im_shape = original_im_shape
        self.dim_noise = dim_noise

        with tf.variable_scope("WGAN"):
            self.generator = Generator()
            self.critic = Critic()

            # Data from mnist
            self.original_image = tf.placeholder(tf.float32, (None, *(self.original_im_shape)), name="original_image")
            self.batch_size = tf.placeholder(tf.int64, None, name="batch_size")
            self.dataset = tf.data.Dataset.from_tensor_slices(self.original_image).shuffle(10000).batch(self.batch_size).repeat()
            self.iterator = self.dataset.make_initializable_iterator()

            self.original_image_exp = tf.expand_dims(self.iterator.get_next(), -1)

            # Sample and generate fake images
            with tf.variable_scope("generator"):
                #self.rand_noise = tf.random_uniform((self.batch_size, self.dim_noise), minval=-1, maxval=1, name="rand_noise")
                self.rand_noise = tf.random_normal((tf.shape(self.original_image_exp)[0], self.dim_noise), name="rand_noise")
                self.generated_images = self.generator.generate(self.rand_noise)

            # Use critic
            with tf.variable_scope("critic"):
                self.score_real = self.critic.evaluate(self.original_image_exp)
                self.score_fake = self.critic.evaluate(self.generated_images)
                
            # Gradient Penalty
            epsilon = tf.random_uniform((), 0., 1.)
            x_between = epsilon * self.original_image_exp + (1 - epsilon) * self.generated_images
            score_between = self.critic.evaluate(x_between)
            grad_between = tf.gradients(score_between, x_between)
            self.grad_penalty = tf.square(tf.norm(grad_between[0], 2) - 1)
            
            # Compute losses
            self.lambda_gp = tf.placeholder(tf.float32, None, name="lambda_gp")
            self.loss_generator = - tf.reduce_mean(self.score_fake)
            self.loss_critic = - tf.reduce_mean(self.score_real - self.score_fake) + self.lambda_gp * self.grad_penalty

            # Separate trainable variables
            self.generator_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="WGAN/generator")
            self.critic_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="WGAN/critic")

            # Optimization
            self.learning_rate = tf.placeholder(tf.float32, None, name="learning_rate")
            self.optimizer_generator = tf.train.AdamOptimizer(self.learning_rate, 0, 0.9)
            self.optimizer_critic = tf.train.AdamOptimizer(self.learning_rate, 0, 0.9)

            self.generator_train_op = self.optimizer_generator.minimize(self.loss_generator, var_list=self.generator_variables)
            self.critic_train_op = self.optimizer_critic.minimize(self.loss_critic, var_list=self.critic_variables)
            
            # Summaries   
            tf.summary.scalar("loss_generator", self.loss_generator)
            tf.summary.scalar("loss_critic", self.loss_critic)
            tf.summary.scalar("grad_penalty", self.grad_penalty)
            tf.summary.image("generated_images", (self.generated_images + 1) / 2, 16)
            self.merged_summaries = tf.summary.merge_all()
            
            self.saver = tf.train.Saver()
            
    def train(self, X_train, batch_size, nb_steps, critic_steps, learning_rate, lambda_gp, log_every, save_every, sess):
        summary_writer = tf.summary.FileWriter("./tensorboard/", sess.graph)

        sess.run(self.iterator.initializer, feed_dict={self.original_image: X_train,
                                                       self.batch_size: batch_size})
        
        for step in range(1, nb_steps + 1):
            # Train discriminator
            for k in range(critic_steps):
                _ = sess.run(self.critic_train_op,
                             feed_dict={self.learning_rate: learning_rate,
                                        self.lambda_gp: lambda_gp,
                                        self.batch_size: batch_size})

            # Train generator
            _, summaries = sess.run([self.generator_train_op, self.merged_summaries],
                                     feed_dict={self.learning_rate: learning_rate,
                                                self.lambda_gp: lambda_gp,
                                                self.batch_size: batch_size})
            
            if step % log_every == 0:
                print("Write summaries")
                summary_writer.add_summary(summaries, step)
        
            if step % save_every == 0:
                print("Save model")
                self.saver.save(sess, "./model/model.ckpt")
        
    def restore(self, sess, ckpt_file):
        self.saver.restore(sess, ckpt_file)

In [None]:
original_im_shape = (28, 28)
dim_noise = 100

batch_size = 128
learning_rate = 2e-4

critic_steps = 5
lambda_gp = 10

nb_steps = 50000
log_every = 250
save_every = 1000

In [None]:
wgan = WGAN(original_im_shape, dim_noise)

In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    wgan.train(X_train, batch_size, nb_steps, critic_steps, learning_rate, lambda_gp, log_every, save_every, sess)