# GAN

In a previous notebooks, we have met [RBMs](./rbm.ipynb) and [VAEs](./autoEnc.ipynb) as tools for letting our AI generate data. However, both approaches come with constraints that we have to be aware of. 

The single-layer architecture of RBMs is not powerful enough to generate data with a high inherent complexity. For instance, by refining the training process, RBMs can be made to produce realistically looking digits, but there is no hope of arriving at an image of a cat or bird.

VAE are far more expressive than RBMs and can generate promising candidates of images with high inherent complexities, such as faces. However, the images obtained in this manner tend to be blurry
<img id="gab" src="images/celebA.png"  width="500">
https://github.com/yzwxx/vae-celebA

The deeper reason can be deduced from analyzing the ELBO that the VAE is trained to optimize. Here, the reconstruction error is given as $-\mathbb{E}_{x \sim \mathsf{data}}\mathbb{E}_{z \sim q(z|x)}[\log p(x|z)]$, which results in a massive penalty if we encounter a data point with very low probability under $p(\cdot|z)$.

A radically new concept for a generative model are **Generative Adversarial Networks (GANs)** as introduced in a [landmark paper](https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf%20(https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf) introduced by Ian Goodfellow et. al. in 2014. The idea as simple as it is intriguing. The generative task is set up as an adversarial competition between two neural networks: While the goal of the **generator** is to produce realistically-looking images of cats, the job of the **discriminator** is tell actual images from generated ones. Then, the generator can take false decisions of the discriminator as cues on how to produce better images.

<img id="gab" src="images/ganGame.png"  width="900">

The code in this notebook is adapted from a wonderful [AC-GAN implementation](https://github.com/keras-team/keras/blob/master/examples/mnist_acgan.py) in the Keras example section.

## Generator versus Discriminator

It's not only that the concept behind GANs can be explained in one sentence to a man in the street, also the mathematical concept is just as elegant as $E = mc^2$! It is based on the following minimax problem:
\begin{align} \min_G \max_D \mathbb{E}_{x \sim p_{\mathsf{data}}}[\log D(x)] + \mathbb{E}_{z \sim p_{\mathsf{noise}}}[\log( 1 - D(G(z)))],
\end{align}
where $G$ and $D$ denote the generator and the discriminator, respectively.

For fixed generator, the inner maximization means that the discriminator is trained to minimize the cross-entropy error for the classification problem of true vs. fake images. The outer minimization means that the generator tries to make this task as difficult as possible.

Heuristically speaking, training the generator is more difficult than training the discriminator. For instance, for fixed generator, the inner maximization problem can be solved explicitly.


**Theorem**

Fix $G$. The solution of the optimization problem
$$ \max_D \mathbb{E}_{x \sim p_{\mathsf{data}}}[\log D(x)] + \mathbb{E}_{z \sim p_{\mathsf{noise}}}[\log(1 - D(G(z)))]$$
is given by 
$$D(x) = \frac{p_\mathsf{data}(x)}{p_\mathsf{data}(x) + p_{G}(x)},$$
where $p_{G}$ is the probability density of $G(Z)$ with $Z \sim p_{\mathsf{noise}}$. The optimal value of the maximization problem is given by
$$\mathsf{KL}(p_\mathsf{data}||(p_{\mathsf{data}} + p_G)/2) + \mathsf{KL}(p_G||(p_{\mathsf{data}} + p_G)/2),$$
the **Jensen-Shannon divergence** between $p_{\mathsf{data}}$ and $p_G$.

Typically, both discriminator and generator are implemented as deep nets. Here, are examples that we use for a code-example below to generate MNIST digits.

### Disriminator

In [49]:
from keras.layers import Conv2D, Dense, Dropout, Flatten, Input, LeakyReLU

def disc_cnn():
    return Sequential([
        Conv2D(32, 3, padding='same', strides=2,
                       input_shape=(28, 28, 1)),
        LeakyReLU(.2),
        Dropout(.3),

        Conv2D(64, 3, padding='same', strides=1),
        LeakyReLU(.2),
        Dropout(.3),

        Conv2D(128, 3, padding='same', strides=2),
        LeakyReLU(.2),
        Dropout(.3),

        Conv2D(256, 3, padding='same', strides=1),
        LeakyReLU(.2),
        Dropout(.3),

        Flatten()
    ])


### Generator

In [51]:
from keras.layers import BatchNormalization, Conv2DTranspose, Reshape

def gen_cnn(latent_size=100):
    return Sequential([
        Dense(3 * 3 * 384, input_dim=latent_size, activation='relu'),
        Reshape((3, 3, 384)),

        Conv2DTranspose(192, 5, strides=1, padding='valid',
                            activation='relu',
                            kernel_initializer='glorot_normal'),
        BatchNormalization(),

        Conv2DTranspose(96, 5, strides=2, padding='same',
                            activation='relu',
                            kernel_initializer='glorot_normal'),
        BatchNormalization(),

        Conv2DTranspose(1, 5, strides=2, padding='same',
                            activation='tanh',
                            kernel_initializer='glorot_normal')    
    ])

## Example: MNIST

First, we initialize the discriminator.

In [48]:
from keras.models import Sequential, Model
from keras.layers import Dense, Input 


def build_discriminator(num_classes=10):
    image = Input(shape = (28, 28, 1))
    features = disc_cnn()(image)

    fake = Dense(1, activation='sigmoid', name='generation')(features)
    aux = Dense(num_classes, activation='softmax', name='auxiliary')(features)
    return Model(image, [fake, aux])


In [52]:
from keras.optimizers import Adam
# Adam parameters suggested in https://arxiv.org/abs/1511.06434
adam_lr = 0.0002
adam_beta_1 = 0.5

# build the discriminator
print('Discriminator model:')
discriminator = build_discriminator()
discriminator.compile(
    optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),
    loss=['binary_crossentropy', 'sparse_categorical_crossentropy']
)

Discriminator model:


Second, the generator...

In [50]:
from keras.layers import Embedding, multiply

def build_generator(latent_size=100, num_classes=10):
    # we will map a pair of (z, L), where z is a latent vector and L is a
    # label drawn from P_c, to image space (..., 28, 28, 1)

    # this is the z space commonly referred to in GAN papers
    latent = Input(shape=(latent_size, ))

    # this will be our label
    image_class = Input(shape=(1,), dtype='int32')

    cls = Flatten()(Embedding(num_classes, latent_size,
                              embeddings_initializer='glorot_normal')(image_class))

    # hadamard product between z-space and a class conditional embedding
    h = multiply([latent, cls])

    fake_image = gen_cnn()(h)

    return Model([latent, image_class], fake_image)

In [53]:
# build the generator
generator = build_generator()


... and third the combined model

In [54]:
latent_size=100

# get a fake image
latent = Input(shape=(latent_size, ))
image_class = Input(shape=(1,), dtype='int32')
fake_image = generator([latent, image_class])

#put it into the discriminator
discriminator.trainable = False
fake, aux = discriminator(fake_image)
combined = Model([latent, image_class], [fake, aux])

combined.compile(
    optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),
    loss=['binary_crossentropy', 'sparse_categorical_crossentropy']
)

