# GAN 101

In [1]:
import os
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

In [2]:
dataset_path = os.path.normpath(r"../DataSet/mnist")
mnist_dataset=input_data.read_data_sets(dataset_path, one_hot=True)
print("Number of training samples: {}\nNumber of test samples: {}".format(
        mnist_dataset.train.num_examples, mnist_dataset.test.num_examples))

Extracting ..\DataSet\mnist\train-images-idx3-ubyte.gz
Extracting ..\DataSet\mnist\train-labels-idx1-ubyte.gz
Extracting ..\DataSet\mnist\t10k-images-idx3-ubyte.gz
Extracting ..\DataSet\mnist\t10k-labels-idx1-ubyte.gz
Number of training samples: 55000
Number of test samples: 10000


In [3]:
class GAN:
    def __init__(self, noise_input_tensor, image_input_tensor, generator_hidden_dim, discriminator_hidden_dim):
        self._noise_input = noise_input_tensor
        self._image_input = image_input_tensor
        with tf.variable_scope("generator"):
            g_o = self._generator_output = self._fnn(noise_input_tensor, generator_hidden_dim, image_input_tensor.shape[1])
        with tf.variable_scope("discriminator"):
            d_o_real = self._discriminator_output_for_real_data = self._fnn(image_input_tensor, 
                                                                            discriminator_hidden_dim, 1)
        with tf.variable_scope("discriminator", reuse=True):  # Share weights and biases
            d_o_fake = self._discriminator_output_for_synth = self._fnn(g_o, discriminator_hidden_dim, 1)
    
    def _fnn(self, input_tensor, hidden_dim, output_dim):
        w = tf.get_variable(initializer=tf.truncated_normal_initializer, shape=[input_tensor.shape[1], hidden_dim], name="W_xh")
        b = tf.get_variable(initializer=tf.truncated_normal_initializer, shape=[hidden_dim], name="b_xh")
        hidden = tf.nn.relu(tf.add(tf.matmul(input_tensor, w), b))
        w = tf.get_variable(initializer=tf.truncated_normal_initializer, shape=[hidden_dim, output_dim], name="W_ho")
        b = tf.get_variable(initializer=tf.truncated_normal_initializer, shape=[output_dim], name="b_ho")
        output = tf.add(tf.matmul(hidden, w), b)
        return tf.sigmoid(output)
        
        
    @property
    def noise_input(self):
        return self._noise_input
    
    @property
    def image_input(self):
        return self._image_input
    
    @property
    def generator_output(self):
        return self._generator_output

    @property
    def discriminator_output_from_generator(self):
        return self._discriminator_output_for_synth
    
    @property
    def discriminator_output_from_image_input(self):
        return self._discriminator_output_for_real_data

In [4]:
noise_dim = 128
image_dim = 28*28
generator_hidden_dim = 256
discriminator_hidden_dim = 256

In [5]:
tf.reset_default_graph()
generator_input = tf.placeholder(shape=[None, noise_dim], dtype=tf.float32)
discriminator_input = tf.placeholder(shape=[None, image_dim], dtype=tf.float32)
gan = GAN(generator_input, discriminator_input, generator_hidden_dim, discriminator_hidden_dim)
generator_loss = tf.reduce_mean(1.0-tf.log(gan.discriminator_output_from_generator))
discriminator_loss = tf.reduce_mean(
    tf.log(gan.discriminator_output_from_image_input)+tf.log(1.0-gan.discriminator_output_from_generator))
g_train_op = tf.train.AdamOptimizer(learning_rate=2e-4).minimize(generator_loss)
d_train_op = tf.train.AdamOptimizer(learning_rate=2e-4).minimize(discriminator_loss)

In [None]:
n_epoch = 30
batch_size = 50
num_batch = int(mnist_dataset.train.num_examples/batch_size)
k = 5

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(n_epoch):
        for batch_no in range(num_batch):
            image_batch, _ = mnist_dataset.train.next_batch(batch_size) # discard the labels
            # Train the discriminator network
            for k in range(5):
                noise_batch = np.random.uniform(-1.0, 1.0, size=[batch_size, noise_dim])
                feed_dict = {gan.image_input: image_batch, gan.noise_input: noise_batch}
                sess.run(d_train_op, feed_dict=feed_dict)
            # Train the generator
            noise_batch = np.random.uniform(-1.0, 1.0, size=[batch_size, noise_dim])
            feed_dict = {gan.image_input: image_batch, gan.noise_input: noise_batch}
            _, g_loss, d_loss = sess.run([g_train_op, generator_loss, discriminator_loss], feed_dict=feed_dict)
            print("Generator loss:{}\nDiscriminator loss:{}".format(g_loss, d_loss))