## Example GAN
***
A quick and simple GAN trained with MNIST

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

from kerasgan import GAN, GANSnapshot, GANCheckpoint

### Key Parameters

In [None]:
BATCH_SIZE = 128
LATENT_DIM = 128
AUTOTUNE = tf.data.AUTOTUNE

### Load MNIST Dataset

In [None]:
(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()

# combine all images and rescale to [-1,1]
image_data = np.concatenate((X_train,X_test))
image_data = image_data[...,np.newaxis].astype('float32')/127.5 - 1

# sanity checks
print(image_data.shape, image_data.dtype, np.max(image_data), np.min(image_data))

# create tf dataset 
image_data = tf.data.Dataset.from_tensor_slices(image_data).batch(BATCH_SIZE).prefetch(AUTOTUNE)

### Create the Generator
The generator should take a latent vector as its input and output an image. This example will use tranpose convolutions with a stride length of 2. Note the use of output_padding=1 in the first `Conv2DTranpose`. This ensures the dimensions go from (4,4,128) to (7,7,64) instead of (8,8,64).

In [None]:
keras.backend.clear_session()
gen_input = layers.Input(shape=(LATENT_DIM,))
x = layers.Dense(4*4*128, use_bias=False)(gen_input)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Reshape((4,4,128))(x)
x = layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', output_padding=1, use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2DTranspose(32, kernel_size=4, strides=2, padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
gen_output = layers.Conv2DTranspose(1, kernel_size=4, strides=2, padding='same', activation='tanh')(x)
generator = keras.Model(inputs=gen_input, outputs=gen_output, name='generator')
generator.summary()

### Create the Discriminator
The discriminator should take an image and output a single value. This example uses regular, strided convolutions. Note there is no batch normalisation in the first layer.

In [None]:
dsc_input = layers.Input(shape=(28,28,1))
x = layers.Conv2D(32, kernel_size=3, strides=2, padding='same')(dsc_input)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(64, kernel_size=3, strides=2, padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, kernel_size=3, strides=2, padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Flatten()(x)
x = layers.Dropout(0.25)(x)
dsc_output = layers.Dense(1)(x) # note there is no activation!
discriminator = keras.Model(inputs=dsc_input, outputs=dsc_output, name='discriminator')
discriminator.summary()

### Training the GAN
Simply instantiate a `GAN` object with the two models, compile and fit. Don't forget to pass `LATENT_DIM` to the `GAN` constructor.

The example below also includes two callbacks:
- `GANSnapshot` will produce a matplotlib figure showing some example images produced by the generator at each step. You can provide a  seed if you wish, otherwise one will be randomly generated. By default, this callback will plot 32 examples. The figures are merely displayed in the Notebook, but you can save these as png files by setting `save_snapshots=True`.

- `GANCheckpoint` will, in its default settings, save both the generator and discriminator as .h5 files in the directory 'gan_checkpoints' every 20 epochs.

In [None]:
gan = GAN(
    generator=generator,
    discriminator=discriminator,
    latent_dim=LATENT_DIM
)

gan.compile(
    generator_optimizer=keras.optimizers.Adam(learning_rate=2e-4,beta_1=0.05),
    discriminator_optimizer=keras.optimizers.Adam(learning_rate=2e-4,beta_1=0.05)
)

hist = gan.fit(
    image_data, epochs=50, verbose=2,
    callbacks = [GANSnapshot(), GANCheckpoint()]
)

### Further Callbacks Examples
Both callbacks have several options for greater flexibility

In [None]:
# display snapshots using the provided seed (64 examples)
seed = tf.random.normal(shape=(64,LATENT_DIM))
GANSnapshot(seed=seed)

# display 100 examples 
GANSnapshot(num_images=100)

# save snapshots every 10 epochs, and write the files to the directory 'mysnapshots'.
# files will be named mymodel_0000.png, mymodel_0010.png, mymodel_0020.png, etc
GANSnapshot(save_snapshots=True, save_dir='mysnapshots',
    save_freq=10, save_prefix='mymodel')

# save snapshots of 42 examples with the 'cividis' colormap and a higher resolution
GANSnapshot(num_images=42, save_snapshots=True, cmap='cividis', dpi=200)

# save just the generator, every 100 epochs
GANCheckpoint(save_freq=100, save_discriminator=False)