In [1]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

In [2]:
mnist = tf.keras.datasets.mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = np.concatenate((X_train, X_test))
X_train = (X_train - 127.5) / 127.5

In [3]:
class Generator:
    
    def __init__(self):
        fc1 = tf.keras.layers.Dense(7*7*128, use_bias=False, name="fc_generator1")
        bn1 = tf.keras.layers.BatchNormalization()
        act1 = tf.keras.layers.ReLU()
        
        reshape = tf.keras.layers.Reshape((7, 7, 128))

        conv_t2 = tf.keras.layers.Conv2DTranspose(filters=64,
                                                  kernel_size=(5,5),
                                                  strides=(1,1),
                                                  padding='same',
                                                  name="conv_t_generator1")
        bn2 = tf.keras.layers.BatchNormalization()
        act2 = tf.keras.layers.ReLU()

        conv_t3 = tf.keras.layers.Conv2DTranspose(filters=32,
                                                  kernel_size=(5,5),
                                                  strides=(2,2),
                                                  padding='same',
                                                  name="conv_t_generator2")
        bn3 = tf.keras.layers.BatchNormalization()
        act3 = tf.keras.layers.ReLU()

        conv_t4 = tf.keras.layers.Conv2DTranspose(filters=1,
                                                  kernel_size=(5,5),
                                                  strides=(2,2),
                                                  padding='same',
                                                  name="conv_t_generator3")
        act4 = tf.keras.layers.Activation('tanh')

        self.layers = [fc1, bn1, act1, reshape, conv_t2, bn2, act2, conv_t3, bn3, act3, conv_t4, act4]
        
    def generate(self, rand_noise):
        x = rand_noise
        for layer in self.layers:
            x = layer(x)
        return x
    
class GeneratorMLP:
    
    def __init__(self):
        fc1 = tf.keras.layers.Dense(128, name="fc_generator1")
        bn1 = tf.keras.layers.BatchNormalization()
        act1 = tf.keras.layers.ReLU()

        fc2 = tf.keras.layers.Dense(28*28, name="fc_generator2")
        bn2 = tf.keras.layers.BatchNormalization()
        act2 = tf.keras.layers.Activation('tanh')

        reshape = tf.keras.layers.Reshape((28, 28, 1))

        self.layers = [fc1, bn1, act1, fc2, bn2, act2, reshape]
        
    def generate(self, rand_noise):
        x = rand_noise
        for layer in self.layers:
            x = layer(x)
        return x

In [4]:
class Critic:
        
    def __init__(self):
        conv1 = tf.keras.layers.Conv2D(filters=32,
                                       kernel_size=(5,5),
                                       strides=(2,2),
                                       activation=tf.nn.relu,
                                       name="conv_critic1")
        dropout1 = tf.keras.layers.Dropout(rate=0.3)

        conv2 = tf.keras.layers.Conv2D(filters=64,
                                       kernel_size=(5,5),
                                       strides=(2,2),
                                       activation=tf.nn.relu,
                                       name="conv_critic2")
        dropout2 = tf.keras.layers.Dropout(rate=0.3)

        fc = tf.keras.layers.Dense(units=1,
                                   activation=tf.nn.relu,
                                   name="fc_critic")

        self.layers = [conv1, dropout1, conv2, dropout2, fc]
            
    def evaluate(self, image):
        x = image
        for layer in self.layers:
            x = layer(x)
        return x
    
class CriticMLP:
        
    def __init__(self):
        flatten = tf.keras.layers.Flatten()
        
        dropout1 = tf.keras.layers.Dropout(rate=0.3)
        fc1 = tf.keras.layers.Dense(128, activation=tf.nn.relu, name="fc_critic1")
                
        dropout2 = tf.keras.layers.Dropout(rate=0.3)
        fc2 = tf.keras.layers.Dense(1, activation=tf.nn.relu, name="fc_critic2")

        self.layers = [flatten, dropout1, fc1, dropout2, fc2]
            
    def evaluate(self, image):
        x = image
        for layer in self.layers:
            x = layer(x)
        return x

