### Imports

In [1]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import os
from helper import *

### To use this on google colab

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### Initial Parameters

In [3]:
image_dims = (28, 28, 1) # Shape of the image
noise_vect_dim = 100     # size of the generator's input noise vector

In [4]:
def create_generator_model(noise_vect_dim = noise_vect_dim, img_dim = image_dims):


    """
    This method create and returns the generator model
    
    @args
      noise_vect_dim = size of the noise vector that is the input to the generator,
      img_dim        = dimension of the image that generator suppose to generate.

    @returns 
      model  = Sequential model containing the structure of the generator model  

    """
    
    model = tf.keras.Sequential()

    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(noise_vect_dim,)))
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) 

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, img_dim[0], img_dim[1], img_dim[2])

    return model


def create_discriminator_model(img_dim = image_dims):
    
    """
    This method creates and returns a sequential model which will be used as discriminator for GAN

    @args
      img_dims = dimension of the image whose authenticity will be validated by the discriminator
    @return
      model = Discrminator model
    
    """
    
    model = tf.keras.Sequential()

    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=img_dim))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.15))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.15))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model



### Loss functions

In [5]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    
    """
    It computes the loss for the discrminator network using output from the real and generator originated fake images
    
    @args
      real_output = Real_output is the output of the discriminator network when real images were used as input, shape (batch_size, 1)
      fake_output = Fake_output is the output of the discriminator network when images created by generator network were used as input, shape = (batch_size, 1)
    @returns
      total_loss = Total loss after summing the loss from fake and real losses for discriminator
    """
 
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)   # Target labels for the real images are only 1s
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)  # Target labesl for the fake images are only 0s
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    
    """  
    It computes the loss for the generator network  by using comparing the output of the discrimintor in terms of how succssfully the wrong inputs were labeled as real image
    @args
      fake_output = output of discrmininator when using input images were the generator outputs
    @returns
      Loss of the discriminator for the fake_output
    
    """
    # Target labels are ones because generator loss should decrease by increase no of fake images being labeled as real
    return cross_entropy(tf.ones_like(fake_output), fake_output)



### Training Step Function

In [6]:
@tf.function
def train_step(images, batch_size, generator, discriminator, generator_optimizer, discriminator_optimizer, noise_size):
    
    """
    This method executes single forward pass for the training of both GANs network

    @args
      images                  =  (batch_size, image_shape) Real images from mnist dataset to be used by discriminator method
      batch_size              =  Size of the batch
      generator               =  Generator model to be trained
      discriminator           =  Discriminator model to be trained
      generator_optimizer     =  Optimization function for generator
      discriminator_optimizer =  Discriminator function for generator

    @returns
      gen_loss  = loss of the generator network
      disc_loss = loss of the discriminator network
    """
    
    noise = tf.random.normal([batch_size, noise_size, 1])                  # Generate the noise vecotors of batch size
        
    with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape:    # This gradient tape is a tensorflow's way of doing loss minization in 3 distinct steps rather than one (i.e .minimize) 
        generated_images = generator(noise, training = True)               # Generator generates fake images images
        
        real_output = discriminator(images, training = True)               # Real image predictions     
        fake_output = discriminator(generated_images, training = True)     # Fake image_predictions
        
        gen_loss = generator_loss(fake_output)                             # Generator loss
        disc_loss = discriminator_loss(real_output, fake_output)           # Discriminator loss 
        
    gen_gradient = gen_tape.gradient(gen_loss, generator.trainable_variables)                      # Computing the gradient for Generation loss
    generator_optimizer.apply_gradients(zip(gen_gradient, generator.trainable_variables))          # Weights updation of generator weights


    disc_gradient = disc_tape.gradient(disc_loss, discriminator.trainable_variables)               # Discrminator Gradient Calculation using GradientTape
    discriminator_optimizer.apply_gradients(zip(disc_gradient, discriminator.trainable_variables)) # Weights updation of discrminator weights
    
    
    return gen_loss, disc_loss

### Load MNIST Dataset

In [8]:
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()               # Load MNIST dataset
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')  # Reshape and change type of the array
train_images = (train_images - 127.5) / 127.5                                            # Normalize the images to [-1, 1]

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


### Parameters

In [9]:
lr = .0001        
batch_size = 256
epochs = 200
base_location = "" # Base location where training data supposed to be saved

### Define Models

In [10]:
generator_model = create_generator_model()
discrminator_model = create_discriminator_model()

### Define Optimizers

In [11]:
generator_optimizer = tf.keras.optimizers.Adam(lr)
discriminator_optimizer = tf.keras.optimizers.Adam(lr)

### Training

In [1]:
gen_loss_ep = []                                                                 # List to save generator loss for each epoch
disc_loss_ep = []                                                                # List to save discriminator loss for each epoch

run_path = create_training_directory(base_location)                                      # Create a folder for this training session at base location
for i in range(epochs):
    image_batches = np.array_split(train_images, len(train_images)//batch_size)  # Create a list of image batches
    gen_loss_list = []                                                           # List to save generator loss for each batch
    disc_loss_list = []                                                          # List to save discriminator loss for each batch
    
    for bi, batch in enumerate(image_batches):                                   # Iterate over image batch sets 
        gen_loss, disc_loss = train_step(  images = batch,                           
                                           batch_size = batch_size,
                                           generator = generator_model,
                                           discriminator = discrminator_model,
                                           generator_optimizer = generator_optimizer,
                                           discriminator_optimizer = discriminator_optimizer,
                                           noise_size = noise_vect_dim
                                        )
        gen_loss_list.append(gen_loss.numpy())
        disc_loss_list.append(disc_loss.numpy())
    
    gen_loss_ep.append(np.mean(gen_loss_list))
    disc_loss_ep.append(np.mean(disc_loss_list))
    
    print("----------------------------Epoch", i+1,"/",epochs,"-----------------------------------")
    print("Generator Loss", gen_loss_ep[-1])
    print("Discrminator Loss", disc_loss_ep[-1])
    
    plot_grids(train_images, generator_model,5 ,5)
    
    if i > 0:   # Plot only makes sense after first epoch 
        plot_loss(gen_loss_ep, disc_loss_ep)
        
    # Save both losses as a numpy array   
    np.save(run_path + "/gen_loss.npy", np.array(gen_loss_ep))
    np.save(run_path + "/disc_loss.npy", np.array(disc_loss_ep))
    
    
# Save the training session data
generator_model.save(os.path.join(run_path, "generator_model.h5"))
discrminator_model.save(os.path.join(run_path, "discrminator_model.h5"))