# Advanced GAN Base Workspace

Building a base Advanced Gan

In [None]:
import keras
import tensorflow as tf

from keras import layers
from keras import ops
import matplotlib.pyplot as plt
import os
import gdown
from zipfile import ZipFile


os.makedirs("celeba_gan")

url = "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"
output = "celeba_gan/data.zip"
gdown.download(url, output, quiet=True)

with ZipFile("celeba_gan/data.zip", "r") as zipobj:
    zipobj.extractall("celeba_gan")


dataset = keras.utils.image_dataset_from_directory(
    "celeba_gan", label_mode=None, image_size=(64, 64), batch_size=32
)
dataset = dataset.map(lambda x: x / 255.0)


for x in dataset:
    plt.axis("off")
    plt.imshow((x.numpy() * 255).astype("int32")[0])
    break

In [None]:
from numpy import expand_dims, ones, zeros, vstack
from numpy.random import rand, randint, randn
from keras.datasets.mnist import load_data
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Dense, Conv2D, Flatten, Dropout, LeakyReLU, Conv2DTranspose, Reshape
from matplotlib import pyplot
import logging
import io
from contextlib import redirect_stdout
import gdown
from zipfile import ZipFile


def initialize_logger():
    """
    Creates a logger for hyperparameters and training data.
    """
    # Get the name of the current script
    script_name = 'dcgan_base'
    log_filename = f'{script_name}.log'

    # Create a custom logger
    logger = logging.getLogger(script_name)
    logger.setLevel(logging.DEBUG)

    # Create handlers
    c_handler = logging.StreamHandler()
    f_handler = logging.FileHandler(log_filename)
    c_handler.setLevel(logging.INFO)
    f_handler.setLevel(logging.DEBUG)

    # Create formatters and add it to handlers
    c_format = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
    f_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    c_handler.setFormatter(c_format)
    f_handler.setFormatter(f_format)

    # Add handlers to the logger
    logger.addHandler(c_handler)
    logger.addHandler(f_handler)

    return logger


def log_model_summary(model, logger):
    with io.StringIO() as buf, redirect_stdout(buf):
        model.summary()
        summary = buf.getvalue()
    logger.info(summary)


def create_discriminator(in_shape=(28,28,1)):
    """
    Creates a discrimator model

    Input:
    in-shape: This is the shape of the photos that will be put into the discriminator model.

    Output:
    The model for discriminating fake vs real images
    """
    model =  Sequential()
    model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same', input_shape=in_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    # compile
    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model


def load_mnist():
    """
    This function loads the MNIST dataset, and scales it to be in a sigmoid (0, 1) range
    """
    (trainX, _), (_, _) = load_data()
    # add third dimension for color value
    X = expand_dims(trainX, axis=-1)
    X = X.astype('float32')
    X = X / 255.0
    return X



def select_real_samples(dataset, n_samples):
    """
    Selects some number of real samples to train with from the input dataset.
    Uses random to select random images. Labels the images as 'real' with label = 1

    Inputs:
    dataset: Input dataset(MNIST)
    n_sample: number of samples to select

    Return:
    X: Selected images
    y: tags for images
    """
    # get n number of random images from dataset
    i = randint(0, dataset.shape[0], n_samples)

    X = dataset[i]
    y = ones((n_samples, 1))

    return X, y



def initial_create_fake_samples(n_samples):
    """
    Creates fake samples to train discriminator with correct dimensions

    Input:
    n_samples: number of samples to create

    Return:
    X: fake images
    y: image tags for training, 0 to mean not real images
    """
    X = rand(28 * 28 * n_samples)

    # Reshape and add tags to show fake
    X = X.reshape((n_samples, 28, 28, 1))
    y = zeros((n_samples, 1))
    return X, y


def train_discriminator(model, dataset, iterations=100, batch_size=256):
    """
    Trains the discriminator using mnist dataset and fake images.
    Takes half batch size of real and fake for each iteration.

    Inputs:
    model: input model
    dataset: loaded dataset (MNIST)
    iterations: number of iterations of training
    batch_size: images to train with in each iteration
    """

    for i in range(iterations):
        # Select real images and train discriminator
        X_real, y_real = select_real_samples(dataset, int(batch_size / 2))
        _, real_acc = model.train_on_batch(X_real, y_real)

        # Select fake images and train discriminator
        X_fake, y_fake = initial_create_fake_samples(int(batch_size / 2))
        _, fake_acc = model.train_on_batch(X_fake, y_fake)

        #Performance
        print(f'>{i+1} real={real_acc*100:.0f}% fake={fake_acc*100:.0f}%')


def create_generator(latent_dim):
    """
    Creates a generator model.
    Starts with a 7x7 and reshapes to be 14x14 then 28x28.

    Input:
    latent-dim: Dimension of the latent space

    Return:
    The generator model
    """

    model = Sequential()
    n_nodes = 128 * 7 * 7

    model.add(Dense(n_nodes, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((7, 7, 128)))
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))

    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))

    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(1, (7,7), activation='sigmoid', padding='same'))
    return model


