# Advanced GAN Base Workspace

Building a base Advanced Gan

In [6]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import mixed_precision
from tensorflow import keras
from tensorflow.data import Dataset
from numpy import expand_dims, ones, zeros, vstack
from numpy.random import rand, randint, randn
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Dense, Conv2D, Flatten, Dropout, LeakyReLU, Conv2DTranspose, Reshape
from matplotlib import pyplot
import logging
import io
from contextlib import redirect_stdout
import gdown
from zipfile import ZipFile
import os

# Enable mixed precision
mixed_precision.set_global_policy('mixed_float16')

def initialize_logger():
    """
    Creates a logger for hyperparameters and training data.
    """
    script_name = 'dcgan_base'
    log_filename = f'{script_name}.log'
    logger = logging.getLogger(script_name)
    logger.setLevel(logging.DEBUG)
    c_handler = logging.StreamHandler()
    f_handler = logging.FileHandler(log_filename)
    c_handler.setLevel(logging.INFO)
    f_handler.setLevel(logging.DEBUG)
    c_format = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
    f_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    c_handler.setFormatter(c_format)
    f_handler.setFormatter(f_format)
    logger.addHandler(c_handler)
    logger.addHandler(f_handler)
    return logger

def log_model_summary(model, logger):
    with io.StringIO() as buf, redirect_stdout(buf):
        model.summary()
        summary = buf.getvalue()
    logger.info(summary)

def create_discriminator(in_shape=(64,64,3)):
    """
    Creates a discrimator model

    Input:
    in-shape: This is the shape of the photos that will be put into the discriminator model.

    Output:
    The model for discriminating fake vs real images
    """
    model = Sequential()
    model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same', input_shape=in_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Conv2D(128, (3,3), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    opt = Adam(learning_rate=0.0005, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

def load_celeba():
    """
    Loads the CelebA dataset from Google Drive.
    """
    if not os.path.exists("celeba_gan"):
        os.makedirs("celeba_gan")
        url = "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"
        output = "celeba_gan/data.zip"
        gdown.download(url, output, quiet=True)
        with ZipFile("celeba_gan/data.zip", "r") as zipobj:
            zipobj.extractall("celeba_gan")

    dataset = keras.utils.image_dataset_from_directory(
        "celeba_gan", label_mode=None, image_size=(64, 64), batch_size=32
    )
    dataset = dataset.map(lambda x: x / 255.0)
    dataset = dataset.cache().prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return dataset

def select_real_samples(dataset, n_samples):
    """
    Selects some number of real samples to train with from the input dataset.
    Uses random to select random images. Labels the images as 'real' with label = 1

    Inputs:
    dataset: Input dataset(MNIST)
    n_sample: number of samples to select

    Return:
    X: Selected images
    y: tags for images
    """
    dataset_np = np.concatenate([x for x in dataset], axis=0)
    i = randint(0, dataset_np.shape[0], n_samples)
    X = dataset_np[i]
    y = ones((n_samples, 1))
    return X, y

def initial_create_fake_samples(n_samples):
    """
    Creates fake samples to train discriminator with correct dimensions

    Input:
    n_samples: number of samples to create

    Return:
    X: fake images
    y: image tags for training, 0 to mean not real images
    """
    X = rand(64 * 64 * 3 * n_samples)
    X = X.reshape((n_samples, 64, 64, 3))
    y = zeros((n_samples, 1))
    return X, y

def train_discriminator(model, dataset, iterations=100, batch_size=256):
    """
    Trains the discriminator using mnist dataset and fake images.
    Takes half batch size of real and fake for each iteration.

    Inputs:
    model: input model
    dataset: loaded dataset (MNIST)
    iterations: number of iterations of training
    batch_size: images to train with in each iteration
    """
    for i in range(iterations):
        X_real, y_real = select_real_samples(dataset, int(batch_size / 2))
        _, real_acc = model.train_on_batch(X_real, y_real)
        X_fake, y_fake = initial_create_fake_samples(int(batch_size / 2))
        _, fake_acc = model.train_on_batch(X_fake, y_fake)
        print(f'>{i+1} real={real_acc*100:.0f}% fake={fake_acc*100:.0f}%')

