<a href="https://colab.research.google.com/github/MParsaMo/Generative-Adversarial-Network-GAN-on-MNIST-Digits/blob/main/GAN_mnist_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import Input, Dense, Dropout, Activation, Flatten
from keras.layers import LeakyReLU
from keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt

# Enable eager execution
tf.config.run_functions_eagerly(True)

# only the X_train (the actual images) is actively used from the MNIST dataset but we have to load all the data
(x_train, y_train), (x_test, y_test) = mnist.load_data()

#printing x_shapes befor reshaping
print(f'x_train.shape befor reshape : {x_train.shape}')
print(f'x_test.shape befor reshape : {x_test.shape}')

# we reshape the images(flatten it from 28*28 into 784)
x_train = x_train.reshape(-1, 784)
x_test = x_test.reshape(-1, 784)

#printing x_shapes after reshaping
print(f'x_train.shape after reshape : {x_train.shape}')
print(f'x_test.shape after reshape : {x_test.shape}')

# normalization
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
#__________________ now or features are ready________________________________

# set the dimensions of the noise (latent) vector
z_dim =100

# define a dictionary to store the loss-values
losses = {"D": [], "G": []}

# Optimizer for the generator
# when we set discriminator.trainable = False and then try to train the gan model,
# the adam optimizer instance gets confused because its internal state is tied to the variables it was originally initialized to optimize.
# It then encounters variables (from the now-frozen discriminator) that it's no longer supposed to update but were part of its initial configuration, leading to the "Unknown variable" error.
adam_generator = Adam(learning_rate=0.0002, beta_1=0.5)
# Optimizer for the discriminator
adam_discriminator = Adam(learning_rate=0.0002, beta_1=0.5)


# the generator transform the low-dimensional z random vector to a high dimensional vector (fake image) from 100 to 784
generator = Sequential()
generator.add(Dense(256, input_dim=z_dim, activation=LeakyReLU(alpha=0.2)))
generator.add(Dense(512, activation=LeakyReLU(alpha=0.2)))
generator.add(Dense(784, activation='sigmoid'))
generator.compile(loss='binary_crossentropy', optimizer=adam_generator, metrics=['accuracy']) # Use adam_generator
generator.summary()

# the discriminator output is a value in the range [0,1]
# 0: image is fake
# 1: image is real
discriminator = Sequential()
discriminator.add(Dense(512, input_dim=784, activation=LeakyReLU(alpha=0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256, activation=LeakyReLU(alpha=0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam_discriminator, metrics=['accuracy']) # Use adam_discriminator
discriminator.summary()

# for updating weights in generator we need to creat a model with generator and discriminator. but the discriminator should not be updated.
# Compile the GAN with the discriminator trainable flag set to False
# Note: For GAN, the combined model's optimizer should *only* optimize the generator's weights.
# The discriminator's weights should be frozen when training the combined GAN model.
discriminator.trainable = False # Freeze discriminator for the combined GAN model training
gan_input = Input(shape=(z_dim,))
gan_hidden = generator(gan_input)
gan_output = discriminator(gan_hidden)
gan = Model(inputs=gan_input, outputs=gan_output)
gan.compile(loss='binary_crossentropy', optimizer=adam_generator, metrics=['accuracy']) # Use adam_generator here as it only optimizes generator part
gan.summary()

# loss_values contains the (loss, accuracy) for every epoch
def plot_loss(loss_values):
    # index 0: loss (this is what we need)
    # index 1: accuracy
    print(loss_values)
    d_loss = [v[0] for v in loss_values["D"]]
    g_loss = [v[0] for v in loss_values["G"]]

    plt.figure(figsize=(10, 8))
    plt.plot(d_loss, label="Discriminator loss")
    plt.plot(g_loss, label="Generator loss")
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()


def plot_fake_images(n=10, dim=(1, 10)):
    noise = np.random.normal(0, 1, size=(n, z_dim))
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(n, 28, 28)

    plt.figure(figsize=(12, 2))
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i + 1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.show()


def train(epochs=1, batch_size=128):
    batch_count = x_train.shape[0] // batch_size
    for epoch in range(1, epochs + 1):
        print(f"Epoch {epoch}")
        for batch_i in range(batch_count):
            # ---------------------
            #  Train Discriminator
            # ---------------------

            # get a random batch of real images
            image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]

            # generate a batch of fake images
            noise = np.random.normal(0, 1, size=(batch_size, z_dim))
            generated_images = generator.predict(noise)

            # set features and targets for discriminator
            discriminator_features = np.concatenate((image_batch, generated_images))
            # Real images are label 0.9 (soft labels), fake images are label 0
            discriminator_targets = np.zeros((2 * batch_size, 1))
            discriminator_targets[:batch_size] = 0.9

            # Train discriminator on batch
            # Ensure discriminator is trainable when training it directly
            discriminator.trainable = True
            d_loss = discriminator.train_on_batch(discriminator_features, discriminator_targets)
            losses["D"].append(d_loss)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Generate new noise for the generator's training step
            # This forces the generator to produce a completely new set of fake images,
            # ensuring it learns a general mapping from latent space to image space.
            new_noise = np.random.normal(0, 1, size=(batch_size, z_dim))
            # Generator wants discriminator to classify fake images as real (label 1)
            generator_targets = np.ones((batch_size, 1)) # Labels for generator training are 1 (real)

            # Train generator on batch (discriminator is frozen here as per `gan` model setup)
            # Ensure discriminator is NOT trainable when training the combined GAN model
            # This is handled by `discriminator.trainable = False` before `gan` compilation
            g_loss = gan.train_on_batch(new_noise, generator_targets)
            losses["G"].append(g_loss)

    # we create 10 fake images after the training
    plot_fake_images()
    plot_loss(losses)

train(epochs=10, batch_size=128)