In [5]:
class WGAN:
    
    def __init__(self, original_im_shape, dim_noise):
        self.original_im_shape = original_im_shape
        self.dim_noise = dim_noise

        with tf.variable_scope("WGAN"):
            self.generator = GeneratorMLP()
            self.critic = CriticMLP()

            # Data from mnist
            self.original_image = tf.placeholder(tf.float32, (None, *(self.original_im_shape)), name="original_image")
            self.batch_size = tf.placeholder(tf.int64, None, name="batch_size")
            self.dataset = tf.data.Dataset.from_tensor_slices(self.original_image).shuffle(10000).batch(self.batch_size).repeat()
            self.iterator = self.dataset.make_initializable_iterator()

            self.original_image_exp = tf.expand_dims(self.iterator.get_next(), -1)

            # Sample and generate fake images
            with tf.variable_scope("generator"):
                #self.rand_noise = tf.random_uniform((self.batch_size, self.dim_noise), minval=-1, maxval=1, name="rand_noise")
                self.rand_noise = tf.random_normal((self.batch_size, self.dim_noise), name="rand_noise")
                self.generated_images = self.generator.generate(self.rand_noise)

            # Use critic
            with tf.variable_scope("critic"):
                self.score_real = self.critic.evaluate(self.original_image_exp)
                self.score_fake = self.critic.evaluate(self.generated_images)

            # Compute losses
            self.loss_generator = - tf.reduce_mean(self.score_fake)
            self.loss_critic = - (tf.reduce_mean(self.score_real) - tf.reduce_mean(self.score_fake))
            #self.loss_discriminator = - tf.reduce_mean(tf.log(self.prob_true_real + 1e-8) +
            #                                           tf.log(1 - self.prob_true_fake + 1e-8))

            # Separate trainable variables
            self.generator_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="WGAN/generator")
            self.critic_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="WGAN/critic")

            # Optimization
            self.learning_rate = tf.placeholder(tf.float32, None, name="learning_rate")
            self.optimizer_generator = tf.train.RMSPropOptimizer(self.learning_rate)
            self.optimizer_critic = tf.train.RMSPropOptimizer(self.learning_rate)

            self.generator_train_op = self.optimizer_generator.minimize(self.loss_generator, var_list=self.generator_variables)
            self.critic_train_op = self.optimizer_critic.minimize(self.loss_critic, var_list=self.critic_variables)
            
            # Summaries   
            tf.summary.scalar("loss_generator", self.loss_generator)
            tf.summary.scalar("loss_critic", self.loss_critic)
            tf.summary.image("generated_images", (self.generated_images + 1) / 2, 16)
            self.merged_summaries = tf.summary.merge_all()
            
            self.saver = tf.train.Saver()
            
    def train(self, X_train, batch_size, nb_steps, learning_rate, critic_steps, save_every, sess):
        summary_writer = tf.summary.FileWriter("./tensorboard/", sess.graph)

        sess.run(self.iterator.initializer, feed_dict={self.original_image: X_train,
                                                       self.batch_size: batch_size})
        
        for step in range(1, nb_steps + 1):
            # Train discriminator
            for k in range(critic_steps):
                _ = sess.run(self.critic_train_op,
                             feed_dict={self.learning_rate: learning_rate,
                                        self.batch_size: batch_size})
                # Clip weights
                for p in self.critic_variables:
                    p.assign(tf.clip_by_value(p, -0.01, 0.01))

            # Train generator
            _, summaries = sess.run([self.generator_train_op, self.merged_summaries],
                                     feed_dict={self.learning_rate: learning_rate,
                                                self.batch_size: batch_size})
            
            if step % save_every == 0:
                print("Save and write summaries")
                self.saver.save(sess, "./model/model.ckpt")
                summary_writer.add_summary(summaries, step)
        
    def restore(self, sess, ckpt_file):
        self.saver.restore(sess, ckpt_file)

In [6]:
original_im_shape = (28, 28)
dim_noise = 16

batch_size = 64
learning_rate = 1e-4

critic_steps = 4

nb_steps = 100000
save_every = 200

In [7]:
wgan = WGAN(original_im_shape, dim_noise)

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Instructions for updating:
Use tf.cast instead.


In [8]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    wgan.train(X_train, batch_size, nb_steps, learning_rate, critic_steps, save_every, sess)

Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries


KeyboardInterrupt: 