## Setup

In [None]:
%matplotlib qt
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_gan as tfgan
import numpy as np
import os, sys
from tqdm.notebook import tqdm
from pathlib import Path

sys.path.append( os.path.abspath('..') )
import utils

In [None]:
Path('Fashion').mkdir(exist_ok=True)
os.chdir('Fashion')

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
data = np.concatenate((x_train, x_test))
data = (data.astype('float32') - 127.5) / 127.5  # normalize to [-1, 1]
data = np.expand_dims(data, axis=-1)  # add channels dimension
assert data.shape == (70000, 28, 28, 1)  # (batch, height, width, channel)

NUM_CLASSES = 10
labels = np.concatenate((y_train, y_test))
labels = np.expand_dims(labels, -1)
assert labels.shape == (70000, 1)

## 1 Models

### 1.1 Architecture

In [None]:
def generator_model(latent_dims):
    ## Label input
    label = tf.keras.Input(shape=(1,), name='label', dtype=tf.int32)
    embedding = tf.keras.layers.Embedding(input_dim=NUM_CLASSES, output_dim=36)(label)
    label_channel = tf.keras.layers.Dense(7*7)(embedding)
    label_channel = tf.keras.layers.Reshape((7, 7, 1))(label_channel)
    ## Latent input
    seed = tf.keras.Input(shape=(latent_dims,), name='seed')
    seed_channels = tf.keras.layers.Dense(7*7*255, input_shape=(latent_dims,))(seed)
    seed_channels = tf.keras.layers.Reshape((7, 7, 255))(seed_channels)
    
    channels = tf.keras.layers.Concatenate(axis=-1)([label_channel, seed_channels])
    channels = tf.keras.layers.LeakyReLU()(channels)
    channels = tf.keras.layers.BatchNormalization()(channels)
    
    channels = tf.keras.layers.Conv2D(128, kernel_size=5, strides=1, padding='same')(channels)
    channels = tf.keras.layers.LeakyReLU()(channels)
    channels = tf.keras.layers.BatchNormalization()(channels)
    
    channels = tf.keras.layers.UpSampling2D(size=2, interpolation='bilinear')(channels)
    channels = tf.keras.layers.Conv2D(64, kernel_size=5, strides=1, padding='same')(channels)
    channels = tf.keras.layers.LeakyReLU()(channels)
    channels = tf.keras.layers.BatchNormalization()(channels)

    channels = tf.keras.layers.UpSampling2D(size=2, interpolation='bilinear')(channels)
    img = tf.keras.layers.Conv2D(1, kernel_size=5, strides=1, padding='same', activation='tanh')(channels)
    
    return tf.keras.Model(inputs=[seed, label], outputs=img, name='generator')

In [None]:
def discriminator_model():
    ## Label input
    label = tf.keras.Input(shape=(1,), name='label', dtype=tf.int32)
    embedding = tf.keras.layers.Embedding(input_dim=NUM_CLASSES, output_dim=36)(label)
    label_channel = tf.keras.layers.Dense(28*28)(embedding)
    label_channel = tf.keras.layers.Reshape((28, 28, 1))(label_channel)
    
    ## Image input
    image = tf.keras.Input(shape=(28, 28, 1), name='image')
    
    channels = tf.keras.layers.Concatenate(axis=-1)([label_channel, image])
    
    channels = tf.keras.layers.Conv2D(64, kernel_size=5, strides=2, padding='same', input_shape=(28,28,1))(channels)
    channels = tf.keras.layers.LeakyReLU()(channels)
    channels = tf.keras.layers.Dropout(0.3)(channels)

    channels = tf.keras.layers.Conv2D(128, kernel_size=5, strides=2, padding='same')(channels)
    channels = tf.keras.layers.LeakyReLU()(channels)
    channels = tf.keras.layers.Dropout(0.3)(channels)
    
    channels = tf.keras.layers.Flatten()(channels)
    logit = tf.keras.layers.Dense(1)(channels)
    return tf.keras.Model(inputs=[image, label], outputs=logit)

### 1.2 Losses

The binary cross entropy (BCE) between $y$ and $\hat{y}$ is calculated as:

$$
    \mathrm{BCE}(y, \hat{y}) = - y \log\left(\hat{y}\right) - (1-y) \log\left(1 - \hat{y}\right)
$$

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

The generator tries to maximize the chance of the discriminator being wrong. This is equivalent of trying to minimize the following loss function:

$$
    J^{(G)} = -\log\bigl(D\bigl(G(z)\bigr)\bigr)
$$

In [None]:
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

The discriminator tries to correctly classify real data as real and fake data as fake. This is equivalent to minimizing the following loss function:

$$
    J^{(D)} = -\log\bigr(D(x)\bigl) - \log\bigl(1 - D\bigl(G(z)\bigr)\bigr)
$$

Here we scale down the loss by a factor of $\;0.5$

