In [None]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

In [None]:
mnist = tf.keras.datasets.mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = np.concatenate((X_train, X_test))
#X_train = (X_train - 127.5) / 127.5
X_train = X_train / 255

In [None]:
class Generator:
    
    def __init__(self):
        fc1 = tf.keras.layers.Dense(7*7*128, use_bias=False, name="fc_generator1")
        bn1 = tf.keras.layers.BatchNormalization()
        act1 = tf.keras.layers.ReLU()
        
        reshape = tf.keras.layers.Reshape((7, 7, 128))

        conv_t2 = tf.keras.layers.Conv2DTranspose(filters=64,
                                                  kernel_size=(5,5),
                                                  strides=(1,1),
                                                  padding='same',
                                                  name="conv_t_generator1")
        bn2 = tf.keras.layers.BatchNormalization()
        act2 = tf.keras.layers.ReLU()

        conv_t3 = tf.keras.layers.Conv2DTranspose(filters=32,
                                                  kernel_size=(5,5),
                                                  strides=(2,2),
                                                  padding='same',
                                                  name="conv_t_generator2")
        bn3 = tf.keras.layers.BatchNormalization()
        act3 = tf.keras.layers.ReLU()

        conv_t4 = tf.keras.layers.Conv2DTranspose(filters=1,
                                                  kernel_size=(5,5),
                                                  strides=(2,2),
                                                  activation=tf.nn.tanh,
                                                  padding='same',
                                                  name="conv_t_generator3")

        self.layers = [fc1, bn1, act1, reshape, conv_t2, bn2, act2, conv_t3, bn3, act3, conv_t4]
        
    def generate(self, rand_noise):
        x = rand_noise
        for layer in self.layers:
            x = layer(x)
        return x
    
class GeneratorMLP:
    
    def __init__(self):
        fc1 = tf.keras.layers.Dense(128, activation=tf.nn.relu, name="fc_generator1")
        fc2 = tf.keras.layers.Dense(512, activation=tf.nn.relu, name="fc_generator2")
        fc3 = tf.keras.layers.Dense(28*28, activation=tf.nn.sigmoid, name="fc_generator3")

        reshape = tf.keras.layers.Reshape((28, 28, 1))

        self.layers = [fc1, fc2, fc3, reshape]
        
    def generate(self, rand_noise):
        x = rand_noise
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
class Discriminator:
        
    def __init__(self):
        conv1 = tf.keras.layers.Conv2D(filters=32,
                                       kernel_size=(5,5),
                                       strides=(2,2),
                                       activation=tf.nn.leaky_relu,
                                       name="conv_discriminator1")
        dropout1 = tf.keras.layers.Dropout(rate=0.3)

        conv2 = tf.keras.layers.Conv2D(filters=64,
                                       kernel_size=(5,5),
                                       strides=(2,2),
                                       activation=tf.nn.leaky_relu,
                                       name="conv_discriminator2")
        dropout2 = tf.keras.layers.Dropout(rate=0.3)

        fc = tf.keras.layers.Dense(units=1,
                                   activation=tf.nn.sigmoid,
                                   name="fc_discriminator")

        self.layers = [conv1, dropout1, conv2, dropout2, fc]
            
    def discriminate(self, image):
        x = image
        for layer in self.layers:
            x = layer(x)
        return x
    
