<a href="https://colab.research.google.com/github/Yashu2699/Deep_learning/blob/main/GAN/gan_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam

import matplotlib.pyplot as plt
%matplotlib inline

from tqdm import tqdm
import numpy as np

In [7]:
np.random.seed(10)

random_dim = 100

In [8]:
def load_mnist_data():
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_train = (x_train.astype(np.float32) - 127.5) / 127.5
  x_train = x_train.reshape(60000, 784)
  return (x_train, y_train, x_test, y_test)

In [9]:
def get_generator():
  generator = Sequential()
  generator.add(Dense(256, input_dim=random_dim))
  generator.add(LeakyReLU(0.2))

  generator.add(Dense(512))
  generator.add(LeakyReLU(0.2))

  generator.add(Dense(1024))
  generator.add(LeakyReLU(0.2))

  generator.add(Dense(784, activation = 'tanh'))
  generator.compile(loss='binary_crossentropy', optimizer='rmsprop')
  generator.summary()
  return generator

In [10]:
from re import L
def get_discriminator():
  discriminator = Sequential()
  discriminator.add(Dense(1024, input_dim=784))
  discriminator.add(LeakyReLU(0.2))
  discriminator.add(Dropout(0.3))

  discriminator.add(Dense(512))
  discriminator.add(LeakyReLU(0.2))
  discriminator.add(Dropout(0.3))

  discriminator.add(Dense(256))
  discriminator.add(LeakyReLU(0.2))
  discriminator.add(Dropout(0.3))

  discriminator.add(Dense(1, activation='sigmoid'))
  discriminator.compile(loss='binary_crossentropy', optimizer='rmsprop')

  discriminator.summary()
  return discriminator

In [11]:
def get_gan_network(discriminator, random_dim, generator):
  
  discriminator.tainable = False
  gan_input = Input(shape=(random_dim, )) #gan_input
  x = generator(gan_input) #the output of the generator (an image)
  gan_output = discriminator(x) #the output of the discriminator (probability if the image is real or not)
  gan = Model(inputs=gan_input, outputs=gan_output)
  gan.compile(loss='binary_crossentropy', optimizer='adam')
  
  return gan

In [15]:
def train(epochs=1, batch_size=128):
  x_train, y_train, x_test, y_test = load_mnist_data()
  batch_count = x_train.shape[0] / batch_size #split training data into batches of 128

  generator = get_generator()
  discriminator = get_discriminator()
  gan = get_gan_network(discriminator, random_dim, generator)

  for e in range(1, epochs+1):
    print('-'*15, 'Epoch %d' % e, '-'*15)
    for _ in tqdm(range(int(batch_count))):
      noise = np.random.normal(0, 1, size=[batch_size, random_dim])  # get a random set of input noise and images
      image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]

      #generate fake images
      generated_images = generator.predict(noise)
      x = np.concatenate([image_batch, generated_images])

      #labels for generated and real data
      y_dis = np.zeros(2*batch_size)
      #one-sided label smooothing
      y_dis[:batch_size] = 0.9

      #train discriminator
      discriminator.trainable = True
      discriminator.train_on_batch(x, y_dis)

      #train generator
      noise = np.random.normal(0, 1, size=[batch_size, random_dim])
      y_gen = np.ones(batch_size)
      discriminator.trainable = False
      gan.train_on_batch(noise, y_gen)

    if(e==1 or e%20==0):
      plot_generated_images(e, generator)



In [16]:
def plot_generated_images(epoch, generator, examples=100, dim=(10,10), figsize=(10,10)):
  noise = np.random.normal(0, 1, size=[examples, random_dim])
  generated_images = generator.predict(noise)
  generated_images = generated_images.reshape(examples, 28, 28)

  plt.figure(figsize=figsize)
  for i in range(generated_images.shape[0]):
    plt.subplot(dim[0], dim[1], i+1)
    plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
    plt.axis('off')
  plt.tight_layout()
  plt.savefig('gan_generated_image_epoch_%d.png' % epoch)

