# GAN on MNIST


We will  still use the same MNIST dataset with each example shaped  as `(28,28,1)` array.
But this time no need to split into train/valid dataset.

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

import matplotlib.pyplot as plt
%matplotlib inline
import datetime as dt

from pathlib import Path

from sklearn.datasets import fetch_openml
from sklearn import preprocessing

import io
import scipy

data_home = '/tmp/scikit_learn_data/'
datafile = '/tmp/mnist.npz'

datapath = Path(datafile)
if not(datapath.exists()):
    print("Data File not found... downloading it")
    Xmnist, ymnist = fetch_openml('mnist_784',
                                  version=1,
                                  return_X_y=True,
                                  data_home=data_home)
    np.savez(datapath.as_posix(),
             X=np.array(Xmnist, dtype='u8'),
             y=np.array(ymnist, dtype='u8'))
    print("Data File downloaded and saved")
    del Xmnist, ymnist

print("Data File found... loading it into memory")
data = np.load(datapath.as_posix())
Xmnist = data['X']/255.
ymnist = keras.utils.to_categorical(data['y'])
print("Data File loaded")

Xtrain, Ytrain = Xmnist[:60000], ymnist[:60000]
Xtest, Ytest = Xmnist[-10000:], ymnist[-10000:]

Xtrain = Xtrain.reshape((Xtrain.shape[0], 28, 28, 1))
Xtest = Xtest.reshape((Xtest.shape[0], 28, 28, 1))


## Training GAN in KERAS


GAN combined two networks a `generator` and a `discriminator`.
Nevertheless,
- while the generator is learnt the weights of the discrimnitor need to be fixed,
- and while the discrimnitor is learnt the weights of the generator need also to be fixed.
That's why there is no easy way to use the `fit` method proposed by keras.
You had to write yourself the big training loop with the `train_on_batch` method.

Here is an example code to show you how to combine the generator and the discriminator:
```python
    
    generator = [...] # create your own generator
    generator.build(input_shape=[...])
    # no need to compile the generator as it will not be learnt by itself
    generator.summary()

    discriminator = [...] # create your own discriminator
    discriminator.build(input_shape=[...])
    discriminator.compile(loss='binary_crossentropy', optimizer='adam')
    discriminator.summary()

    # we fixed discriminator weights
    # must be call after discriminator compilation
    # see definition in the next cell
    set_trainable(discriminator, False) 
    
    # Combined the generator and the discriminator into a gan model
    # here the discriminator weights are fixed
    gan = keras.Sequential(name='gan')
    gan.add(generator)
    gan.add(discriminator)
    gan.build(input_shape=[...])
    gan.compile(loss='binary_crossentropy', optimizer='adam')
    gan.summary()

```

Now at each batch iteration:
- the discriminator should be trained with fake examples labeled 0 and real examples labeled 1. 
- the gan should be trained with only fake examples labeled 1. 

You can produce fake example using the 'predict' method of the generator.

```python

    for epochs
        for batch
            #[...] 
            Xfake_batch = generator.predict(some_noise, verbose=0)
            #[...]
            # Xdiscriminator_batch contains both fake and real examples
            # respectively labelled 0 and 1
            discriminator.train_on_batch(Xdiscriminator_batch, Ydiscriminator_batch)
            # Xgan_batch contains only fake examples
            # labelled 1
            gan.train_on_batch(Zgan_batch, Ygan_batch)
            
```


In order to help you, the following function are given:
- `set_trainable` to change the trainable state of a whole model
- `batch_iterator` an iterator producing slice indices to split a dataset into batches

See examples bellow.

In [None]:
def set_trainable(model, trainable=True):
    model.trainable = trainable
    for l in model.layers:
       l.trainable = trainable

def batch_iterator(nx, batch_size, shuffle=True):
# nx : bumber of examples in the set
# batch_size: the desired batch size
# shuffle: whether to shuffle examples or not
# It works even if nx is not divisible by batch_size
    idx = np.arange(nx)
    if shuffle:
        np.random.shuffle(idx)
    n_batchs = int(np.floor(nx/batch_size))
    if n_batchs * batch_size < nx:
        n_batchs += 1
    start = 0
    for batch in range(n_batchs):
        end = min(start + batch_size, nx)
        aSlice = idx[start:end]
        start = end
        yield aSlice, batch, n_batchs
# yield the current slice, the current batch index, and the total number of batches

print("Batch iterator examples")
for aSlice, batch, n_batch in batch_iterator(100,10,shuffle=False):
    print(aSlice, batch, n_batch)
print("")
for aSlice, batch, n_batch in batch_iterator(100,10):
    print(aSlice, batch, n_batch)
print("")
for aSlice, batch, n_batch in batch_iterator(100,15,shuffle=False):
    print(aSlice, batch, n_batch)

## Exercice

1) Build a GAN to produce MNIST like images.

2) Build a Pac-GAN to produce more diversity (2 images at the input of the discriminator).

3) Build a Condition GAN to produce MNIST like images conditioned to the class label.