In [None]:
import os
import time
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow.train import Checkpoint, CheckpointManager
from tensorflow.data import Dataset
from tensorflow.data.experimental import AUTOTUNE
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import Mean
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import TruncatedNormal, RandomNormal
from tensorflow.keras.layers import Input, Dense, Reshape, BatchNormalization, Conv2D, Conv2DTranspose, \
        LeakyReLU, Flatten, SpatialDropout2D, Dropout, MaxPool2D, GlobalAvgPool2D, Concatenate, LayerNormalization

from IPython import display
import numpy as np

In [None]:
BASE_PATH = 'C:/Users/s4571730/Downloads/img_align_celeba'
RANDOM_STATE = 7
SHUFFLE_BUFFER = 32_000
IMAGE_SIZE = (178, 218)
BATCH_SIZE = 128
GEN_NOISE_SHAPE = (6, 5, 8)
PREDICT_COUNT = 9
GEN_LR = 4e-6
GEN_BETA_1 = 0.5
DISC_LR = 1e-6
DISC_BETA_1 = 0.9
GEN_RELU_ALPHA = 0.2
DISC_RELU_ALPHA = 0.3
EPOCHS = 50
DISC_LABEL_SMOOTHING = 0.25
PLOTS_DPI = 150
RETRAIN = os.path.isfile('./ckpt/checkpoint')

In [None]:
image_names = Dataset.list_files(os.path.join(BASE_PATH, '*.jpg'), seed = RANDOM_STATE)
image_count = image_names.cardinality().numpy()
print(f"\nTotal number of image files: {image_count}\n")

In [None]:
def load_image_data(filename):
    img = tf.io.read_file(filename)
    img = tf.io.decode_jpeg(img, channels = 3)
    img = tf.image.resize(img, IMAGE_SIZE)
    return (img - 127.5)/127.5

train_ds = image_names.cache() \
        .shuffle(SHUFFLE_BUFFER) \
        .map(load_image_data, num_parallel_calls = AUTOTUNE) \
        .batch(BATCH_SIZE, drop_remainder = True) \
        .prefetch(buffer_size = AUTOTUNE)

train_ds

In [None]:
fig, axes = plt.subplots(nrows = 3, ncols = 3, figsize = (9, 11))

sample_images = [i for i in train_ds.take(1)][0].numpy()

for i, ax in enumerate(axes.flatten()):
    ax.imshow((sample_images[i] * 0.5) + 0.5)
    ax.axis(False)
    ax.grid(False)

plt.tight_layout()

In [None]:
def generator_model():
    weight_init = TruncatedNormal(mean = 0.0, stddev = 0.02)
    model = Sequential(name='Generator')
    model.add(Flatten(input_shape=GEN_NOISE_SHAPE))
    model.add(Dense(6 * 5 * 512, use_bias = False, kernel_initializer = weight_init, 
                activation = LeakyReLU(GEN_RELU_ALPHA), name = 'Gen_Dense'))
    model.add(Reshape((6, 5, 512), name = 'Gen_Reshape'))
    model.add(SpatialDropout2D(0.2, name = 'Gen_SD_1'))
    model.add(Conv2DTranspose(512, (3, 3), padding='same', activation=LeakyReLU(GEN_RELU_ALPHA), use_bias = False,
                               kernel_initializer=weight_init, name='Gen_Conv_T_1'))
    model.add(Conv2DTranspose(256, (3, 3), padding = 'same', strides = (2, 2), use_bias = False,
                               kernel_initializer = weight_init, name = 'Gen_Conv_T_2'))
    model.add(BatchNormalization(name = 'Gen_BN_1'))
    model.add(LeakyReLU(GEN_RELU_ALPHA, name = 'Gen_LR_1'))
    model.add(Conv2DTranspose(128, (4, 4), padding = 'same', activation = LeakyReLU(GEN_RELU_ALPHA), use_bias = False, 
                               kernel_initializer = weight_init, name='Gen_Conv_T_3'))
    model.add(Conv2DTranspose(64, (4, 4), padding = 'same', strides = (2, 2), use_bias = False, 
                               kernel_initializer = weight_init, name = 'Gen_Conv_T_4'))
    model.add(BatchNormalization(name = 'Gen_BN_2'))
    model.add(LeakyReLU(GEN_RELU_ALPHA, name = 'Gen_LR_2'))
    model.add(SpatialDropout2D(0.15, name = 'Gen_SD_3'))
    model.add(Conv2DTranspose(8, (6, 6), padding = 'same', strides = (2, 2), use_bias = False, 
                               kernel_initializer = weight_init, name = 'Gen_Conv_T_6'))
    model.add(BatchNormalization(name = 'Gen_BN_4'))
    model.add(LeakyReLU(GEN_RELU_ALPHA, name = 'Gen_LR_4'))
    model.add(SpatialDropout2D(0.15, name = 'Gen_SD_4'))
    model.add(Conv2DTranspose(8, (7, 7), padding = 'same', activation = LeakyReLU(GEN_RELU_ALPHA), use_bias = False,
                               kernel_initializer = weight_init, strides = (2, 2), name = 'Gen_Conv_T_7'))
    model.add(Conv2DTranspose(3, (5, 5), padding = 'same', kernel_initializer = weight_init, use_bias = False,
                               activation = 'tanh', name = 'Gen_Conv_T_8'))
    return model
    
