<a href="https://colab.research.google.com/github/FaiazS/Handwritten-Digit-Generation-using-Generative-Adversarial-Networks-GANs-/blob/main/Handwritten_Digit_Generation_using_Generative_Adversarial_Networks_(GANs)_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Load Libraries

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

from tensorflow.keras import models, layers

In [None]:
#Loading and preprocessing MNIST Dataset

def load_data():

  (x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()

  x_train = (x_train.astype(np.float32) - 127.5) / 127.5 #Scaling down / Normalizing input to [-1, 1]

  x_train = np.expand_dims(x_train, axis = -1) #Adding Channel Dimension

  return tf.data.Dataset.from_tensor_slices(x_train).shuffle(60000).batch(128).prefetch(tf.data.AUTOTUNE)

In [None]:
#Building the Generative Adversarial Network

def Generator_Model():

  model = tf.keras.Sequential([layers.Dense(7 * 7 * 256, use_bias = False, input_shape =(100, )),

                               layers.BatchNormalization(),

                               layers.ReLU(),

                               layers.Reshape((7, 7, 256)),

                               layers.Conv2DTranspose(128, (5 ,5), strides = (1, 1), padding = 'same', use_bias = False),

                               layers.BatchNormalization(),

                               layers.ReLU(),

                               layers.Conv2DTranspose(24, (5, 5), strides = (2, 2), padding = 'same', use_bias = False),

                               layers.BatchNormalization(),

                               layers.ReLU(),

                               layers.Conv2DTranspose(1, (5 ,5), strides = (2, 2), padding = 'same' , activation = 'tanh')

                               ])

  return model


def Discriminator_Model():

      model = tf.keras.Sequential([layers.Conv2D(64, (5, 5), strides = (1, 1), padding = 'same', input_shape = [28, 28, 1]),

                                   layers.BatchNormalization(),

                                   layers.LeakyReLU(alpha = 0.2),

                                   layers.Dropout(0.3),

                                   layers.Conv2D(128, (5, 5), strides = (2, 2), padding = 'same'),

                                   layers.LeakyReLU(alpha = 0.2),

                                   layers.Dropout(0.3),

                                   layers.Flatten(),

                                   layers.Dense(1, activation = 'sigmoid')

                                   ])

      return model







In [None]:
#Preprocessing and Compiling

def get_optimizers():

   return tf.keras.optimizers.Adam(1e-4), tf.keras.optimizers.Adam(1e-4)

generator_model = Generator_Model()

discriminator_model = Discriminator_Model()

loss_function = tf.keras.losses.BinaryCrossentropy(from_logits = True)


generator_optimizer, discriminator_optimizer = get_optimizers()

In [None]:
#Discriminator wants real images to be classified as 1 and fake images as 0,
#On the other hand, Generator wants fake images to be classified as 1.

#Compiling to a TensorFlow Graph for Speed.

@tf.function

def training_step(images):

  noise = tf.random.normal((128, 100))  #Defining a batch of 128 noise vectors each of size 100 dimension

  with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:

    generated_images = generator_model(noise, training = True) # Generator takes the random noise as input produces fake images resembling real MNIST Digits.

    real_image_output = discriminator_model(images, training = True) # Discriminator predicts whether real MNIST images are real.(Should be close to 1).

    fake_image_output = discriminator_model(generated_images, training = True) #Discriminator predicts whether the generated(fake) images are real.(Should be close to 0).

    #tf.ones_like(fake_image_output) defines a tensor of ones, representing real labels

    #loss_function(Binary Cross Entropy) measures how close fake_image_output is to 1.

    #Low generator loss means the Generator Model is successfully fooling the Discriminator Model.

    #High generator loss means the Generator Model is not Performing well.

    generator_loss = loss_function(tf.ones_like(fake_image_output), fake_image_output)


     #How well the discriminator model classfies real images as real(Close to 1). + How well the discriminator model classifies fake images as fake(Close to 0).

    discriminator_loss = loss_function(tf.ones_like(real_image_output), real_image_output) + loss_function(tf.zeros_like(fake_image_output), fake_image_output)


    generator_gradient = generator_tape.gradient(generator_loss, generator_model.trainable_variables) #Computes the gradient for the parameters of the Generator Model.

    discriminator_gradient = discriminator_tape.gradient(discriminator_loss, discriminator_model.trainable_variables)  #Computes the gradient for the parameters of the Discriminator Model.

    #Updating the Generator Model's weights

    generator_optimizer.apply_gradients(zip (generator_gradient, generator_model.trainable_variables))


    #Updating the Discriminator Model's weights

    discriminator_optimizer.apply_gradients(zip (discriminator_gradient, discriminator_model.trainable_variables))

    return generator_loss, discriminator_loss


In [None]:
#Training both the Generator Model and Discriminator Model

def train_generator_and_discriminator_models(dataset, epochs = 27):

  for epoch in range(epochs):

    for batch in dataset:

      gen_loss, disc_loss = training_step(batch)

def generate_and_show_image():

  noise = tf.random.normal((16, 100))

  images = generator_model(noise, training = False)

  images = (images + 1) / 2.0  #Rescaling to [0,1]

  figure, axis = plt.subplots(4, 4, figsize = (7, 7))

  for i, ax in enumerate(axis.flat):

    ax.imshow(images[i, :, :, 0], cmap = 'gray')

    ax.axis('off')

  plt.show()


#Loading data and training the both Models

mnist_data = load_data()

train_generator_and_discriminator_models(mnist_data, epochs = 27)

generate_and_show_image()
