## Example CGAN
***
A quick and simple CGAN trained with MNIST

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

from kerasgan import CGAN, CGANSnapshot, CGANClassSnapshot, GANCheckpoint

### Key Parameters

In [None]:
BATCH_SIZE = 128
LATENT_DIM = 128
NUM_CLASSES = 10
AUTOTUNE = tf.data.AUTOTUNE

### Load MNIST Data

In [None]:
(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()

# combine all images and rescale to [-1,1]
image_data = np.concatenate((X_train,X_test))
image_data = image_data[...,np.newaxis].astype('float32')/127.5 - 1
# combine labels
image_labels = np.concatenate((y_train,y_test))
# note that labels must be an array of integers
# in order to work with CGANSnapshot and/or CGANClassSnapshot

# sanity checks
print(image_data.shape, image_data.dtype, np.max(image_data), np.min(image_data))

# create tf dataset (both data and labels)
training_data = tf.data.Dataset.from_tensor_slices(
    (image_data,image_labels)).batch(BATCH_SIZE).prefetch(AUTOTUNE)

### Create the Generator
The generator should accept two inputs: a latent vector of size equal to `LATENT_DIM`, and a label. In this example the label is transformed into a higher-dimensional embedding of length 64, before being concatenated with the latent vector. Note this is just one of many ways of incorporating labels; you could, for instance, use a Multiply() layer.

In [None]:
keras.backend.clear_session()

# generator
latent_input = layers.Input(shape=(LATENT_DIM,))
label_input = layers.Input(shape=(1,))
embedding = layers.Embedding(input_dim=NUM_CLASSES, output_dim=32, input_length=1)(label_input)
embedding = layers.Flatten()(embedding)
x = layers.Concatenate()([latent_input, embedding]) # shape = (LATENT_DIM+64,)
x = layers.Dense(4*4*128, use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Reshape((4,4,128))(x)
x = layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', output_padding=1, use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2DTranspose(32, kernel_size=4, strides=2, padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
gen_output = layers.Conv2DTranspose(1, kernel_size=4, strides=2, padding='same', activation='tanh')(x)
generator = keras.Model(inputs=[latent_input, label_input], outputs=gen_output, name='generator')
generator.summary()

### Create the Discriminator
The discriminator also accepts two inputs; an image and a label. In this example the label is transformed into a higher dimensional embedding, which is eventually reshaped and concatenated to the image as an additional "channel".

In [None]:
image_input = layers.Input(shape=(28,28,1))
label_input = layers.Input(shape=(1,))
embedding = layers.Embedding(input_dim=10, output_dim=28*28, input_length=1)(label_input)
embedding = layers.Flatten()(embedding)
embedding = layers.Reshape((28,28,1))(embedding)
x = layers.Concatenate()([image_input, embedding]) # shape = (28,28,2)
x = layers.Conv2D(32, kernel_size=3, strides=2, padding='same')(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(64, kernel_size=3, strides=2, padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, kernel_size=3, strides=2, padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Flatten()(x)
x = layers.Dropout(0.25)(x)
dsc_output = layers.Dense(1)(x)
discriminator = keras.Model(inputs=[image_input, label_input], outputs=dsc_output, name='discriminator')
discriminator.summary()

### Training the CGAN
As with the normal GAN, training is a simple 3-step process. Don't forget to pass both the latent vector dimension `LATENT_DIM` as well as the number of classes `NUM_CLASSES` to the `CGAN` constructor; the latter is important as it is used as an upper bound for generating labels.

#### A Note on Labels
The callbacks `CGANSnapshot` and `CGANClassSnapshot` are both capable of randomly generating labels to pass to the generator. These labels are strictly integers in the range 0, 1, ..., `NUM_CLASSES`-1. Ensure your generator model and training data all use integer labels before using these callbacks.

In [None]:
cgan = CGAN(
    generator=generator,
    discriminator=discriminator,
    latent_dim=LATENT_DIM,
    num_classes=NUM_CLASSES
)

cgan.compile(
    generator_optimizer=keras.optimizers.Adam(learning_rate=2e-4,beta_1=0.05),
    discriminator_optimizer=keras.optimizers.Adam(learning_rate=2e-4,beta_1=0.05)
)

hist = cgan.fit(
    training_data, epochs=30, verbose=2,
    callbacks = [CGANClassSnapshot(examples_per_class=4), GANCheckpoint(save_dir='cgan_checkpoints')]
)

### CGAN Callbacks
Both `CGANSnapshot` and `CGANClassSnapshot` are intended to be used with `CGAN.` **Do not use `GANSnapshot` with `CGAN`, it should instead only be used with `GAN`!**

`CGANSnapshot` is identical to `GANSnapshot` but instead also randomly generates labels.

`CGANClassSnapshot` is the recommended snapshot callback to use. By default, it will generate 3 rows of examples with each column corresponding to a different label.

Each of these callbacks has several additional options that allow for greater flexibility. Some examples are shown below:

In [None]:
# display snapshots using the provided seed (will generate random labels)
seed = tf.random.normal(shape=(32,LATENT_DIM))
CGANSnapshot(seed=seed)

# display snapshots using the provided labels (will generate a random seed)
labels = np.array([0,1,2,3,4,5,6,7,8,9])
CGANSnapshot(labels=labels)

# display snapshots using the provided seed and labels
CGANSnapshot(seed=seed, labels=labels)

# save snapshots with 100 images into the directory 'mysnapshots'
CGANSnapshot(num_images=100, save_snapshots=True, save_dir='mysnapshots')

# save snapshots every 10 epochs with 5 examples per class into the directory 'myclass_snapshots'
CGANClassSnapshot(examples_per_class=5, save_snapshots=True,
    save_freq=10, save_dir='myclass_snapshots')

# display snapshots with 5 examples per class, except make each row show a different class
CGANClassSnapshot(examples_per_class=5, columns_indicate='examples')