generator = generator_model()
generator.summary()

In [None]:
def discriminator_model():
    input_layer = Input(shape = (*IMAGE_SIZE, 3), name = 'Disc_Input')
    
    conv_1 = Conv2D(32, (4, 4), activation = LeakyReLU(DISC_RELU_ALPHA), padding = 'same', name = 'Disc_Conv_1')(input_layer)
    max_pool_1 = MaxPool2D(2, name = 'Disc_MP_1')(conv_1)
    conv_2 = Conv2D(64, (4, 4), activation = LeakyReLU(DISC_RELU_ALPHA), padding = 'same', name = 'Disc_Conv_2')(max_pool_1)
    max_pool_2 = MaxPool2D(2, name = 'Disc_MP_2')(conv_2)
    global_pool_1 = GlobalAvgPool2D(name = 'Disc_GAP_1')(max_pool_2)
    
    sp_dropout_1 = SpatialDropout2D(0.2, name = 'Disc_SD_1')(max_pool_2)
    conv_3 = Conv2D(128, (3, 3), activation = LeakyReLU(DISC_RELU_ALPHA), padding = 'same', name = 'Disc_Conv_3')(sp_dropout_1)
    max_pool_3 = MaxPool2D(2, name = 'Disc_MP_3')(conv_3)
    conv_4 = Conv2D(256, (3, 3), activation = LeakyReLU(DISC_RELU_ALPHA), padding = 'same', name = 'Disc_Conv_4')(max_pool_3)
    max_pool_4 = MaxPool2D(2, name = 'Disc_MP_4')(conv_4)
    global_pool_2 = GlobalAvgPool2D(name = 'Disc_GAP_2')(max_pool_4)
    
    sp_dropout_2 = SpatialDropout2D(0.2, name = 'Disc_SD_2')(max_pool_4)
    conv_5 = Conv2D(512, (2, 2), activation = LeakyReLU(DISC_RELU_ALPHA), padding = 'same', name = 'Disc_Conv_5')(sp_dropout_2)
    max_pool_5 = MaxPool2D(2, name = 'Disc_MP_5')(conv_5)
    global_pool_3 = GlobalAvgPool2D(name = 'Disc_GAP_3')(max_pool_5)
    
    concat = Concatenate(name = 'Disc_Concat')([global_pool_1, global_pool_2, global_pool_3])
    dropout = Dropout(0.2, name = 'Disc_Dropout')(concat)
    dense_1 = Dense(32, activation = LeakyReLU(DISC_RELU_ALPHA), name = 'Disc_Dense_1')(dropout)
    dense_2 = Dense(1, name = 'Disc_Dense_2')(dense_1)
    
    return Model(inputs = input_layer, outputs = dense_2, name = 'Discriminator')
    
discriminator = discriminator_model()
discriminator.summary()

