## Setup

In [None]:
%matplotlib qt
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_gan as tfgan
import numpy as np
import os, sys
from tqdm.notebook import tqdm
from pathlib import Path

sys.path.append( os.path.abspath('..') )
import utils

In [None]:
Path('CelebA').mkdir(exist_ok=True)
os.chdir('CelebA')

In [None]:
dataset = tf.data.Dataset.from_tensor_slices(np.load(os.path.join('..', '..', 'celeba.npy')))
dataset = dataset.map(lambda img: (tf.cast(img, tf.float32) - 127.5) / 127.5)
NUM_IMAGES = int(dataset.cardinality())

## 1 Models

### 1.1 Architecure

In [None]:
def generator_model(latent_dims):
    return tf.keras.Sequential([
        tf.keras.layers.Dense(6*6*512, input_shape=(latent_dims,)),
        tf.keras.layers.LeakyReLU(0.2),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Reshape((6, 6, 512)),
        #6x6x512
        
        tf.keras.layers.UpSampling2D(size=2, interpolation='nearest'),
        tf.keras.layers.Conv2D(256, kernel_size=3, strides=1, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        tf.keras.layers.BatchNormalization(),
        #12x12x256
        
        tf.keras.layers.UpSampling2D(size=2, interpolation='nearest'),
        tf.keras.layers.Conv2D(128, kernel_size=3, strides=1, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        tf.keras.layers.BatchNormalization(),
        #24x24x128
        
        tf.keras.layers.UpSampling2D(size=2, interpolation='nearest'),
        tf.keras.layers.Conv2D(64, kernel_size=3, strides=1, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        tf.keras.layers.BatchNormalization(),
        #48x48x64
        
        tf.keras.layers.Conv2D(3, kernel_size=1, strides=1, padding='same', activation='tanh')
        #48x48x3
    ])

In [None]:
def discriminator_model():
    return tf.keras.Sequential([
        tf.keras.layers.Conv2D(64, kernel_size=1, strides=2, padding='same', input_shape=(48,48,3)),
        tf.keras.layers.LeakyReLU(0.2),
        tf.keras.layers.Dropout(0.3),
        #48x48x64
        
        tf.keras.layers.Conv2D(128, kernel_size=3, strides=2, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        tf.keras.layers.Dropout(0.3),
        #24x24x128
        
        tf.keras.layers.Conv2D(256, kernel_size=3, strides=2, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        tf.keras.layers.Dropout(0.3),
        #12x12x256
        
        tf.keras.layers.Conv2D(512, kernel_size=3, strides=2, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        tf.keras.layers.Dropout(0.3),
        #6x6x512

        tf.keras.layers.Conv2D(512, kernel_size=6, strides=1, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        tf.keras.layers.Dropout(0.3),
        #1x1x512

        tf.keras.layers.Dense(1)
    ])

### 1.2 Losses

The binary cross entropy (BCE) between $y$ and $\hat{y}$ is calculated as:

$$
    \mathrm{BCE}(y, \hat{y}) = - y \log\left(\hat{y}\right) - (1-y) \log\left(1 - \hat{y}\right)
$$

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

The generator tries to maximize the chance of the discriminator being wrong. This is equivalent of trying to minimize the following loss function:

$$
    J^{(G)} = -\log\bigl(D\bigl(G(z)\bigr)\bigr)
$$

In [None]:
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

The discriminator tries to correctly classify real data as real and fake data as fake. This is equivalent to minimizing the following loss function:

$$
    J^{(D)} = -\log\bigr(D(x)\bigl) - \log\bigl(1 - D\bigl(G(z)\bigr)\bigr)
$$

Here we scale down the loss by a factor of $\;0.5$ and apply a one sided label smoothing of $\:0.9$

In [None]:
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(0.9*tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return 0.5 * (real_loss + fake_loss)

## 2 Training

### 2.1 Main functions

In [None]:
def discriminator_train_step(generator, discriminator, images, latent_dims):
    noise = tf.random.normal([images.shape[0], latent_dims])
    with tf.GradientTape() as disc_tape:
        generated_imgs = generator(noise, training=True)
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_imgs, training=True)
        loss_D = discriminator_loss(real_output, fake_output)
    
    grads_D = disc_tape.gradient(loss_D, discriminator.trainable_variables)
    discriminator.optimizer.apply_gradients(zip(grads_D, discriminator.trainable_variables))

In [None]:
def generator_train_step(generator, discriminator, batch_size, latent_dims):
    noise = tf.random.normal([batch_size, latent_dims])
    with tf.GradientTape() as gen_tape:
        generated_imgs = generator(noise, training=True)
        fake_output = discriminator(generated_imgs, training=True)
        loss_G = generator_loss(fake_output)
    
    grads_G = gen_tape.gradient(loss_G, generator.trainable_variables)
    generator.optimizer.apply_gradients(zip(grads_G, generator.trainable_variables))

In [None]:
def train(generator, discriminator, dataset, epochs, batch_size, callbacks=None):
    latent_dims = generator.input_shape[1]
    num_batches = int(1 + (NUM_IMAGES - 1) // batch_size)
    
    generator_step = tf.function(generator_train_step)
    discriminator_step = tf.function(discriminator_train_step)
    for epoch in tqdm(range(epochs)):
        for c in callbacks:
            c.on_epoch_begin(epoch=epoch + 1, generator=generator, discriminator=discriminator)
        
        for batch in tqdm(dataset, leave=False, total=num_batches):
            discriminator_step(generator, discriminator, batch, latent_dims)
            generator_step(generator, discriminator, batch_size, latent_dims)
        
        for c in callbacks:
            c.on_epoch_end(epoch=epoch + 1, generator=generator, discriminator=discriminator)

### 2.2 Training Model

This callback can be used to save a copy of all the generators produced for each epoch

In [None]:
class SaveModelsCallback(tf.keras.callbacks.Callback):
    def __init__(self, g_path_format, d_path_format):
        self.__g_path_format = g_path_format
        self.__d_path_format = d_path_format

    def on_epoch_begin(self, **kwargs): return

    def on_epoch_end(self, epoch, generator, discriminator, **kwargs):
        generator.save    (self.__g_path_format.format(epoch), overwrite=True, save_format='h5')
        discriminator.save(self.__d_path_format.format(epoch), overwrite=True, save_format='h5')

In [None]:
BATCH_SIZE = 32
LATENT_DIMS = 256

In [None]:
generator = generator_model(LATENT_DIMS)
generator.optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator = discriminator_model()
discriminator.optimizer = tf.keras.optimizers.Adam(1e-4)

In [None]:
timer = utils.callback.TimerCallback()
save_samples = utils.callback.SaveSamplesCallback(
    path_format='epoch-{}',
    inputs=tf.random.normal((8*8, LATENT_DIMS)),
    n_cols=8,
    savefig_kwargs={'bbox_inches': 'tight', 'pad_inches': 0, 'dpi': 256},
    grid_params={'border':1, 'pad':1, 'pad_value':0.0},
    transform_samples=lambda samples: (1 + samples) * 0.5
)
save_models = SaveModelsCallback('generator-{}.h5', 'discriminator-{}.h5')

In [None]:
train(
    generator,
    discriminator,
    dataset=dataset.shuffle(1024).batch(BATCH_SIZE),
    epochs=20,
    batch_size=BATCH_SIZE,
    callbacks=[timer, save_samples, save_models]
)

\
In windows the command bellow is used to turn down the machine after the training finishes, very useful if you wanna let the computer running while you go to sleep :)

In [None]:
# !shutdown /s /t 60