In [None]:
import keras
import tensorflow as tf

In [None]:
from keras.datasets import mnist 
from keras.layers import Input, Dense, Reshape, Flatten 
from keras.layers import BatchNormalization 
from keras.layers import LeakyReLU
from keras.models import Sequential, Model 
from keras.optimizers.legacy import Adam 
import matplotlib.pyplot as plt 
import numpy as np

In [None]:
# Define input image dimensions
# Large images take too much time and resources
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)

In [None]:
# Given input of noise (latent) vector, the Generator produces on image.
def build_generator():
    noise_shape=(100,) # 1D array of size 100 (latent vector / noise)

# Define generator network
# Here we are only using Dense Layers. But network can be complicated based on the application. for example, you can use VGG for super res. GAN.

    model = Sequential()

    model.add(Dense(256, input_shape=noise_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(np.prod(img_shape),activation='tanh'))
    model.add(Reshape(img_shape))

    model.summary()

    noise = Input(shape=noise_shape)
    img = model(noise)

    return Model(noise,img)

# Alpha - is a hyperparameter which controls the underlying value to which the function saturates negative network inputs
# Momentum - spped up the traning

In [None]:
# Given an input, the Discriminator outputs the likelihood of the image being real
# binary classification - true or false (we're calling it validity)

def build_discriminator():

    model = Sequential()

    model.add(Flatten(input_shape=img_shape)) 
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    img = Input(shape=img_shape)
    validity = model(img)

    return Model(img, validity)

# The validity is the Discriminator's guess of input being real or not.

In [None]:
# Now to pit them against each other.
# define a training function, load the data set, re-scale training images and set the ground truths

def train(epochs, batch_size=128, save_interval=2000):
    
    # Load the dataset
    (X_train, _), (_, _) = mnist.load_data()

    # Convert to float and Rescale -1 to 1 (Can also do 0 to 1)
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5

    # Add channels dimension. As the input to our generator and discriminator has a shape 28*28*1
    X_train = np.expand_dims(X_train, axis=3)

    half_batch = int(batch_size / 2)
    
    # loop through a number of epochs to train the discriminator by first selecting a random batch of images from the true dataset, generating a set of images from the generator, feeding both set of images into the discriminator, and finally setting the loss parameters for both the real and fake images, as well as the combined loss

    for epoch in range(epochs+2):

        # train discriminator

        # Select a random half batch of real images
        idx = np.random.randint(0, X_train.shape[0], half_batch)
        imgs = X_train[idx]

        noise = np.random.normal(0,1,(half_batch, 100))

        # Generate a half batch of fake images
        gen_imgs = generator.predict(noise)
        
        # Train the discriminator on real and fake images, separately
        # Research showed that separate training is more effective
        d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
        d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch,1)))

        # take average loss from real and fake images
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # within the same loop we train our generator, by setting the input noise and ultimately training the generator to have the discriminator label its samples as valid by specifying the gradient loss

        # train generator

        # create noise vectors as input for generator
        # create as many noise vectors as defined by the batch size.
        # Based on normal distribution. output will be of size (batch size, 100)
        noise = np.random.normal(0,1,(batch_size,100))

        # the generator wants the discriminator to label the generated samples as valid (ones)
        # this is where the generator is trying to trick discriminator into believing the genrated image is true (hence value of 1 for y)
        valid_y = np.array([1] * batch_size) # creates an array of all ones of size=batch size

        # generator is part of combined where it got directly linked with the discriminator
        # train the generator with noise as x and 1 as y
        # again, 1 as the output as it is adversarial and if generator did a great job of fooling the discriminator then the output would be 1 (true)
        g_loss = combined.train_on_batch(noise, valid_y)

        # Additionally, in order for us to keep track of our training process, we print the progress and save the sample image output depending on the epoch interval specified. 
        # Plot the progress
        
        print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1],g_loss))

        # if at save interval -> save generated image samples
        if epoch % save_interval == 0:
            save_imgs(epoch)

In [None]:
# when the specific sample_interval is hit, we call the sample_image function:

def save_imgs(epoch):
    r,c = 5, 5
    noise = np.random.normal(0, 1, (r * c, 100))
    gen_imgs = generator.predict(noise)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r,c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt,:,:,0],cmap='gray')
            axs[i,j].axis('off') 
            cnt += 1
    fig.savefig('./mnist100k/mnist_%d.png' % epoch)
    plt.close()

# this function saves our image 

In [None]:
# Let us also define our optimizer for easy use later on.
# That way if you change your mind, you can change it easily here
optimizer = Adam(0.0002, 0.5) # Learning rate and momentum

# Build and compile the discriminator first.
# Generator will be trained as part of the combined model, later.
# Pick the loss function and the type of metric to keep track. 
# Binary cross entropy as we are doing prediction and it is a better loss function compared to MSE or other.
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy',optimizer=optimizer, metrics=['accuracy'])

# build and compile our discriminator, pick the loss function

In [None]:
# since we are only generating (faking) images, let us not track any metrics.
generator = build_generator()
generator.compile(loss='binary_crossentropy', optimizer=optimizer)

# this build the generator and deffines the input noise

In [None]:
# In a GAN the generator network takes noise z as an input to produce its images
z = Input(shape=(100,))
img = generator(z)

In [None]:
# this ensures that when we combine our networks we only train the generator.
# this doesn't affect the above discriminator training
discriminator.trainable = False

In [None]:
# this specifies that our Discriminator will take the images generated by our Generator and true dataset and set its output to a parameter called valid, which will indicate whether the input is real or not
valid = discriminator(img) # Validity check on the generated image

In [None]:
# Here we combine the models and also set our loss function and optimizer.
# Again, we are only training the generator here.
# the ultimate goal here is for the generator to fool the discriminator.
# the combined model (stacked generator and discriminator) takes noise as input => generates images => determines validity

combined = Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

train(epochs=30000)

In [None]:
# Savel model for future use to generate fake images
# Compare with GAN4

generator.save('generator_model_test.h5') # Test the model on GAN4_predict...
# Change epochs back to 30K

# Epochs dictate the number of backward and forward propagations, the batch_size indicates the number of training samples per backward/forward propogation, and the sample_interval specifies after how many epochs we call our sample_image function