## Setup

In [None]:
%matplotlib qt
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_gan as tfgan
import numpy as np
import os, sys
from tqdm.notebook import tqdm
from pathlib import Path

sys.path.append( os.path.abspath('..') )
import utils

In [None]:
Path('CIFAR').mkdir(exist_ok=True)
os.chdir('CIFAR')

In [None]:
(x_train, _), (x_test, _) = tf.keras.datasets.cifar10.load_data()
data = np.concatenate((x_train, x_test))
data = (data.astype('float32') - 127.5) / 127.5  # normalize to [-1, 1]
assert data.shape == (60000, 32, 32, 3)  # (batch, height, width, channel)

## 1 Models

### 1.1 Architecture

CIFAR-10 was very hard to get to work, for these tests it was found that not using batch normalization or dropout and changing some hyperparameters made things work.

However, in hindsight the thing that was probably causing problems was the use of the momentum term $\beta_1$ in the Adam optimizer. Maybe Batch Normalization and Dropout would also work now, but this was the way that it was done for the document, so they will be commented out here.

In [None]:
def generator_model_transp_conv(latent_dims):
    return tf.keras.Sequential([
        tf.keras.layers.Dense(4*4*256, input_shape=(latent_dims,)),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Reshape((4, 4, 256)),

        tf.keras.layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.BatchNormalization(),

        tf.keras.layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.BatchNormalization(),
        
        tf.keras.layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),

        tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, padding='same', activation='tanh')
    ])

