<a href="https://colab.research.google.com/github/MLandML/MLandML/blob/learning_projects/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
import sys,os

from tensorflow.keras.layers import Input,Dense,BatchNormalization,LeakyReLU,Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

In [None]:
mnist = tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test) = mnist.load_data()
x_train,x_test = x_train/255.0,x_test/255.0
x_train.shape

In [None]:
#Flatten
N,H,W = x_train.shape
D = H*W
x_train = x_train.reshape(-1,D)
x_test = x_test.reshape(-1,D)

In [None]:
latent_dims = 100 #dimension of latent space

def build_generator(latent_dims):
  i = Input(shape=(latent_dims,))
  x = Dense(512,activation=LeakyReLU(alpha=0.1))(i)
  x = BatchNormalization(momentum=0.8)(x)
  x = Dense(512,activation=LeakyReLU(alpha=0.1))(x)
  x = BatchNormalization(momentum=0.8)(x)
  x = Dense(512,activation=LeakyReLU(alpha=0.1))(x)
  x = BatchNormalization(momentum=0.8)(x)
  x = Dense(D,activation='tanh')(x)

  model = Model(i,x)
  return model

In [None]:
def build_discriminator(img_size):
  i = Input(shape=(img_size,))
  x = Dense(512,activation=LeakyReLU(alpha=0.1))(i)
  x = Dense(512,activation=LeakyReLU(alpha=0.1))(x)
  x = Dense(1,activation='sigmoid')(x)

  model = Model(i,x)
  return model

In [None]:
discriminator = build_discriminator(D)
discriminator.compile(
    loss='binary_crossentropy',
    optimizer = Adam(0.0002,0.5),
    metrics = ['accuracy']
)

z = Input(shape=(latent_dims)) #random noise
generator = build_generator(latent_dims)
img = generator(z)
#only the generator should be trained
discriminator.trainable = False

fake_pred = discriminator(img) #labeling fake images as real
combined_model = Model(z,fake_pred) #combining generator and discriminator
combined_model.compile(
    loss = 'binary_crossentropy',
    optimizer = Adam(0.0002,0.5)
)

#Train the GAN

#Config
batch_size = 32
epochs = 30000
sample_period = 200 #saves some data after every mentioned no. of steps

#batch labels to use when calling train_on_batch
ones = np.ones(batch_size)
zeros = np.zeros(batch_size)

#store losses
g_losses = []
d_losses = []

#create a folder to store generated images
if not os.path.exists('gan_images'):
  os.makedirs('gan_images')

#a function to generate a grid of random samples from the generator and save them to a file
def sample_images(epoch):
  rows,columns=5,5
  noise = np.random.randn(rows*columns,latent_dims)
  imgs = generator.predict(noise)

  #rescale the images
  imgs = 0.5*imgs + 0.5
  fig,axs = plt.subplots(rows,columns)
  idx = 0
  for i in range(rows):
    for j in range(columns):
      axs[i,j].imshow(imgs[idx].reshape(H,W),cmap='gray')
      axs[i,j].axis('off') #off so that we dont see lines in plots
  fig.savefig('gan_images/%d.png'% epoch)
  plt.close() #to clean up any resources

In [None]:
# Main training loop
for epoch in range(epochs):

  #Train discriminator

  #select a batch of real images
  idx = np.random.randint(0,x_train.shape[0],batch_size)
  real_imgs = x_train[idx]

  #select fake images
  noise = np.random.randn(batch_size,latent_dims)
  fake_imgs = generator.predict(noise)

  d_real_loss,d_real_acc = discriminator.train_on_batch(real_imgs,ones)
  d_fake_loss,d_fake_acc = discriminator.train_on_batch(fake_imgs,zeros)
  d_loss = 0.5*(d_real_loss + d_fake_loss)
  d_acc = 0.5*(d_real_acc + d_fake_acc)

  #Train generator
  noise = np.random.randn(batch_size,latent_dims)
  g_loss = combined_model.train_on_batch(noise,ones)

  #save the loss
  d_losses.append(d_loss)
  g_losses.append(g_loss)

  if epoch % 100 == 0:
    print(f"epoch: {epoch+1}/{epochs},\
         d_loss:{d_loss:2f},\
         g_loss:{g_loss:2f},\
         d_acc:{d_acc:2f}"
         )
  if epoch % sample_period == 0:
    sample_images(epoch)

In [None]:
plt.plot(g_losses,label='g_loss')
plt.plot(d_losses,label='d_loss')
plt.legend()

In [None]:
!ls gan_images

In [None]:
from skimage.io import imread
a = imread('gan_images/0.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/1000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/2000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/5000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/10000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/20000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/29800.png')
plt.imshow(a)