In [None]:
def discriminator_loss_normal(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return 0.5 * (real_loss + fake_loss)

This function applies one sided label smoothing of $\:0.9\:$ to the discriminator loss.

In [None]:
def discriminator_loss_smooth(real_output, fake_output):
    real_loss = cross_entropy(0.9 * tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return 0.5 * (real_loss + fake_loss)

## 2 Training

### 2.1 Main Functions

In [None]:
def discriminator_train_step(generator, discriminator, images, labels, latent_dims):
    noise = tf.random.normal([images.shape[0], latent_dims])
    with tf.GradientTape() as disc_tape:
        generated_imgs = generator([noise, labels], training=True)
        real_output = discriminator([images, labels], training=True)
        fake_output = discriminator([generated_imgs, labels], training=True)
        loss_D = discriminator_loss(real_output, fake_output)
    
    grads_D = disc_tape.gradient(loss_D, discriminator.trainable_variables)
    discriminator.optimizer.apply_gradients(zip(grads_D, discriminator.trainable_variables))

In [None]:
def generator_train_step(generator, discriminator, y, latent_dims):
    noise = tf.random.normal([y.shape[0], latent_dims])
    with tf.GradientTape() as gen_tape:
        generated_imgs = generator([noise, y], training=True)
        fake_output = discriminator([generated_imgs, y], training=True)
        loss_G = generator_loss(fake_output)
    
    grads_G = gen_tape.gradient(loss_G, generator.trainable_variables)
    generator.optimizer.apply_gradients(zip(grads_G, generator.trainable_variables))

In [None]:
def train(generator, discriminator, data, labels, epochs, batch_size=None, callbacks=None):
    latent_dims = generator.input_shape[0][1]
    batch_size = batch_size if batch_size is not None else 32
    num_batches = 1 + (data.shape[0] - 1) // batch_size
    X = tf.data.Dataset.from_tensor_slices(data)
    Y = tf.data.Dataset.from_tensor_slices(labels)
    dataset = tf.data.Dataset.zip((X, Y)).shuffle(data.shape[0]).batch(batch_size)
    
    generator_step = tf.function(generator_train_step)
    discriminator_step = tf.function(discriminator_train_step)
    callbacks = callbacks or []
    for epoch in tqdm(range(epochs)):
        for c in callbacks:
            c.on_epoch_begin(epoch=epoch + 1, generator=generator, discriminator=discriminator)
            
        for images, labels in tqdm(dataset, leave=False, total=num_batches):
            discriminator_step(generator, discriminator, images, labels, latent_dims)
            generator_step(generator, discriminator, labels, latent_dims)
        
        for c in callbacks:
            c.on_epoch_end(epoch=epoch + 1, generator=generator, discriminator=discriminator)

### 2.2 Metrics classifier

Loading the classifier that will be used to calculate the *Classifier Score* (CS) and *Fréchet Classifier Distance* (FCD). \
The features of the real data are also precalculated to avoid doing that for each epoch.

In [None]:
classifier = tf.keras.models.load_model('../../Classifiers/fashion.h5')
feature_layer = classifier.get_layer('features')
logits_layer = classifier.get_layer('logits')
precalculated_features = utils.fn.calculate_features(classifier, feature_layer, data)

### 2.3 Hyperparameter testing

This function will overload the function of the same name in the MetricsCallback instance, this is because the default for this class does not generate the labels as input.

In [None]:
def get_random_inputs(n_samples):
    seeds = tf.random.normal((n_samples, LATENT_DIMS))
    labels = tf.random.uniform(
        shape=(n_samples, 1),
        minval=0, maxval=NUM_CLASSES,
        dtype=tf.int32
    )
    return [seeds, labels]

These were the hyperparameters tested for the final document. Training all of them simultaneously may take a long time, consider commenting out some options to run the tests individually.

In [None]:
LATENT_DIMS = 32

In [None]:
hparams_list = [
    {'batch_size':   16, 'smooth_labels': False},
    {'batch_size':   32, 'smooth_labels': False},
    {'batch_size':   16, 'smooth_labels': True},
    {'batch_size':   32, 'smooth_labels': True},
]

In [None]:
for hparams in hparams_list:
    dirname = 'BS{}{}'.format(
        hparams['batch_size'],
        '_SMOOTH' if hparams['smooth_labels'] else ''
    )
    Path(dirname).mkdir(exist_ok=True)
    
    generator = generator_model(LATENT_DIMS)
    generator.optimizer = tf.keras.optimizers.Adam(1e-4)

    discriminator = discriminator_model()
    discriminator.optimizer = tf.keras.optimizers.Adam(1e-4)
    discriminator_loss = discriminator_loss_smooth if hparams['smooth_labels'] else discriminator_loss_normal
    
    ## Callbacks
    timer = utils.callback.TimerCallback()
    save_samples = utils.callback.SaveSamplesCallback(
        path_format=os.path.join(dirname, 'epoch-{}'),
        inputs=[
            tf.random.normal((10*10, LATENT_DIMS)),
            np.expand_dims(np.repeat(np.arange(10), 10, axis=0), -1)
        ],
        n_cols=10,
        savefig_kwargs={'bbox_inches': 'tight', 'pad_inches': 0, 'dpi': 192},
        imshow_kwargs={'cmap': 'gray_r', 'vmin': -1, 'vmax': 1}
    )
    metrics = utils.callback.MetricsCallback(
        generator=generator,
        classifier=classifier,
        latent_dims=LATENT_DIMS,
        feature_layer=feature_layer,
        logits_layer=logits_layer,
        precalculated_features=precalculated_features,
        save_after=5, save_to=os.path.join(dirname, 'best.h5'),
    )
    metrics.get_random_inputs = get_random_inputs #overloading default function
    
    ## Train and save results
    train(
        generator, discriminator, data, labels, epochs=30,
        batch_size=hparams['batch_size'],
        callbacks=[timer, save_samples, metrics]
    )
    
    metrics_obj = metrics.get_metrics()
    metrics_obj['time'] = timer.get_time()
    utils.fn.update_json_log(os.path.join(dirname, 'log.json'), metrics_obj)
    
    generator.save    (os.path.join(dirname, 'generator.h5'    ), overwrite=True, save_format='h5')
    discriminator.save(os.path.join(dirname, 'discriminator.h5'), overwrite=True, save_format='h5')

\
In windows the command bellow is used to turn down the machine after the training finishes, very useful if you wanna let the computer running while you go to sleep :)

In [None]:
# !shutdown /s /t 60