Now, we load the MNIST data

In [55]:
from keras.datasets import mnist
import numpy as np

(x_train, y_train), _ = mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
x_train = np.expand_dims(x_train, axis=-1)

In [None]:
epochs = 100
batch_size = 100

for epoch in range(1, epochs + 1):
        print('Epoch {}/{}'.format(epoch, epochs))
        num_batches = int(x_train.shape[0] / batch_size) 
        
        for idx in range(num_batches):
            #train discriminator
            train_disc(idx, x_train, y_train, discriminator, generator)
            
            #train generator
            train_gen(idx, x_train, y_train, discriminator, generator)
            
        gen_pic(generator, x_train, y_train)  

    

Epoch 1/100


  'Discrepancy between trainable weights and collected trainable'


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100


In [35]:
def train_disc(idx, x_train, y_train, discriminator, generator):
    #generate mini-batch data
    x, y, aux_y = gen_disc_data(idx, x_train, generator)
    
    # set weights so that discriminator does not try to infer 
    # class from fake images
    disc_sample_weight = [np.ones(2 * batch_size),
                      np.concatenate((np.ones(batch_size) * 2,
                                      np.zeros(batch_size)))]
    #train discriminator
    discriminator.train_on_batch(x, [y, aux_y], sample_weight=disc_sample_weight)

