<a href="https://www.kaggle.com/code/omarsaad34/mnist-gan?scriptVersionId=106595084" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

**MNIST Dataset with Generative adversarial networks (GAN)**

Imports
---

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras.datasets import mnist
from keras.layers import Input,Dense,Flatten,Reshape,BatchNormalization,Conv2D,LeakyReLU,Dropout,Conv2DTranspose
from keras.models import Sequential,Model
import matplotlib.pyplot as plt
from keras.utils.vis_utils import plot_model
from tensorflow.keras.optimizers import Adam

In [None]:
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

In [None]:
import tensorflow as tf
tf.test.is_gpu_available()

In [None]:
# check keras version
keras.__version__

Load Dataset
---

In [None]:
IMG_ROW = 28
IMG_COL = 28
IMG_CHANNEL = 1 # grey scale image
IMG_SHAPE = (IMG_ROW,IMG_COL,IMG_CHANNEL)

In [None]:
# (x_train, y_train), (x_test, y_test) = mnist.load_data()
def load_real_samples():
	# load mnist dataset
	(trainX, _), (_, _) = mnist.load_data()
	# expand to 3d, e.g. add channels dimension
	X = np.expand_dims(trainX, axis=-1)
	# convert from unsigned ints to floats
	X = X.astype('float32')
	# scale from [0,255] to [0,1]
	X = X / 255.0
	return X

In [None]:
# select real samples
def generate_real_samples(dataset, n_samples):
	# choose random instances
	idx = np.random.randint(0, dataset.shape[0], n_samples)
	# retrieve selected images
	X = dataset[idx]
	# generate 'real' class labels (1)
	y = np.ones((n_samples, 1))
	return X, y

In [None]:
# x_train.shape

Constants
---

In [None]:
# LR : Learning Rate
LR = 0.0002
# MOMENTUM speeds up the training
MOMENTUM = 0.5
BINARY_LOSS_FUNC = 'binary_crossentropy'

Visualize Dataset
---

In [None]:
def show_images(images):
  n = len(images)
  for i in range(n):
    plt.subplot(int(n/2),int(n/2),i+1)
    plt.imshow(images[i])


In [None]:
(images, _), (_, _) = mnist.load_data()

show_images(images[0:8])
del images

Building Model
---

Building the generator

In [None]:
def build_generator(latent_dim):
	model = Sequential()
	# foundation for 7x7 image
	n_nodes = 128 * 7 * 7
	model.add(Dense(n_nodes, input_dim=latent_dim))
	model.add(LeakyReLU(alpha=0.2))
	model.add(Reshape((7, 7, 128)))
	# upsample to 14x14
	model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
	model.add(LeakyReLU(alpha=0.2))
	# upsample to 28x28
	model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
	model.add(LeakyReLU(alpha=0.2))
	model.add(Conv2D(1, (7,7), activation='sigmoid', padding='same'))
	return model

Build The Discriminator

> The Discriminator : predicts whethear a given image is real or fake.



In [None]:
def build_discriminator(in_shape = IMG_SHAPE):
	model = Sequential()
	model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same', input_shape=in_shape))
	model.add(LeakyReLU(alpha=0.2))
	model.add(Dropout(0.4))
	model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same'))
	model.add(LeakyReLU(alpha=0.2))
	model.add(Dropout(0.4))
	model.add(Flatten())
	model.add(Dense(1, activation='sigmoid'))
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
	return model

Build The combined GAN model

In [None]:
def build_gan(generator,discriminator):
  discriminator.trainable = False
  model = Sequential()
  model.add(generator)
  model.add(discriminator)
  optimizer = Adam(learning_rate=LR,beta_1=MOMENTUM)
  model.compile(optimizer,loss=BINARY_LOSS_FUNC)
  return model

In [None]:
# train the discriminator model
def train_discriminator(model, dataset, epochs=100, batch_size=256):
    half_batch = int(batch_size / 2)
    print('Discriminator starts training')
    # manually enumerate epochs
    for i in range(epochs):
        # get randomly selected 'real' samples
        X_real, y_real = generate_real_samples(dataset, half_batch)
        # update discriminator on real samples
        _, real_acc = model.train_on_batch(X_real, y_real)
        # generate 'fake' examples
        X_fake, y_fake = generate_fake_samples(half_batch)
        # update discriminator on fake samples
        _, fake_acc = model.train_on_batch(X_fake, y_fake)
        # summarize performance
        print('>%d real acc.=%.0f%% fake acc.=%.0f%%' % (i+1, real_acc*100, fake_acc*100))

