# My First GAN
This is my first implementation of a working Generative Adversarial Network.

### 1. Importing all necessary libraries
`hyperdash` can be removed. Just don't forget to remove all the code that uses it.

In [None]:
# Import all the libraries
import tensorflow as tf
import tensorflow.contrib.layers as tcl
import numpy as np
import matplotlib.pyplot as plt
from scipy.misc import imsave

from hyperdash import Experiment # To view the training from my phone, CAN BE REMOVED

### 2. All the necessary functions
These are functions for:
1. Sampling from a Gaussian distribution,
2. Creating the generator,
3. Creating the discriminator.

In [None]:
# Helper funcs
def sample_z(m, n):
    return np.random.normal(0., 0.1, size=[m, n]) # Sampling from a Gaussian distribution


start = 28 # Starting dimensions of the image in the deconv-net
depth = 1024 // 8 # Depth of the image in the deconv-net

training_session_number = 4 # Current training session number

# NOTE: Before running this notebook, be sure to look through the code and fill in the missing code in the
# training loop and after it.

# NOTE: To just generate images from a saved model, do not run the training loop. Run the session written at
# the bottom.


# Generator model
def gen(z, reuse):
    with tf.variable_scope("Generator", reuse=reuse):
        net = tcl.fully_connected(z, start * start * depth, activation_fn=tf.nn.leaky_relu)
        net = tf.reshape(net, (-1, start, start, depth))
        return tcl.conv2d_transpose(net, 1, start, normalizer_fn=tcl.batch_norm, activation_fn=tf.nn.tanh)


# Discriminator model
def dis(x, reuse):
    with tf.variable_scope("Discriminator", reuse=reuse):
        net = tcl.conv2d(x, depth // 2, start, 2, padding="SAME", normalizer_fn=tcl.batch_norm, activation_fn=tf.nn.leaky_relu)
        net = tcl.dropout(net, 0.85)
        net = tcl.flatten(net)
        return tcl.fully_connected(net, 1, activation_fn=tf.nn.sigmoid)

### 3. Loading the MNIST Dataset

In [None]:
# Loading the dataset
(X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()

X_train = (X_train / 127.5) - 1. # Normalize the dataset in range [-1, 1] for training

X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], X_train.shape[2], 1))

## Creating the Graph

### 4. Creating the placeholders

In [None]:
# Creating the placeholders
X = tf.placeholder(tf.float32, shape=[None, X_train.shape[1], X_train.shape[2], 1], name="image_input")
Z = tf.placeholder(tf.float32, shape=[None, 100], name="latent_input")

### 5. Getting the outputs from both the networks

In [None]:
# Get outputs from nets
gen_imgs = gen(Z, False)
d_real = dis(X, False)
d_fake = dis(gen_imgs, True)

# Add generated images to summary
tf.summary.image("GEN_IMAGES", gen_imgs[:50])

### 6. Computing the loss
Here, I am using the "-log D" trick.

In [None]:
# Computing the loss
d_loss = -tf.reduce_mean(tf.log(d_real) + tf.log(1. - d_fake))
g_loss = -tf.reduce_mean(tf.log(d_fake))

# Add losses to summary
tf.summary.scalar("DISC_LOSS", d_loss)
tf.summary.scalar("GEN_LOSS", g_loss)

### 7. Getting the variables for both the networks to train

In [None]:
# Getting variables
gen_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Generator")
dis_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Discriminator")

### 8. Creating optimizers to train

In [None]:
# Optimizers
g_opt = tf.train.AdamOptimizer(0.001).minimize(g_loss, var_list=gen_vars)
d_opt = tf.train.MomentumOptimizer(0.001, 0.001).minimize(d_loss, var_list=dis_var, name="SGD")

### 9. Merging all summaries to be viewed using tensorboard

In [None]:
# Merge summaries
summary_op = tf.summary.merge_all()

### 10. Hyperparameters

In [None]:
# Hyperparameters
batch_size = 1000

### 11. Creating a saver to save the model variables

In [None]:
# Creating saver
saver = tf.train.Saver()

### 12. Training the networks.

In [None]:
# Train

# Creating experiment for monitoring
exp = Experiment("GAN Test Final")

# Creating session to train the model
with tf.Session() as sess:
    writer = tf.summary.FileWriter("./model/{}/logs".format(), sess.graph) # Summary writer
    sess.run(tf.global_variables_initializer()) # Initialize all variables
    
    saver.restore(sess, "./model/{}/saved/model_ckpt_9.ckpt".format())
    print("Starting where the last training session left off.")
  
    for i in range(10000):
        idx = np.random.randint(0, X_train.shape[0], batch_size) # Pick random indices from X
        imgs = X_train[idx] # Pick random images from X
    
        # Get the losses
        _, d_loss_curr, s_str = sess.run([d_opt, d_loss, summary_op], feed_dict={X: imgs, Z: sample_z(batch_size, 100)})
        _, g_loss_curr = sess.run([g_opt, g_loss], feed_dict={Z: sample_z(batch_size, 100)})
        
        # Write the summary of the iteration
        writer.add_summary(s_str, i)
    
        # Log
        print("Iteration: {}, D_LOSS: {}, G_LOSS: {}".format(i + 1, d_loss_curr, g_loss_curr))
        
        # Save the model every 1000 iterations
        if i % 500 == 0:
            save_path = saver.save(sess, "./model/{}/saved/model_ckpt_{}.ckpt".format(, str(i // 1000)))
            print("Saved model {} at: {}".format(training_session_number, save_path))
        
    writer.close()

In [None]:
# End experiment and stop sending to my phone
exp.end()

### 13. Generating sampless from the learned model.

In [None]:
genned_imgs = None

# Generate image from the generator
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, "./model/{}/saved/model_ckpt_9.ckpt".format(training_session_number - 1))
    genned_imgs = sess.run(gen_imgs, feed_dict={Z: sample_z(10, 100)}) # Generating 10 images

### 14. Reshaping and rescaling the images back to the original format.

In [None]:
# Reshaping the image
genned_imgs = np.reshape(genned_imgs, (10, 28, 28))

In [None]:
# Rescale to 0-255 and convert to uint8
genned_imgs = (genned_imgs + 1.) * 127.5
genned_imgs = genned_imgs.astype(np.uint8)

### 15. Saving the generated images.

In [None]:
# Saving generated images
for i in range(genned_imgs.shape[0]):
    imsave("./generated/imgs/img_{}.png".format(i + 1), genned_imgs[i])