In [None]:
def generator_model_bilinear(latent_dims):
    return tf.keras.Sequential([
        tf.keras.layers.Dense(4*4*256, input_shape=(latent_dims,)),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Reshape((4, 4, 256)),
        
        tf.keras.layers.UpSampling2D(size=2, interpolation='bilinear'),
        tf.keras.layers.Conv2D(128, kernel_size=3, strides=1, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.BatchNormalization(),
        
        tf.keras.layers.UpSampling2D(size=2, interpolation='bilinear'),
        tf.keras.layers.Conv2D(128, kernel_size=3, strides=1, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.BatchNormalization(),
        
        tf.keras.layers.UpSampling2D(size=2, interpolation='bilinear'),
        tf.keras.layers.Conv2D(128, kernel_size=3, strides=1, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        
        tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, padding='same', activation='tanh')
    ])

In [None]:
def discriminator_model():
    return tf.keras.Sequential([
        tf.keras.layers.Conv2D(64, kernel_size=3, strides=2, padding='same', input_shape=(32,32,3)),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.Dropout(0.3),

        tf.keras.layers.Conv2D(128, kernel_size=3, strides=2, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.BatchNormalization(),
        # tf.keras.layers.Dropout(0.3),
        
        tf.keras.layers.Conv2D(filters=128, kernel_size=3, strides=2, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.BatchNormalization(),
        # tf.keras.layers.Dropout(0.3),
        
        tf.keras.layers.Conv2D(filters=256, kernel_size=3, strides=1, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.BatchNormalization(),
        # tf.keras.layers.Dropout(0.3),

        # tf.keras.layers.GlobalAvgPool2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dropout(0.5),
        
        tf.keras.layers.Dense(1)
    ])

### 1.2 Losses

The binary cross entropy (BCE) between $y$ and $\hat{y}$ is calculated as:

$$
    \mathrm{BCE}(y, \hat{y}) = - y \log\left(\hat{y}\right) - (1-y) \log\left(1 - \hat{y}\right)
$$

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

The generator tries to maximize the chance of the discriminator being wrong. This is equivalent of trying to minimize the following loss function:

$$
    J^{(G)} = -\log\bigl(D\bigl(G(z)\bigr)\bigr)
$$

In [None]:
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

The discriminator tries to correctly classify real data as real and fake data as fake. This is equivalent to minimizing the following loss function:

$$
    J^{(D)} = -\log\bigr(D(x)\bigl) - \log\bigl(1 - D\bigl(G(z)\bigr)\bigr)
$$

Here we scale down the loss by a factor of $\;0.5$

In [None]:
def discriminator_loss_normal(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return 0.5 * (real_loss + fake_loss)

This function applies one sided label smoothing of $\:0.9\:$ to the discriminator loss.

In [None]:
def discriminator_loss_smooth(real_output, fake_output):
    real_loss = cross_entropy(0.9 * tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return 0.5 * (real_loss + fake_loss)

## 2 Training

### 2.1 Main functions

In [None]:
def discriminator_train_step(generator, discriminator, images, latent_dims):
    noise = tf.random.normal([images.shape[0], latent_dims])
    with tf.GradientTape() as disc_tape:
        generated_imgs = generator(noise, training=True)
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_imgs, training=True)
        loss_D = discriminator_loss(real_output, fake_output)
    
    grads_D = disc_tape.gradient(loss_D, discriminator.trainable_variables)
    discriminator.optimizer.apply_gradients(zip(grads_D, discriminator.trainable_variables))

In [None]:
def generator_train_step(generator, discriminator, batch_size, latent_dims):
    noise = tf.random.normal([batch_size, latent_dims])
    with tf.GradientTape() as gen_tape:
        generated_imgs = generator(noise, training=True)
        fake_output = discriminator(generated_imgs, training=True)
        loss_G = generator_loss(fake_output)
    
    grads_G = gen_tape.gradient(loss_G, generator.trainable_variables)
    generator.optimizer.apply_gradients(zip(grads_G, generator.trainable_variables))

In [None]:
def train(generator, discriminator, data, epochs, batch_size=None, callbacks=None):
    latent_dims = generator.input_shape[1]
    batch_size = batch_size if batch_size is not None else 32
    num_batches = 1 + (data.shape[0] - 1) // batch_size
    dataset = tf.data.Dataset.from_tensor_slices(data).shuffle(data.shape[0]).batch(batch_size)
    
    generator_step = tf.function(generator_train_step)
    discriminator_step = tf.function(discriminator_train_step)
    for epoch in tqdm(range(epochs)):
        for c in callbacks:
            c.on_epoch_begin(epoch=epoch + 1, generator=generator, discriminator=discriminator)
            
        for batch in tqdm(dataset, leave=False, total=num_batches):
            discriminator_step(generator, discriminator, batch, latent_dims)
            generator_step(generator, discriminator, batch_size, latent_dims)
        
        for c in callbacks:
            c.on_epoch_end(epoch=epoch + 1, generator=generator, discriminator=discriminator)

### 2.2 Metrics Classifier

Loading the classifier that will be used to calculate the *Classifier Score* (CS) and *Fréchet Classifier Distance* (FCD). \
The features of the real data are also precalculated to avoid doing that for each epoch.

In [None]:
classifier = tf.keras.models.load_model('../../Classifiers/cifar.h5')
feature_layer = classifier.get_layer('features')
logits_layer = classifier.get_layer('logits')
precalculated_features = utils.fn.calculate_features(classifier, feature_layer, data)

### 2.3 Hyperparameter Testing

These were the hyperparameters tested for the final document. Training all of them simultaneously may take a long time, consider commenting out some options to run the tests individually.

In [None]:
BETA_1 = 0.0
BATCH_SIZE = 16
LATENT_DIMS = 64

In [None]:
hparams_list = [
    {'use_transp_conv':  True, 'smooth_labels': False},
    {'use_transp_conv': False, 'smooth_labels': False},
    {'use_transp_conv':  True, 'smooth_labels':  True},
    {'use_transp_conv': False, 'smooth_labels':  True}
]

In [None]:
for hparams in hparams_list:
    dirname = '{}{}'.format(
        'TRPCONV' if hparams['use_transp_conv'] else 'BILINEAR',
        '_SMOOTH' if hparams['smooth_labels'] else ''
    )
    Path(dirname).mkdir(exist_ok=True)
    
    ## Models
    if hparams['use_transp_conv']:
        generator = generator_model_transp_conv(LATENT_DIMS)
    else:
        generator = generator_model_bilinear(LATENT_DIMS)
    generator.optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=BETA_1, epsilon=1e-8)

    discriminator = discriminator_model()
    discriminator.optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=BETA_1, epsilon=1e-8)
    discriminator_loss = discriminator_loss_smooth if hparams['smooth_labels'] else discriminator_loss_normal
    
    ## Callbacks
    timer = utils.callback.TimerCallback()
    save_samples = utils.callback.SaveSamplesCallback(
        path_format=os.path.join(dirname, 'epoch-{}'),
        inputs=tf.random.normal((10*10, LATENT_DIMS)),
        n_cols=10,
        savefig_kwargs={'bbox_inches': 'tight', 'pad_inches': 0, 'dpi': 192},
        transform_samples= lambda samples: (1 + samples) * 0.5,
        grid_params={'border': 1, 'pad': 1, 'pad_value': 0}
    )
    metrics = utils.callback.MetricsCallback(
        generator=generator,
        classifier=classifier,
        latent_dims=LATENT_DIMS,
        feature_layer=feature_layer,
        logits_layer=logits_layer,
        n_samples=200*128,
        batch_size=128,
        precalculated_features=precalculated_features,
        save_after=5, save_to=os.path.join(dirname, 'best.h5')
    )
    
    ## Train and save results
    train(
        generator, discriminator, data, epochs=30,
        batch_size=BATCH_SIZE,
        callbacks=[timer, save_samples, metrics]
    )
    
    metrics_obj = metrics.get_metrics()
    metrics_obj['time'] = timer.get_time()
    utils.fn.update_json_log(os.path.join(dirname, 'log.json'), metrics_obj)
    
    generator.save    (os.path.join(dirname, 'generator.h5'    ), overwrite=True, save_format='h5')
    discriminator.save(os.path.join(dirname, 'discriminator.h5'), overwrite=True, save_format='h5')

\
In windows the command bellow is used to turn down the machine after the training finishes, very useful if you wanna let the computer running while you go to sleep :)

In [None]:
# !shutdown /s /t 60