In [40]:
 def train_gen(idx, x_train, y_train, discriminator, generator, num_classes=10, batch_size=100):
    # make new noise. we generate 2 * batch size here such that we have
    # the generator optimize over an identical number of images as the
    # discriminator
    noise = np.random.uniform(-1, 1, (2 * batch_size, latent_size))
    sampled_labels = np.random.randint(0, num_classes, 2 * batch_size)

    # we want to train the generator to trick the discriminator
    # For the generator, we want all the {fake, not-fake} labels to say
    # not-fake
    trick = np.ones(2 * batch_size) * .95
    combined.train_on_batch(
        [noise, sampled_labels.reshape((-1, 1))],
        [trick, sampled_labels])

### Discriminator Details

In [38]:
def gen_disc_data(idx, x_train, generator, num_classes=10, batch_size=100):
    #sample class labels
    sampled_labels = np.random.randint(0, num_classes, batch_size)
    
    #generate data conditioned on labels
    x = gen_img_pair(idx, x_train, sampled_labels, generator)
    y = np.array([.95] * batch_size + [0] * batch_size)
    aux_y = gen_class_targets(idx, y_train)
    
    return x, y, aux_y

In [37]:
def gen_img_pair(idx, x_train, sampled_labels, generator, batch_size=100, latent_size=100, num_classes=10):
        # generate a new batch of noise
        noise = np.random.uniform(-1, 1, (batch_size, latent_size))

        # get a batch of real images
        image_batch = x_train[idx * batch_size:(idx + 1) * batch_size]

        # generate a batch of fake images, using the generated labels as a
        # conditioner. We reshape the sampled labels to be
        # (batch_size, 1) so that we can feed them into the embedding
        # layer as a length one sequence
        generated_images = generator.predict(
            [noise, sampled_labels.reshape((-1, 1))], verbose=0)

        return np.concatenate((image_batch, generated_images))

In [45]:
def gen_class_targets(idx, y_train, num_classes=10, batch_size=100):
    sampled_labels = np.random.randint(0, num_classes, batch_size)
    label_batch = y_train[idx * batch_size:(idx + 1) * batch_size]
    return np.concatenate((label_batch, sampled_labels), axis=0)
    

## Visualization

In [46]:
# generate some digits to display
def gen_pic(generator, x_train, y_train, num_rows=40, latent_size=100, num_classes=10):
    noise = np.tile(np.random.uniform(-1, 1, (num_rows, latent_size)),
                    (num_classes, 1))

    sampled_labels = np.array([
        [i] * num_rows for i in range(num_classes)
    ]).reshape(-1, 1)

    # get a batch to display
    generated_images = generator.predict(
        [noise, sampled_labels], verbose=0)

    # prepare real images sorted by class label
    real_labels = y_train[(epoch - 1) * num_rows * num_classes:
                          epoch * num_rows * num_classes]
    indices = np.argsort(real_labels, axis=0)
    real_images = x_train[(epoch - 1) * num_rows * num_classes:
                          epoch * num_rows * num_classes][indices]

    # display generated images, white separator, real images
    img = np.concatenate(
        (generated_images,
         np.repeat(np.ones_like(x_train[:1]), num_rows, axis=0),
         real_images))

    # arrange them into a grid
    img = (np.concatenate([r.reshape(-1, 28)
                           for r in np.split(img, 2 * num_classes + 1)
                           ], axis=-1) * 127.5 + 127.5).astype(np.uint8)
    Image.fromarray(img).save(  'plot{0:.2f}_generated.png'.format(np.random.rand()))

## Training Issues and Solutions

## Metrics

## Wasserstein GAN

## Homework

Prove the theorem about the optimal discriminator.