class DiscriminatorMLP:
        
    def __init__(self):
        flatten = tf.keras.layers.Flatten()
        
        dropout1 = tf.keras.layers.Dropout(rate=0.5)
        fc1 = tf.keras.layers.Dense(128, activation=tf.nn.leaky_relu, name="fc_discriminator1")
                
        dropout2 = tf.keras.layers.Dropout(rate=0.2)
        #fc2 = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid, name="fc_discriminator2")
        fc2 = tf.keras.layers.Dense(1, name="fc_discriminator2")

        self.layers = [flatten, dropout1, fc1, dropout2, fc2]
            
    def discriminate(self, image):
        x = image
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
class GAN:
    
    def __init__(self, original_im_shape, dim_noise):
        self.original_im_shape = original_im_shape
        self.dim_noise = dim_noise

        with tf.variable_scope("GAN"):
            self.generator = GeneratorMLP()
            self.discriminator = DiscriminatorMLP()

            # Data from mnist
            self.original_image = tf.placeholder(tf.float32, (None, *(self.original_im_shape)), name="original_image")
            self.batch_size = tf.placeholder(tf.int64, None, name="batch_size")
            self.dataset = tf.data.Dataset.from_tensor_slices(self.original_image).shuffle(10000).batch(self.batch_size).repeat()
            self.iterator = self.dataset.make_initializable_iterator()

            self.original_image_exp = tf.expand_dims(self.iterator.get_next(), -1)

            # Sample and generate fake images
            with tf.variable_scope("generator"):
                #self.rand_noise = tf.random_uniform((self.batch_size, self.dim_noise), minval=-1, maxval=1, name="rand_noise")
                self.rand_noise = tf.clip_by_value(tf.random_normal((self.batch_size, self.dim_noise), name="rand_noise"), -1, 1)
                self.generated_images = self.generator.generate(self.rand_noise)

            # Use discriminator
            with tf.variable_scope("discriminator"):
                self.prob_true_real = self.discriminator.discriminate(self.original_image_exp)
                self.prob_true_fake = self.discriminator.discriminate(self.generated_images)

            # Compute losses
            #self.loss_generator = - tf.reduce_mean(tf.log(self.prob_true_fake + 1e-8))
            #self.loss_discriminator = - (tf.reduce_mean(tf.log(self.prob_true_real + 1e-8)) +
            #                             tf.reduce_mean(tf.log(1 - self.prob_true_fake + 1e-8)))
            self.loss_generator = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.prob_true_fake,
                                                                                         labels=tf.ones_like(self.prob_true_fake)))
            self.loss_discriminator = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.prob_true_real,
                                                                                             labels=.9 * tf.ones_like(self.prob_true_real))) + \
                                      tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.prob_true_fake,
                                                                                             labels=tf.zeros_like(self.prob_true_fake)))
            # Separate trainable variables
            self.generator_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GAN/generator")
            self.discriminator_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GAN/discriminator")

            # Optimization
            self.learning_rate = tf.placeholder(tf.float32, None, name="learning_rate")
            self.optimizer_generator = tf.train.AdamOptimizer(self.learning_rate)
            self.optimizer_discriminator = tf.train.AdamOptimizer(self.learning_rate)

            self.generator_train_op = self.optimizer_generator.minimize(self.loss_generator, var_list=self.generator_variables)
            self.discriminator_train_op = self.optimizer_discriminator.minimize(self.loss_discriminator, var_list=self.discriminator_variables)
            
            # Summaries   
            tf.summary.scalar("loss_generator", self.loss_generator)
            tf.summary.scalar("loss_discriminator", self.loss_discriminator)
            #tf.summary.image("generated_images", (self.generated_images + 1) / 2, 16)
            tf.summary.image("generated_images", self.generated_images, 16)
            self.merged_summaries = tf.summary.merge_all()
            
            self.saver = tf.train.Saver()
            
    def train(self, X_train, batch_size, nb_steps, learning_rate, discriminator_steps, save_every, sess):
        summary_writer = tf.summary.FileWriter("./tensorboard/", sess.graph)

        sess.run(self.iterator.initializer, feed_dict={self.original_image: X_train,
                                                       self.batch_size: batch_size})
        
        for step in range(1, nb_steps + 1):
            # Train discriminator
            for k in range(discriminator_steps):
                _ = sess.run(self.discriminator_train_op,
                             feed_dict={self.learning_rate: learning_rate,
                                        self.batch_size: batch_size})
            # Train generator
            _, summaries = sess.run([self.generator_train_op, self.merged_summaries],
                                     feed_dict={self.learning_rate: learning_rate,
                                                self.batch_size: batch_size})
            
            if step % save_every == 0:
                print("Save and write summaries")
                self.saver.save(sess, "./model/model.ckpt")
                summary_writer.add_summary(summaries, step)
        
    def restore(self, sess, ckpt_file):
        self.saver.restore(sess, ckpt_file)

In [None]:
original_im_shape = (28, 28)
dim_noise = 16

batch_size = 128
learning_rate = 4e-4

discriminator_steps = 1

nb_steps = 100000
save_every = 1000

In [None]:
gan = GAN(original_im_shape, dim_noise)

In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #gan.restore(sess, "./model/model.ckpt")
    gan.train(X_train, batch_size, nb_steps, learning_rate, discriminator_steps, save_every, sess)