In [None]:
# train the generator and discriminator
def train(generator, discriminator, gan, dataset, latent_dim, n_epochs=100, n_batch=256,save_interval=2500):
    
  bat_per_epo = int(dataset.shape[0] / n_batch)
  half_batch = int(n_batch / 2)
  # manually enumerate epochs
  for i in range(n_epochs):
    # enumerate batches over the training set
    for j in range(bat_per_epo):
      # get randomly selected 'real' samples
      X_real, y_real = generate_real_samples(dataset, half_batch)
      # generate 'fake' examples
      X_fake, y_fake = generate_fake_samples(generator, latent_dim, half_batch)
      # create training set for the discriminator
      X, y = np.vstack((X_real, X_fake)), np.vstack((y_real, y_fake))
      # update discriminator model weights
      d_loss, _ = discriminator.train_on_batch(X, y)
      # prepare points in latent space as input for the generator
      X_gan = generate_latent_points(latent_dim, n_batch)
      # create inverted labels for the fake samples
      y_gan = np.ones((n_batch, 1))
      # update the generator via the discriminator's loss
      g_loss = gan.train_on_batch(X_gan, y_gan)
      # summarize loss on this batch
      print('>%d, %d/%d, d=%.3f, g=%.3f' % (i+1, j+1, bat_per_epo, d_loss, g_loss))
      
    if i % save_interval ==0 :
      evalute(i, generator, discriminator, dataset, latent_dim, n_samples=10)




In [None]:
# evaluate the discriminator, plot generated images, save generator model
def evalute(epoch, generator, discriminator, dataset, latent_dim, n_samples=10):
	# prepare real samples
	X_real, y_real = generate_real_samples(dataset, n_samples)
	# evaluate discriminator on real examples
	_, acc_real = discriminator.evaluate(X_real, y_real, verbose=0)
	# prepare fake examples
	x_fake, y_fake = generate_fake_samples(generator, latent_dim, n_samples)
	# evaluate discriminator on fake examples
	_, acc_fake = discriminator.evaluate(x_fake, y_fake, verbose=0)
	# summarize discriminator performance
	print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
	# save plot
	save_image(epoch,generator)
	# save the generator model tile file
	filename = 'generator_model_%03d.h5' % (epoch + 1)
	generator.save(filename)

In [None]:
def save_image(epoch,generator):
  r,c = 5,5
  noise = np.random.normal(0, 1 , (r*c,100))
  generated_images = generator.predict(noise)
  fig,axs = plt.subplots(r,c)
  index=0
  for i in range(r):
    for j in range(c):
      axs[i,j].imshow(generated_images[index,:,:,0])
      # axs[i,j].axis('off')
      index+=1
  fig.savefig('mnist_%d.png'%epoch)
  plt.close()

In [None]:
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
	# generate points in the latent space
	x_input = np.random.randn(latent_dim * n_samples)
	# reshape into a batch of inputs for the network
	x_input = x_input.reshape(n_samples, latent_dim)
	return x_input

In [None]:
# use the generator to generate n fake examples, with class labels
def generate_fake_samples(g_model, latent_dim, n_samples):
	# generate points in latent space
	x_input = generate_latent_points(latent_dim, n_samples)
	# predict outputs
	X = g_model.predict(x_input)
	# create 'fake' class labels (0)
	y = np.zeros((n_samples, 1))
	return X, y

Visulize generator output before training (noise)
---

In [None]:
latent_im = 100
generator = build_generator(latent_im)
plt.imshow(np.reshape(generate_fake_samples(generator ,100 , 1)[0][0] , (28,28)))

Main Code
---



> Build disciminator



In [None]:
# build discriminator
discriminator = build_discriminator()
discriminator.summary()

In [None]:
# plot discriminator model
plot_model(discriminator, to_file='discriminator_plot.png', show_shapes=True, show_layer_names=True)



> Build generator



In [None]:
# build discriminator
noise_dimension = 100
generator = build_generator(noise_dimension)
generator.summary()

In [None]:
plot_model(generator, to_file='generator_plot.png', show_shapes=True, show_layer_names=True)



> Build combined GAN model



In [None]:
gan = build_gan(generator,discriminator)
gan.summary()

In [None]:
plot_model(gan, to_file='gan_plot.png', show_shapes=True, show_layer_names=True)

> Training GAN Model


In [None]:
X = load_real_samples()
train(generator,discriminator,gan,X,noise_dimension,n_epochs=100,n_batch=128,save_interval=10)