# GAN
Train a GAN on the MNIST handwritten digit dataset. 

This makes use of: our custom Keras model class defined in vaegan.gan.py, our
class for loading the MNIST dataset defined in vaegan.data, and our custom Keras
callback in vaegan.callbacks.

A directory called 'output' will be created to save figures and the trained
model.  

In [None]:
# nEpochs=40  # orig, longer training
nEpochs=3  # quick testing during development

: 

## 1. Import 3rd party libraries 

In [None]:
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from importlib import reload



## 2. Import our own classes (that we will complete together)

In [None]:
# Import our own classes 
from vaegan.data import MNIST
from vaegan.callbacks import GenerateImages
import vaegan.gan


## 3. Show some our data

In [None]:
# Create output directory if it doesn't exist yet.
output_dir = './outputs/mnist_gan'
os.makedirs(output_dir, exist_ok=True)
    
# Instantiate the MNIST class containing our training data.
data = MNIST()

# Show some example images and their labels.
data.show_example_images(os.path.join(output_dir, 'example_images.png'))

## 4. Construct the model using the python class you completed

In [None]:
pyModule = reload(vaegan.gan)


# Create the model. Note that we're using mostly the default arguments, but this is
# where you might want to play around with different loss weights.
tf.random.set_seed(1234)
model = pyModule.GAN()

# This step tells Keras to compute the explicit output shapes of each layer.
# Otherwise, the layers will have dynamic/variable output shapes which is not
# compatible with saving and loading.
model.compute_output_shape((None, 32, 32, 1))
model.discriminator.compute_output_shape((None, 32, 32, 1))

## <span style="color:blue"> Correct model dimensions </span>
    

<span style="color:blue"> === OVERALL MODEL ==== </span>

```
Model: "gan"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
generator (Generator)        multiple                  1582465   
_________________________________________________________________
discriminator (Discriminator multiple                  50145     
_________________________________________________________________
gen_loss (Mean)              multiple                  2         
_________________________________________________________________
disc_loss (Mean)             multiple                  2         
=================================================================
Total params: 1,632,614
Trainable params: 1,632,418
Non-trainable params: 196
_________________________________________________________________

```
<span style="color:blue"> === GENERATOR SUBMODEL ====</span>
```
Layer (type)                 Output Shape              Param #   
=================================================================
input_11 (InputLayer)        [(None, 128)]             0         
_________________________________________________________________
dense (Dense)                (None, 8192)              1056768   
_________________________________________________________________
relu_dense (LeakyReLU)       (None, 8192)              0         
_________________________________________________________________
reshape (Reshape)            (None, 8, 8, 128)         0         
_________________________________________________________________
conv0 (Conv2DTranspose)      (None, 16, 16, 128)       262272    
_________________________________________________________________
relu0 (LeakyReLU)            (None, 16, 16, 128)       0         
_________________________________________________________________
conv1 (Conv2DTranspose)      (None, 32, 32, 128)       262272    
_________________________________________________________________
relu1 (LeakyReLU)            (None, 32, 32, 128)       0         
_________________________________________________________________
conv_out (Conv2D)            (None, 32, 32, 1)         1153      
_________________________________________________________________
sigmoid_out (Activation)     (None, 32, 32, 1)         0         
=================================================================
Total params: 1,582,465
Trainable params: 1,582,465
Non-trainable params: 0

```
<span style="color:blue"> === DISCRIMINATOR SUBMODEL ====</span>
```
Layer (type)                 Output Shape              Param #   
=================================================================
input_12 (InputLayer)        [(None, 32, 32, 1)]       0         
_________________________________________________________________
conv0 (Conv2D)               (None, 32, 32, 32)        544       
_________________________________________________________________
bn0 (BatchNormalization)     (None, 32, 32, 32)        128       
_________________________________________________________________
relu0 (ReLU)                 (None, 32, 32, 32)        0         
_________________________________________________________________
conv1 (Conv2D)               (None, 16, 16, 64)        32832     
_________________________________________________________________
bn1 (BatchNormalization)     (None, 16, 16, 64)        256       
_________________________________________________________________
relu1 (ReLU)                 (None, 16, 16, 64)        0         
_________________________________________________________________
flatten (Flatten)            (None, 16384)             0         
_________________________________________________________________
dense_out (Dense)            (None, 1)                 16385     
_________________________________________________________________
sigmoid_out (Activation)     (None, 1)                 0         
=================================================================
Total params: 50,145
Trainable params: 49,953
Non-trainable params: 192

```
## 5. Now check your model's  dimensions against this list above

In [None]:
print("=== OVERALL MODEL ====")
model.summary()
print("=== GENERATOR SUBMODEL ====")
gen_in = tf.keras.layers.Input(model.n_latent_dims)
gen_out  = model.generator.call(gen_in) 
gen = tf.keras.Model(gen_in, gen_out) 
gen.summary()
print("=== DISCRIMINATOR SUBMODEL ====")
disc_in = tf.keras.layers.Input(model.image_shape) 
disc_out  = model.discriminator.call(disc_in) 
disc = tf.keras.Model(disc_in, disc_out) 
disc.summary()

## 6. Compile the model 

In [None]:
# Compile the model with an optimizer. The learning rate of the optimizer can be
# specified here. Normally, this is also where you would select a loss function
# and any metrics. However, our custom model defines the loss functions inside
# its __init__ constructor, so we don't need to do that here. 
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.0001))

# Instantiate our custom callback to save a few example reconstructions and
# generated images after each epoch.
save_images_callback = GenerateImages(output_dir=output_dir, 
                                      model=model,
                                      n_generated_images=10,
                                      n_latent_dims=model.n_latent_dims)

# the number of latent dims is 128
# model.build(input_shape=(60000, 32, 32, 1))

## 7. Train (fit) the model on the data  

In [None]:
# Train the model. Just like any off-the-shelf Keras model, we just call fit.
# Under the hood, Keras will call the train_step method of our custom subclass
# on each mini-batch and automatically loop through the training data. It will
# take care of all the details, like converting numpy arrays to tensors, showing
# a progress bar, and tracking the loss over the epochs.
logs = model.fit(data.images_train,
                 batch_size=128,
                 epochs=nEpochs,
                 callbacks=[save_images_callback])

# 469 is the number of iterations/epoch...

## 8. Training saves results to disk, now also plot training curves

In [None]:
# Plot the training curves, which are stored in logs.history as a dict. Keys of
# this dict are the metric names, while the corresponding values are arrays.
fig, ax = plt.subplots()

for loss_name in ['gen_loss', 'disc_loss']:
    loss_values = logs.history[loss_name]
    x = np.arange(len(loss_values))
    ax.plot(x, loss_values, label=loss_name)
    
ax.legend()
ax.set_xlabel('Epoch')
fig.savefig(os.path.join(output_dir, 'training_curves.png'), transparent=False)
# fig.show()

# Save the model 
model.save(os.path.join(output_dir, 'gan'))