## Setup

In [None]:
%matplotlib qt
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_gan as tfgan
import numpy as np
import os, sys
from tqdm.notebook import tqdm
from pathlib import Path

sys.path.append( os.path.abspath('..') )
import utils

In [None]:
Path('CIFAR').mkdir(exist_ok=True)
os.chdir('CIFAR')

In [None]:
(x_train, _), (x_test, _) = tf.keras.datasets.cifar10.load_data()
data = np.concatenate((x_train, x_test))
data = (data.astype('float32') - 127.5) / 127.5  # normalize to [-1, 1]
assert data.shape == (60000, 32, 32, 3)  # (batch, height, width, channel)

## 1 Models

### 1.1 Architecture

In [None]:
def generator_model_transp_conv(latent_dims):
    return tf.keras.Sequential([
        tf.keras.layers.Dense(4*4*256, input_shape=(latent_dims,)),
        tf.keras.layers.LeakyReLU(0.2),
        #  tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Reshape((4, 4, 256)),

        tf.keras.layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.BatchNormalization(),

        tf.keras.layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.BatchNormalization(),
        
        tf.keras.layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),

        tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, padding='same', activation='tanh')
    ])

In [None]:
def generator_model_upsample(latent_dims, interpolation):
    return tf.keras.Sequential([
        tf.keras.layers.Dense(4*4*256, input_shape=(latent_dims,)),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Reshape((4, 4, 256)),
        
        tf.keras.layers.UpSampling2D(size=2, interpolation=interpolation),
        tf.keras.layers.Conv2D(128, kernel_size=3, strides=1, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.BatchNormalization(),
        
        tf.keras.layers.UpSampling2D(size=2, interpolation=interpolation),
        tf.keras.layers.Conv2D(128, kernel_size=3, strides=1, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.BatchNormalization(),
        
        tf.keras.layers.UpSampling2D(size=2, interpolation=interpolation),
        tf.keras.layers.Conv2D(128, kernel_size=3, strides=1, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        
        tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, padding='same', activation='tanh')
    ])

In [None]:
def critic_model():
    return tf.keras.Sequential([
        tf.keras.layers.Conv2D(64, kernel_size=3, strides=2, padding='same', input_shape=(32,32,3)),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.Dropout(0.3),

        tf.keras.layers.Conv2D(128, kernel_size=3, strides=2, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.Dropout(0.3),
        
        tf.keras.layers.Conv2D(filters=128, kernel_size=3, strides=2, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.Dropout(0.3),
        
        tf.keras.layers.Conv2D(filters=256, kernel_size=3, strides=1, padding='same'),
        tf.keras.layers.LeakyReLU(0.2),
        # tf.keras.layers.Dropout(0.3),

        # tf.keras.layers.GlobalAvgPool2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dropout(0.5),
        
        tf.keras.layers.Dense(1)
    ])

### 1.2 Losses

For the Wasserstein metric, the loss for the generator (G) given the critic (f) is given by:

$$
    -\mathbb{E}_{z \sim p(z)}\bigl\lbrack f\bigl(G(z)\bigr) \bigr\rbrack
$$

In [None]:
def generator_loss(fake_output):
    return -tf.reduce_mean(fake_output)

The critic loss tries to maximize:
$$
    \max_{\|f \|_{L} \leq 1} \mathbb{E}_{x \sim p_{data}} \bigl\lbrack f(x) \bigr\rbrack - 
    \mathbb{E}_{z \sim p_{z}} \bigl\lbrack f\bigl(G(z)\bigr) \bigr\rbrack
$$

Which is equivalent to minimizing the negative of this value, as shown in the function bellow

In [None]:
def critic_loss(real_output, fake_output):
    mu_real = tf.reduce_mean(real_output)
    mu_fake = tf.reduce_mean(fake_output)
    return mu_fake - mu_real

The gradient penalty is given by:

$$
    \mathbb{E}_{\hat{x} \sim p_{\hat{x}}} \bigl\lbrack 
    \bigl( \|\nabla_{\hat{x}} f(\hat{x})\|_{2} - 1\bigr)^2
    \bigr\rbrack
$$

Where $\hat{x}$ is a point between the real and fake data outputs from the critic, the point is randomly selected in a uniform distribution

In [None]:
def gradient_penalty(critic, real_samples, fake_samples):
    differences = fake_samples - real_samples
    alpha = tf.random.uniform(shape=(differences.shape[0], 1, 1, 1), minval=0.0, maxval=1.0)
    interpolated = real_samples + alpha*differences
    grads = tf.gradients(critic(interpolated), [interpolated])[0]
    slopes = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1]))
    return tf.reduce_mean((slopes-1.0)**2)

## 2 Training

### 2.1 Main functions

In [None]:
def critic_train_step(generator, critic, images, latent_dims, _lambda):
    noise = tf.random.normal([images.shape[0], latent_dims])
    with tf.GradientTape() as crit_tape:
        generated_imgs = generator(noise, training=True)
        real_output = critic(images, training=True)
        fake_output = critic(generated_imgs, training=True)
        penalty = gradient_penalty(critic, images, generated_imgs)
        loss_C = critic_loss(real_output, fake_output) + _lambda * penalty
    
    grads_C = crit_tape.gradient(loss_C, critic.trainable_variables)
    critic.optimizer.apply_gradients(zip(grads_C, critic.trainable_variables))

In [None]:
def generator_train_step(generator, critic, batch_size, latent_dims):
    noise = tf.random.normal([batch_size, latent_dims])
    with tf.GradientTape() as gen_tape:
        generated_imgs = generator(noise, training=True)
        fake_output = critic(generated_imgs, training=True)
        loss_G = generator_loss(fake_output)
    
    grads_G = gen_tape.gradient(loss_G, generator.trainable_variables)
    generator.optimizer.apply_gradients(zip(grads_G, generator.trainable_variables))

In [None]:
def train(generator, critic, data, epochs, _lambda, n_critic=1, batch_size=32, callbacks=None):
    latent_dims = generator.input_shape[1]
    dataset = tf.data.Dataset.from_tensor_slices(data).shuffle(data.shape[0]).batch(batch_size)
    iterator     = iter(dataset)
    num_batches  = 1 + (data.shape[0] - 1) // batch_size
    batches_left = True
    batch_count  = 0
    
    
    generator_step = tf.function(generator_train_step)
    critic_step = tf.function(critic_train_step)
    for epoch in tqdm(range(epochs)):
        for c in callbacks:
            c.on_epoch_begin(epoch=epoch + 1, generator=generator, discriminator=critic)
        
        batch_pbar = tqdm(total=num_batches, leave=False)
        while batches_left:
            for i in range(n_critic):
                if batch_count == num_batches:
                    batch_count = 0
                    batches_left = False
                    iterator = iter(dataset)
                batch_count += 1
                batch_pbar.update()
                batch = iterator.get_next()
                critic_step(generator, critic, batch, latent_dims, _lambda)
            generator_step(generator, critic, batch_size, latent_dims)
        batches_left = True
        batch_pbar.update(num_batches)
        batch_pbar.close()
        
        for c in callbacks:
            c.on_epoch_end(epoch=epoch + 1, generator=generator, discriminator=critic)

### 2.2 Metrics classifier

Loading the classifier that will be used to calculate the *Classifier Score* (CS) and *Fréchet Classifier Distance* (FCD). \
The features of the real data are also precalculated to avoid doing that for each epoch.

In [None]:
classifier = tf.keras.models.load_model('../../Classifiers/cifar.h5')
feature_layer = classifier.get_layer('features')
logits_layer = classifier.get_layer('logits')
precalculated_features = utils.fn.calculate_features(classifier, feature_layer, data)

### 2.3 Hyperparameter Testing

These were the hyperparameters tested for the final document. Training all of them simultaneously may take a long time, consider commenting out some options to run the tests individually.

In [None]:
LATENT_DIMS = 64
BATCH_SIZE = 16
LAMBDA = 10

In [None]:
hparams_list = [
    {'n_critic': 5, 'learning_rate': 2e-4, 'beta_1': 0.0, 'upsample': 'TrpConv' },
    {'n_critic': 5, 'learning_rate': 2e-4, 'beta_1': 0.5, 'upsample': 'TrpConv' },
    {'n_critic': 5, 'learning_rate': 2e-4, 'beta_1': 0.0, 'upsample': 'bilinear'},
    {'n_critic': 5, 'learning_rate': 2e-4, 'beta_1': 0.5, 'upsample': 'bilinear'},
    {'n_critic': 5, 'learning_rate': 2e-4, 'beta_1': 0.0, 'upsample': 'nearest' },
    {'n_critic': 5, 'learning_rate': 2e-4, 'beta_1': 0.5, 'upsample': 'nearest' },
    {'n_critic': 1, 'learning_rate': 1e-4, 'beta_1': 0.0, 'upsample': 'nearest' }
]

In [None]:
for hparams in hparams_list:
    dirname = 'NCRIT{}_beta1[{:.1f}]_LR{:.0e}_{}'.format(
        hparams['n_critic'],
        hparams['beta_1'],
        hparams['learning_rate'],
        hparams['upsample'].upper()
    )
    Path(dirname).mkdir(exist_ok=True)
    
    if hparams['upsample'] is 'TrpConv':
        generator = generator_model_transp_conv(LATENT_DIMS)
    else:
        generator = generator_model_upsample(LATENT_DIMS, hparams['upsample'])
    generator.optimizer = tf.keras.optimizers.Adam(hparams['learning_rate'], beta_1=hparams['beta_1'])
    
    critic = critic_model()
    critic.optimizer = tf.keras.optimizers.Adam(hparams['learning_rate'], beta_1=hparams['beta_1'])
    
    ## Callbacks
    timer = utils.callback.TimerCallback()
    save_samples = utils.callback.SaveSamplesCallback(
        path_format=os.path.join(dirname, 'epoch-{}'),
        inputs=tf.random.normal((10*10, LATENT_DIMS)),
        n_cols=10,
        savefig_kwargs={'bbox_inches': 'tight', 'pad_inches': 0, 'dpi': 192},
        transform_samples= lambda samples: (1 + samples) * 0.5,
        grid_params={'border': 1, 'pad': 1, 'pad_value': 0}
    )
    metrics = utils.callback.MetricsCallback(
        generator=generator,
        classifier=classifier,
        latent_dims=LATENT_DIMS,
        feature_layer=feature_layer,
        logits_layer=logits_layer,
        n_samples=200*128,
        batch_size=128,
        precalculated_features=precalculated_features,
        save_after=5, save_to=os.path.join(dirname, 'best.h5')
    )
    
    ## Train and save results
    train(
        generator, critic, data, epochs=30,
        _lambda=LAMBDA,
        n_critic=hparams['n_critic'],
        batch_size=BATCH_SIZE,
        callbacks=[timer, save_samples, metrics]
    )
    
    metrics_obj = metrics.get_metrics()
    metrics_obj['time'] = timer.get_time()
    utils.fn.update_json_log(os.path.join(dirname, 'log.json'), metrics_obj)
    
    generator.save(os.path.join(dirname, 'generator.h5'), overwrite=True, save_format='h5')
    critic.save   (os.path.join(dirname, 'critic.h5'   ), overwrite=True, save_format='h5')

\
In windows the command bellow is used to turn down the machine after the training finishes, very useful if you wanna let the computer running while you go to sleep :)

In [None]:
# !shutdown /s /t 60