def create_generator(latent_dim):
    """
    Creates a generator model.

    Input:
    latent-dim: Dimension of the latent space

    Return:
    The generator model
    """
    model = Sequential()
    n_nodes = 256 * 8 * 8
    model.add(Dense(n_nodes, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((8, 8, 256)))
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2DTranspose(64, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(3, (3,3), activation='tanh', padding='same'))
    return model

def generate_latent_points(latent_dim, n_samples):
    """
    Generates points in latent space.
    Used as input for the generator

    Inputs:
    latent_dim: Dimension of the latent space
    n_samples: number of samples to generate

    Return:
    x_input: points in latent space
    """
    x_input = randn(latent_dim * n_samples)
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input

def create_fake_samples(g_model, latent_dim, n_samples):
    """
    Generates fake samples from the generator model.
    Creates labels of 0 for the fake images

    Inputs:
    g_model: generator model
    latent_dim: Dimension of the latent space
    n_samples: number of samples to generate

    Return:
    X: fake images
    y: image tags for training, 0 to mean not real
    """
    x_input = generate_latent_points(latent_dim, n_samples)
    X = g_model.predict(x_input)
    y = zeros((n_samples, 1))
    return X, y

def create_gan(g_model, d_model):
    """
    Creating the GAN model. The generator is trained but the discriminator is untrainable.

    Inputs:
    g_model: generator model
    d_model: discriminator model

    Return:
    The GAN model
    """
    d_model.trainable = False
    model = Sequential()
    model.add(g_model)
    model.add(d_model)
    opt = Adam(learning_rate=0.0005, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=10, n_batch=1024):
    """
    Trains the generator and discriminator.

    Inputs:
    g_model: generator model
    d_model: discriminator model
    gan_model: GAN model
    dataset: MNIST dataset
    latent_dim: Dimension of the latent space
    n_epochs: number of epochs to train
    n_batch: number of images to train with in each iteration
    """
    num_batch = int(len(dataset) * 32 / n_batch)
    half_batch = int(n_batch / 2)
    for i in range(n_epochs):
        for j in range(num_batch):
            X_real, y_real = select_real_samples(dataset, half_batch)
            X_fake, y_fake = create_fake_samples(g_model, latent_dim, half_batch)
            X, y = vstack((X_real, X_fake)), vstack((y_real, y_fake))
            d_loss, _ = d_model.train_on_batch(X, y)
            X_gan = generate_latent_points(latent_dim, n_batch)
            y_gan = ones((n_batch, 1))
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            print(f'>{i+1}, {j+1}/{num_batch}, d={d_loss:.3f}, g={g_loss:.3f}')
        if (i+1) % 2 == 0:
            performance_summary(i, g_model, d_model, dataset, latent_dim)


def create_plot(examples, epoch, n=10):
    """
    Creates a physical picture to look at to see how the training is coming after each 10 epochs

    Inputs:
    examples: images to plot
    epoch: epoch number
    n: number of images to plot
    """
    examples = examples.astype(np.float32)  # Convert to float32 for plotting
    examples = (examples * 127.5 + 127.5).astype(np.uint8)  # Rescale to [0, 255] and convert to uint8

    for i in range(n * n):
        pyplot.subplot(n, n, 1 + i)
        pyplot.axis('off')
        pyplot.imshow(examples[i])
    filename = 'generated_plot_e%03d.png' % (epoch+1)
    pyplot.savefig(filename)
    pyplot.close()


def performance_summary(epoch, g_model, d_model, dataset, latent_dim, n_samples=100):
    """
    Every 10 epochs, save a copy of the model as well as a plot of generated images.

    Inputs:
    epoch: epoch number
    g_model: generator model
    d_model: discriminator model
    dataset: MNIST dataset
    latent_dim: Dimension of the latent space
    n_samples: number of samples to generate
    """
    X_real, y_real = select_real_samples(dataset, n_samples)
    _, acc_real = d_model.evaluate(X_real, y_real, verbose=0)
    x_fake, y_fake = create_fake_samples(g_model, latent_dim, n_samples)
    _, acc_fake = d_model.evaluate(x_fake, y_fake, verbose=0)
    print(f'>Accuracy real: {acc_real*100:.0f}%, fake: {acc_fake*100:.0f}%')
    logger.info(f'>Accuracy real: {acc_real*100:.0f}%, fake: {acc_fake*100:.0f}%')
    create_plot(x_fake, epoch)
    filename = 'generator_model_%03d.h5' % (epoch + 1)
    g_model.save(filename)

# Initialize the logger
logger = initialize_logger()
logger.info('Program started')
logger.info('------------------------------------------------------------')

# size of the latent space
latent_dim = 100
# create the discriminator
d_model = create_discriminator()
logger.info('Discriminator created')
log_model_summary(d_model, logger)
logger.info('------------------------------------------------------------')
# create the generator
g_model = create_generator(latent_dim)
logger.info('Generator created')
log_model_summary(g_model, logger)
logger.info('------------------------------------------------------------')
# create the gan
gan_model = create_gan(g_model, d_model)
logger.info('GAN created')
log_model_summary(gan_model, logger)
logger.info('------------------------------------------------------------')
# load image data
dataset = load_celeba()
# train model
logger.info('Training started')
train(g_model, d_model, gan_model, dataset, latent_dim)
logger.info('Training finished')


dcgan_base - INFO - Program started
dcgan_base - INFO - Program started
dcgan_base - INFO - Program started
dcgan_base - INFO - Program started
dcgan_base - INFO - Program started
dcgan_base - INFO - Program started
INFO:dcgan_base:Program started
dcgan_base - INFO - ------------------------------------------------------------
dcgan_base - INFO - ------------------------------------------------------------
dcgan_base - INFO - ------------------------------------------------------------
dcgan_base - INFO - ------------------------------------------------------------
dcgan_base - INFO - ------------------------------------------------------------
dcgan_base - INFO - ------------------------------------------------------------
INFO:dcgan_base:------------------------------------------------------------
dcgan_base - INFO - Discriminator created
dcgan_base - INFO - Discriminator created
dcgan_base - INFO - Discriminator created
dcgan_base - INFO - Discriminator created
dcgan_base - INFO - D

Found 202599 files belonging to 1 classes.


dcgan_base - INFO - Training started
dcgan_base - INFO - Training started
dcgan_base - INFO - Training started
dcgan_base - INFO - Training started
dcgan_base - INFO - Training started
dcgan_base - INFO - Training started
INFO:dcgan_base:Training started






>1, 1/197, d=0.696, g=0.691
>1, 2/197, d=0.542, g=0.675
>1, 3/197, d=0.470, g=0.622
>1, 4/197, d=0.505, g=0.520
>1, 5/197, d=0.702, g=0.446
>1, 6/197, d=0.862, g=0.569
>1, 7/197, d=0.706, g=1.065
>1, 8/197, d=0.592, g=1.318
>1, 9/197, d=0.670, g=0.982
>1, 10/197, d=0.712, g=0.707
>1, 11/197, d=0.711, g=0.646
>1, 12/197, d=0.655, g=0.769
>1, 13/197, d=0.565, g=0.896
>1, 14/197, d=0.526, g=0.801
>1, 15/197, d=0.523, g=0.649
>1, 16/197, d=0.532, g=0.565
>1, 17/197, d=0.581, g=0.526
>1, 18/197, d=0.657, g=0.572
>1, 19/197, d=0.678, g=0.704
>1, 20/197, d=0.678, g=0.714
>1, 21/197, d=0.633, g=0.729
>1, 22/197, d=0.601, g=0.807
>1, 23/197, d=0.626, g=0.843
>1, 24/197, d=0.669, g=0.737
>1, 25/197, d=0.713, g=0.731
>1, 26/197, d=0.701, g=0.891
>1, 27/197, d=0.647, g=1.129
>1, 28/197, d=0.580, g=1.194
>1, 29/197, d=0.587, g=1.188
>1, 30/197, d=0.625, g=1.038
>1, 31/197, d=0.665, g=0.880
>1, 32/197, d=0.683, g=0.967
>1, 33/197, d=0.687, g=0.845
>1, 34/197, d=0.666, g=0.899
>1, 35/197, d=0.660, g=

dcgan_base - INFO - >Accuracy real: 91%, fake: 6%
dcgan_base - INFO - >Accuracy real: 91%, fake: 6%
dcgan_base - INFO - >Accuracy real: 91%, fake: 6%
dcgan_base - INFO - >Accuracy real: 91%, fake: 6%
dcgan_base - INFO - >Accuracy real: 91%, fake: 6%
dcgan_base - INFO - >Accuracy real: 91%, fake: 6%
INFO:dcgan_base:>Accuracy real: 91%, fake: 6%


>Accuracy real: 91%, fake: 6%




>3, 1/197, d=0.698, g=0.649
>3, 2/197, d=0.697, g=0.662
>3, 3/197, d=0.693, g=0.686
>3, 4/197, d=0.694, g=0.711
>3, 5/197, d=0.689, g=0.739
>3, 6/197, d=0.698, g=0.751
>3, 7/197, d=0.703, g=0.741
>3, 8/197, d=0.698, g=0.754
>3, 9/197, d=0.697, g=0.780
>3, 10/197, d=0.696, g=0.781
>3, 11/197, d=0.688, g=0.781
>3, 12/197, d=0.688, g=0.781
>3, 13/197, d=0.691, g=0.790
>3, 14/197, d=0.688, g=0.753
>3, 15/197, d=0.695, g=0.727
>3, 16/197, d=0.704, g=0.705
>3, 17/197, d=0.707, g=0.701
>3, 18/197, d=0.697, g=0.712
>3, 19/197, d=0.691, g=0.714
>3, 20/197, d=0.691, g=0.707
>3, 21/197, d=0.687, g=0.703
>3, 22/197, d=0.690, g=0.713
>3, 23/197, d=0.697, g=0.730
>3, 24/197, d=0.694, g=0.754
>3, 25/197, d=0.693, g=0.797
>3, 26/197, d=0.694, g=0.824
>3, 27/197, d=0.699, g=0.820
>3, 28/197, d=0.702, g=0.814
>3, 29/197, d=0.727, g=0.787
>3, 30/197, d=0.731, g=0.774
>3, 31/197, d=0.750, g=0.748
>3, 32/197, d=0.741, g=0.748
>3, 33/197, d=0.733, g=0.787
>3, 34/197, d=0.701, g=0.869
>3, 35/197, d=0.665, g=

dcgan_base - INFO - >Accuracy real: 43%, fake: 56%
dcgan_base - INFO - >Accuracy real: 43%, fake: 56%
dcgan_base - INFO - >Accuracy real: 43%, fake: 56%
dcgan_base - INFO - >Accuracy real: 43%, fake: 56%
dcgan_base - INFO - >Accuracy real: 43%, fake: 56%
dcgan_base - INFO - >Accuracy real: 43%, fake: 56%
INFO:dcgan_base:>Accuracy real: 43%, fake: 56%


>Accuracy real: 43%, fake: 56%




>5, 1/197, d=0.706, g=0.803
>5, 2/197, d=0.695, g=0.820
>5, 3/197, d=0.703, g=0.782
>5, 4/197, d=0.691, g=0.751
>5, 5/197, d=0.670, g=0.783
>5, 6/197, d=0.678, g=0.772
>5, 7/197, d=0.674, g=0.757
>5, 8/197, d=0.676, g=0.719
>5, 9/197, d=0.671, g=0.693
>5, 10/197, d=0.682, g=0.678
>5, 11/197, d=0.690, g=0.665
>5, 12/197, d=0.690, g=0.656
>5, 13/197, d=0.684, g=0.666
>5, 14/197, d=0.679, g=0.672
>5, 15/197, d=0.685, g=0.717
>5, 16/197, d=0.681, g=0.739
>5, 17/197, d=0.683, g=0.764
>5, 18/197, d=0.695, g=0.794
>5, 19/197, d=0.710, g=0.857
>5, 20/197, d=0.702, g=0.880
>5, 21/197, d=0.700, g=0.902
>5, 22/197, d=0.693, g=0.893
>5, 23/197, d=0.691, g=0.858
>5, 24/197, d=0.692, g=0.848
>5, 25/197, d=0.706, g=0.779
>5, 26/197, d=0.721, g=0.706
>5, 27/197, d=0.728, g=0.678
>5, 28/197, d=0.727, g=0.674
>5, 29/197, d=0.714, g=0.702
>5, 30/197, d=0.696, g=0.755
>5, 31/197, d=0.677, g=0.793
>5, 32/197, d=0.660, g=0.814
>5, 33/197, d=0.656, g=0.782
>5, 34/197, d=0.667, g=0.778
>5, 35/197, d=0.685, g=

dcgan_base - INFO - >Accuracy real: 49%, fake: 11%
dcgan_base - INFO - >Accuracy real: 49%, fake: 11%
dcgan_base - INFO - >Accuracy real: 49%, fake: 11%
dcgan_base - INFO - >Accuracy real: 49%, fake: 11%
dcgan_base - INFO - >Accuracy real: 49%, fake: 11%
dcgan_base - INFO - >Accuracy real: 49%, fake: 11%
INFO:dcgan_base:>Accuracy real: 49%, fake: 11%


>Accuracy real: 49%, fake: 11%




>7, 1/197, d=0.748, g=0.658
>7, 2/197, d=0.729, g=0.684
>7, 3/197, d=0.703, g=0.708
>7, 4/197, d=0.681, g=0.731
>7, 5/197, d=0.665, g=0.751
>7, 6/197, d=0.649, g=0.761
>7, 7/197, d=0.645, g=0.765
>7, 8/197, d=0.651, g=0.766
>7, 9/197, d=0.663, g=0.769
>7, 10/197, d=0.676, g=0.772
>7, 11/197, d=0.694, g=0.822
>7, 12/197, d=0.708, g=0.801
>7, 13/197, d=0.723, g=0.767
>7, 14/197, d=0.727, g=0.773
>7, 15/197, d=0.719, g=0.786
>7, 16/197, d=0.710, g=0.779
>7, 17/197, d=0.700, g=0.784
>7, 18/197, d=0.679, g=0.800
>7, 19/197, d=0.656, g=0.841
>7, 20/197, d=0.640, g=0.895
>7, 21/197, d=0.638, g=0.970
>7, 22/197, d=0.646, g=1.041
>7, 23/197, d=0.674, g=0.935
>7, 24/197, d=0.715, g=0.801
>7, 25/197, d=0.747, g=0.792
>7, 26/197, d=0.763, g=0.679
>7, 27/197, d=0.738, g=0.688
>7, 28/197, d=0.712, g=0.758
>7, 29/197, d=0.693, g=0.773
>7, 30/197, d=0.676, g=0.724
>7, 31/197, d=0.667, g=0.725
>7, 32/197, d=0.681, g=0.693
>7, 33/197, d=0.697, g=0.716
>7, 34/197, d=0.710, g=0.763
>7, 35/197, d=0.738, g=

dcgan_base - INFO - >Accuracy real: 5%, fake: 99%
dcgan_base - INFO - >Accuracy real: 5%, fake: 99%
dcgan_base - INFO - >Accuracy real: 5%, fake: 99%
dcgan_base - INFO - >Accuracy real: 5%, fake: 99%
dcgan_base - INFO - >Accuracy real: 5%, fake: 99%
dcgan_base - INFO - >Accuracy real: 5%, fake: 99%
INFO:dcgan_base:>Accuracy real: 5%, fake: 99%


>Accuracy real: 5%, fake: 99%




>9, 1/197, d=0.694, g=0.839
>9, 2/197, d=0.694, g=0.827
>9, 3/197, d=0.691, g=0.848
>9, 4/197, d=0.686, g=0.805
>9, 5/197, d=0.671, g=0.756
>9, 6/197, d=0.675, g=0.729
>9, 7/197, d=0.684, g=0.711
>9, 8/197, d=0.692, g=0.729
>9, 9/197, d=0.720, g=0.659
>9, 10/197, d=0.733, g=0.620
>9, 11/197, d=0.731, g=0.662
>9, 12/197, d=0.723, g=0.698
>9, 13/197, d=0.707, g=0.700
>9, 14/197, d=0.688, g=0.688
>9, 15/197, d=0.671, g=0.688
>9, 16/197, d=0.660, g=0.704
>9, 17/197, d=0.653, g=0.730
>9, 18/197, d=0.646, g=0.759
>9, 19/197, d=0.645, g=0.768
>9, 20/197, d=0.657, g=0.789
>9, 21/197, d=0.659, g=0.794
>9, 22/197, d=0.672, g=0.795
>9, 23/197, d=0.680, g=0.769
>9, 24/197, d=0.692, g=0.759
>9, 25/197, d=0.713, g=0.795
>9, 26/197, d=0.717, g=0.789
>9, 27/197, d=0.727, g=0.767
>9, 28/197, d=0.735, g=0.753
>9, 29/197, d=0.731, g=0.773
>9, 30/197, d=0.722, g=0.808
>9, 31/197, d=0.706, g=0.797
>9, 32/197, d=0.687, g=0.818
>9, 33/197, d=0.679, g=0.829
>9, 34/197, d=0.671, g=0.855
>9, 35/197, d=0.669, g=

dcgan_base - INFO - >Accuracy real: 78%, fake: 59%
dcgan_base - INFO - >Accuracy real: 78%, fake: 59%
dcgan_base - INFO - >Accuracy real: 78%, fake: 59%
dcgan_base - INFO - >Accuracy real: 78%, fake: 59%
dcgan_base - INFO - >Accuracy real: 78%, fake: 59%
dcgan_base - INFO - >Accuracy real: 78%, fake: 59%
INFO:dcgan_base:>Accuracy real: 78%, fake: 59%


>Accuracy real: 78%, fake: 59%


dcgan_base - INFO - Training finished
dcgan_base - INFO - Training finished
dcgan_base - INFO - Training finished
dcgan_base - INFO - Training finished
dcgan_base - INFO - Training finished
dcgan_base - INFO - Training finished
INFO:dcgan_base:Training finished