In [None]:
cross_entropy = BinaryCrossentropy(from_logits = True)
gen_mean_loss = Mean(name = "Generator mean loss")
disc_mean_loss = Mean(name = "Discriminator mean loss")
generator_optimizer = Adam(GEN_LR, beta_1 = GEN_BETA_1)
discriminator_optimizer = Adam(DISC_LR, beta_1 = DISC_BETA_1)

In [None]:
checkpoint_dir = './ckpt'

checkpoint = Checkpoint(
    step = tf.Variable(1),
    generator_optimizer = generator_optimizer,
    discriminator_optimizer = discriminator_optimizer,
    generator = generator,
    discriminator = discriminator)

ckpt_manager = CheckpointManager(checkpoint, checkpoint_dir, max_to_keep = 5)

EPOCH_START = 1
if RETRAIN:
    checkpoint.restore(ckpt_manager.latest_checkpoint)
    EPOCH_START = checkpoint.step.numpy()

print(f"Starting training from Epoch {EPOCH_START}")

In [None]:
tf.random.set_seed(RANDOM_STATE)
seed_noise = tf.random.normal([PREDICT_COUNT, *GEN_NOISE_SHAPE], seed = RANDOM_STATE)

In [None]:
def generate_images(seed, save = False, epoch = None):
    pred = generator(seed, training = False)

    fig, axes = plt.subplots(nrows = 3, ncols = 3, figsize = (9, 11))

    for i, ax in enumerate(axes.flatten()):
        ax.imshow((pred[i] * 0.5) + 0.5)
        ax.axis(False)
        ax.grid(False)

    plt.suptitle('Generator Predictions', fontsize = 20)
    
    plt.tight_layout()

    if save:
        plt.savefig(f'Pred_Epoch_{epoch:04d}.png', dpi = PLOTS_DPI, facecolor = 'white', 
                transparent = False, bbox_inches = 'tight')
        plt.close()
    
generate_images(seed_noise)

In [None]:
def discriminator_loss(real_output, fake_output):
    pos_labels = tf.ones_like(real_output) - (tf.random.uniform(real_output.shape) * DISC_LABEL_SMOOTHING)
    neg_labels = tf.zeros_like(fake_output) + (tf.random.uniform(fake_output.shape) * DISC_LABEL_SMOOTHING)
    real_loss = cross_entropy(pos_labels, real_output)
    fake_loss = cross_entropy(neg_labels, fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [None]:
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, *GEN_NOISE_SHAPE])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training = True)

        real_output = discriminator(images, training = True)
        fake_output = discriminator(generated_images, training = True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gen_mean_loss(gen_loss)
    disc_mean_loss(disc_loss)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

In [None]:
gen_losses = []
disc_losses = []

def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()

        gen_mean_loss.reset_states()
        disc_mean_loss.reset_states()
        
        print(f"\nTraining Epoch {epoch + EPOCH_START}\n")
        
        for batch_ind, image_batch in enumerate(dataset):
            train_step(image_batch)

            if (batch_ind + 1) % 10 == 0:
                print(". ", end = '')
            if (batch_ind + 1) % 250 == 0:
                print(f"{batch_ind + 1}")
        
        checkpoint.step.assign_add(1)

        display.clear_output(wait = True)
        
        generate_images(seed_noise, True, epoch + EPOCH_START)

        if (epoch + EPOCH_START) % 5 == 0:
            ckpt_manager.save()
            
        gen_losses.append(gen_mean_loss.result())
        disc_losses.append(disc_mean_loss.result())

        print(f"\nEpoch: {epoch + EPOCH_START}\n")
        print(f'Generator Loss: {gen_mean_loss.result():.4f}')
        print(f'Discriminator Loss: {disc_mean_loss.result():.4f}')
        print (f'Time elapsed: {time.time() - start:.2f} s')

    display.clear_output(wait = True)

In [None]:
train(train_ds, EPOCHS)

print(f'Final Generator Loss: {gen_mean_loss.result()}')
print(f'Final Discriminator Loss: {disc_mean_loss.result()}')

In [None]:
generate_images(seed_noise)

In [None]:
!pip install tensorflow-datasets
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds

In [19]:
tfds.load('celeb_a', split='train')

[1mDownloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to C:\Users\s4571730\tensorflow_datasets\celeb_a\2.0.1...[0m


Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]