def generate_latent_points(latent_dim, n_samples):
    """
    Generates points in latent space.
    Used as input for the generator

    Inputs:
    latent_dim: Dimension of the latent space
    n_samples: number of samples to generate

    Return:
    x_input: points in latent space
    """
    x_input = randn(latent_dim * n_samples)
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input


def create_fake_samples(g_model, latent_dim, n_samples):
    """
    Generates fake samples from the generator model.
    Creates labels of 0 for the fake images

    Inputs:
    g_model: generator model
    latent_dim: Dimension of the latent space
    n_samples: number of samples to generate

    Return:
    X: fake images
    y: image tags for training, 0 to mean not real
    """
    x_input = generate_latent_points(latent_dim, n_samples)
    X = g_model.predict(x_input)
    y = zeros((n_samples, 1))
    return X, y


def create_gan(g_model, d_model):
    """
    Creating the GAN model. The generator is trained but the discriminator is untrainable.

    Inputs:
    g_model: generator model
    d_model: discriminator model

    Return:
    The GAN model
    """
    d_model.trainable = False

    model = Sequential()
    model.add(g_model)
    model.add(d_model)

    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model


def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=100, n_batch=256):
    """
    Trains the generator and discriminator.

    Inputs:
    g_model: generator model
    d_model: discriminator model
    gan_model: GAN model
    dataset: MNIST dataset
    latent_dim: Dimension of the latent space
    n_epochs: number of epochs to train
    n_batch: number of images to train with in each iteration
    """
    num_batch = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)

    for i in range(n_epochs):

        for j in range(num_batch):
            # generate real and fake samples
            X_real, y_real = select_real_samples(dataset, half_batch)
            X_fake, y_fake = create_fake_samples(g_model, latent_dim, half_batch)

            # train discriminator separately
            X, y = vstack((X_real, X_fake)), vstack((y_real, y_fake))
            d_loss, _ = d_model.train_on_batch(X, y)

            # train generator - use 1 value to find how often it is wrong to update model
            X_gan = generate_latent_points(latent_dim, n_batch)
            y_gan = ones((n_batch, 1))

            g_loss = gan_model.train_on_batch(X_gan, y_gan)

            print(f'>{i+1}, {j+1}/{num_batch}, d={d_loss:.3f}, g={g_loss:.3f}')

        if (i+1) % 10 == 0:
            performance_summary(i, g_model, d_model, dataset, latent_dim)


def create_plot(examples, epoch, n=10):
    """
    Creates a physical picture to look at to see how the training is coming after each 10 epochs

    Inputs:
    examples: images to plot
    epoch: epoch number
    n: number of images to plot
    """
    for i in range(n * n):
        pyplot.subplot(n, n, 1 + i)
        pyplot.axis('off')
        pyplot.imshow(examples[i, :, :, 0], cmap='gray_r')
    filename = 'generated_plot_e%03d.png' % (epoch+1)
    pyplot.savefig(filename)
    pyplot.close()


def performance_summary(epoch, g_model, d_model, dataset, latent_dim, n_samples=100):
    """
    Every 10 epochs, save a copy of the model as well as a plot of generated images.

    Inputs:
    epoch: epoch number
    g_model: generator model
    d_model: discriminator model
    dataset: MNIST dataset
    latent_dim: Dimension of the latent space
    n_samples: number of samples to generate
    """
    # take real samples and evaluate discriminator
    X_real, y_real = select_real_samples(dataset, n_samples)
    _, acc_real = d_model.evaluate(X_real, y_real, verbose=0)

    # take generated samples and evaluate using discriminator
    x_fake, y_fake = create_fake_samples(g_model, latent_dim, n_samples)
    _, acc_fake = d_model.evaluate(x_fake, y_fake, verbose=0)

    # summarize discriminator performance
    print(f'>Accuracy real: {acc_real*100:.0f}%, fake: {acc_fake*100:.0f}%')
    logger.info(f'>Accuracy real: {acc_real*100:.0f}%, fake: {acc_fake*100:.0f}%')

    # create the 10 x 10 plot picture
    create_plot(x_fake, epoch)

    # save the generator model tile file
    filename = 'generator_model_%03d.h5' % (epoch + 1)
    g_model.save(filename)


# load and prepare mnist training images
def load_real_samples():
    """
    Loads the MNIST dataset and scales it to be in a sigmoid (0, 1) range
    """
    (trainX, _), (_, _) = load_data()
    X = expand_dims(trainX, axis=-1)
    X = X.astype('float32')
    X = X / 255.0
    return X



logger = initialize_logger()
logger.info('Program started')
logger.info('------------------------------------------------------------')

# size of the latent space
latent_dim = 100
# create the discriminator
d_model = create_discriminator()
logger.info('Discriminator created')
log_model_summary(d_model, logger)
logger.info('------------------------------------------------------------')
# create the generator
g_model = create_generator(latent_dim)
logger.info('Generator created')
log_model_summary(g_model, logger)
logger.info('------------------------------------------------------------')
# create the gan
gan_model = create_gan(g_model, d_model)
logger.info('GAN created')
log_model_summary(gan_model, logger)
logger.info('------------------------------------------------------------')
# load image data
dataset = load_real_samples()
# train model
logger.info('Training started')
train(g_model, d_model, gan_model, dataset, latent_dim)

logger.info('Training finished')