<a href="https://colab.research.google.com/github/Machine-Learning-Tokyo/Intro-to-GANs/blob/master/WassersteinGAN/DIY_WGAN_Solution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wasserstein GAN (WGAN)  -- Solutions

**This notebook is the solution to the DIY_WGAN notebook. Mind that lines that contain important changes from the previous version have been marked with a `#CHANGE` comment.**
```python
this_line _has_changed #CHANGE
```
---

This notebook gives indications on how to implement a Wasserstein GAN for the fashion MNIST dataset. We take our previous Conditional DCGAN implementation as a baseline and apply the necessary changes to convert it into a Wasserstein GAN. This means that if you run this file with no modification you would be just training a Conditional DCGAN on the fashion MNIST dataset.

## What is a Wasserstein GAN?

The Wasserstein GAN is a type of GAN that uses the Wasserstein distance to measure the difference between the . As opposed to the Kullback-Leibler divergence, that is the one usued in the original GAN.

The nuances of this new approach are discussed in the original paper: [Wasserstein GAN](https://arxiv.org/abs/1701.07875).

## What should we change in the code?

Although the mathematical formulation and demonstration of the Wasserstein GAN is relatively complicated, the changes needed to make an WGAN out of a normal GAN are not that many nor complicated.

In total there are 4 things we need to change:

+ __Activation of D__, as opposed to the original GAN, the WGAN has a linear activation, so we have to get rid of the sigmoid function at the output of D.
+ __Loss function__. The loss function is just the difference between the output of D for real samples and the output of D for generated samples. As opposed to the original GAN, the Discriminator in WGAN does not discriminate samples as being real-looking or fake-looking, but provides a measure of how close they are to the true distribution. For this reason it is often called Critic instead of Discriminator (although in practice it is common to keep calling it D while coding). **Mind that changes in this loss function are going to need changes in how we provide the targets**.
+ __Clipping weights and optimiser__. We need to guarantee that the function that D computes is K-Lipschitz. One way to ensure this is to clip the weights of D to a small value. This is not the only form and there are papers proposing new methods ([Improved Training of Wasserstein GANs](https://arxiv.org/pdf/1704.00028.pdf) and [Improving the Improved Training of Wasserstein GANs: A Consistency Term and Its Dual Effect](https://arxiv.org/abs/1803.01541).
+ __Training procedure: now D is trained more often__, in the WGAN it is not important to keep a balance between the training of D and G. Training D more frequently than G is actually desired.

In this exercise we will follow the indications int the original paper.

# Fashion cond-dc-GAN

- ~~mnist gan~~
- ~~fashion gan~~
- ~~fashion (32, 32, 1) gan~~
- ~~fashion dc-gan~~
- ~~fashion cond-dc-gan~~
- ~~interpolation fashion cond-dc-gan~~
- fashion cond-w-dc-gan
- celebA cond-w-dc-gan
- interpolation celebA cond-w-dc-gan

### Things to change in WGAN

+ Activation of D
+ Loss function
+ Clipping weights
+ Training procedure: now D is trained more often

### Imports

In [0]:
from keras.models import Model
from keras.layers import Input, Dense, BatchNormalization, Reshape, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import fashion_mnist
from keras.optimizers import Adam, RMSprop

from keras.layers import Conv2D, UpSampling2D, concatenate, Lambda
from keras.initializers import RandomNormal
from keras.utils import to_categorical
import keras.backend as K

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from PIL import Image
from matplotlib import animation, rc
from IPython.display import Image as ipyImage
from IPython.display import HTML

### Function to build the generator

The Generator does not change for our WGAN. The generator that we have previously coded stays the same.

In [0]:
def build_generator(noise_size, img_shape, num_classes):
  
  filters = 512
  k_size = 5, 5
  k_init = RandomNormal(0, 0.007)
  
  noise = Input((noise_size,))
  labels = Input((num_classes,))
  
  model_input = concatenate([noise, labels])
  
  x = Dense(4*4*filters, kernel_initializer=k_init, activation='relu')(model_input)
  x = Reshape((4, 4, filters))(x)  # 4, 4
  x = BatchNormalization()(x)
  x = UpSampling2D()(x)  # 8, 8
  
  x = Conv2D(filters // 2, k_size, padding='same', kernel_initializer=k_init, activation='relu')(x)
  x = BatchNormalization()(x)
  x = UpSampling2D()(x)  # 16, 16
  
  x = Conv2D(filters // 4, k_size, padding='same', kernel_initializer=k_init, activation='relu')(x)
  x = BatchNormalization()(x)
  x = UpSampling2D()(x)  # 32, 32
  
  img = Conv2D(img_shape[-1], k_size, padding='same', kernel_initializer=k_init, activation='tanh')(x)  # 32, 32, 1
  
  generator = Model([noise, labels], img)
  return generator

### Function to build the discriminator

The Discriminator does change. We have to get rid of the `sigmoid` activation of the last layer of the Discriminator.

In [0]:
def build_discriminator(img_shape, num_classes):
  
  filters = 512
  k_size = 5, 5
  k_init = RandomNormal(0, 0.007)
  
  img = Input(img_shape)  # 32, 32, 1
  labels = Input((num_classes,))  # ?, 10
  
  n_labels = Reshape([1, 1, num_classes])(labels)  # ?, 1, 1, 10
  n_labels = Lambda(lambda x: K.tile(x, [1, img_shape[0], img_shape[1], 1]))(n_labels)  # ?, 32, 32, 10
  
  model_input = concatenate([img, n_labels])  # ?, 32, 32, 11
  
  x = Conv2D(filters // 4, k_size, strides=(2, 2), padding='same', kernel_initializer=k_init)(model_input)
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha=0.2)(x)  # 16, 16
  
  x = Conv2D(filters // 2, k_size, strides=(2, 2), padding='same', kernel_initializer=k_init)(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha=0.2)(x)  # 8, 8
  
  x = Conv2D(filters, k_size, strides=(2, 2), padding='same', kernel_initializer=k_init)(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(alpha=0.2)(x)  # 4, 4
  
  x = Flatten()(x)
  validity = Dense(1, kernel_initializer=k_init, activation='linear')(x) #CHANGE
  
  discriminator = Model([img, labels], validity)
  return discriminator

### Function to compile the models

Another change must be done here. We do not want `binary_crossentropy` when we compile the Discriminator and the Combined model anymore. 

So we're going to implement our own loss function called `critic_loss`, you can use the following template.

```python
    def critic_loss(y_pred, y_true):
      """You must implement the value of loss
          An expression of one line is just enough"""
      return loss
```
 **Mind that changes in this loss function are going to need changes in how we provide the targets**
 
The Critic loss must output a very positive number if the samples are real, and very negative if the samples are fake. Also, remember that `y_pred` is a vector of values, so we are going to average over them to return a single value for the whole batch. Try to follow me now because this is complicated, here are the rules we need to follow for the case with **real** samples and with **fake** samples:

+ In the case that the samples are **real** we expect `y_pred` to be positive, so we can define the loss as the inverse of `y_pred`: if `y_pred` is positive (that's good) then the loss is negative (that's good, we're encouraging this behaviour); if `y_pred` is negative (that's bad) the loss gets positive (we're discouraging this behaviour).

  ```python
  real_loss = Average(-1 * y_pred)
  ```
+ A similar reasoning can be followed for the fake case. When samples are **fake** we expect `y_pred` to be negative. So we can use `y_pred` as the loss itself: if `y_pred` is positive (that's bad) then the loss is also positive (we're discouraging this behaviour); if `y_pred` is negative (that's good) then the loss is also negative (we're encouraging this behaviour).

  ```python
  fake_loss = Average(+1 * y_pred)
  ```

An easy way of putting both rules together is to set the targets `y_true` as `-1` (for the **real** case) and `+1` (for the **fake** case), and using the `K.mean()` function of keras. Remember to import the keras backend

```python
import keras.backend as K
K.mean(something_to_average)
```

When you finish defining your loss function, add it as a compile option for the Discriminator (or Critic) and the Combined model. Also, computing the accuracy of the Discriminator (`metrics=['accuracy']`) makes  not much sense in this case so it should be removed.

In [0]:
def critic_loss(y_pred, y_true):
  return K.mean(y_pred * y_true) #CHANGE

In [0]:
def get_compiled_models(generator, discriminator, noise_size, num_classes):
  
  optimizer = RMSprop(0.00005)
  
  discriminator.compile(optimizer, loss=critic_loss) #CHANGE
  discriminator.trainable = False
  
  noise = Input((noise_size,))
  labels = Input((num_classes,))
  
  img = generator([noise, labels])
  validity = discriminator([img, labels])
  combined = Model([noise, labels], validity)
  
  combined.compile(optimizer, loss=critic_loss) #CHANGE
  
  return generator, discriminator, combined

### Function to sample and save generated images

In [0]:
def sample_imgs(generator, noise_size, gen_loss_memory, step, plot_img=True, cond=False, num_classes=10):
  np.random.seed(0)
  
  r, c = num_classes, 10
  if cond:
    noise = np.random.normal(0, 1, (c, noise_size))
    noise = np.tile(noise, (r, 1))

    sampled_labels = np.arange(r).reshape(-1, 1)
    sampled_labels = to_categorical(sampled_labels, r)
    sampled_labels = np.repeat(sampled_labels, c, axis=0)

    imgs = generator.predict([noise, sampled_labels])
  else:
    noise = np.random.normal(0, 1, (r*c, noise_size))
    imgs = generator.predict_on_batch(noise)
  
  imgs = imgs / 2 + 0.5
  imgs = np.reshape(imgs, [r, c, imgs.shape[1], imgs.shape[2], -1])
  
  gs = gridspec.GridSpec(r, 2*c)
  
  figsize = 2 * c, 1 * r
  fig = plt.figure(figsize=figsize)
  
  for i in range(r):
    for j in range(c):
      img = imgs[i, j] if len(imgs.shape) == 4 else imgs[i, j, :, :, 0]
      plt.subplot(gs[i, j])
      plt.imshow(img, cmap='gray')
      plt.axis('off')
  plt.subplot(gs[:, c:])
  plt.plot(gen_loss_memory)
  plt.ylim(-0.05, 0.5)
  plt.gca().tick_params(axis='y', direction='in', pad=-20)
  #plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.savefig(f'/content/images/{step}.png')
  if plot_img:
    plt.show()
  plt.close()
  
  np.random.seed(None)

### Function to train the models

This function is the one that requires more change.
* __Loss function__, despite what we did on the code above, now we need to change the targets `y_true` (or `real_validity` and `gen_validity` in the code below).

+ __Clipping weights and optimiser__. We need to guarantee that the function that D computes is K-Lipschitz. One way to ensure this is to clip the weights of D to a small value.

  After processing a full batch of real and fake samples, and after actualising the weights accordingly, we need to iterate over the layers and to clip the weights to a small constant **c** (0.01) in the original paper.
  
+ __Training procedure: now D is trained more often__, in the WGAN it is not important to keep a balance between the training of D and G. Training D more frequently than G is actually desired.

  We can use a for loop to train D more.

In [0]:
def train(models, noise_size, img_shape, num_classes, batch_size, steps, n_critic, c): #CHANGE, add n_critic and c
  
  generator, discriminator, combined = models
  #get real data
  (X_train, Y_train), (X_val, Y_val) = fashion_mnist.load_data()
  fashion_mnist_imgs = np.concatenate((X_train, X_val)) / 127.5 - 1  # 100.000, 28, 28
  fashion_mnist_imgs = np.pad(fashion_mnist_imgs, ((0, 0), (2, 2), (2, 2)), 'constant', constant_values=-1)  # 100.000, 32, 32
  fashion_mnist_imgs = np.expand_dims(fashion_mnist_imgs, axis=-1)  # 100.000, 32, 32, 1
  fashion_mnist_labels = np.concatenate((Y_train, Y_val))
  
  gen_loss_memory = [] # to save gen_loss during training
  
  for step in range(1, steps + 1):
    # train discriminator
    for _ in range(n_critic): #CHANGE, add loop
      inds = np.random.randint(0, fashion_mnist_imgs.shape[0], batch_size)

      labels = fashion_mnist_labels[inds]
      labels = to_categorical(labels, num_classes)

      real_imgs = fashion_mnist_imgs[inds]
      real_validity = -np.ones(batch_size) #CHANGE

      noise = np.random.normal(0, 1, (batch_size, noise_size))
      gen_imgs = generator.predict([noise, labels])
      gen_validity = np.ones(batch_size) #CHANGE

      r_loss = discriminator.train_on_batch([real_imgs, labels], real_validity)
      g_loss = discriminator.train_on_batch([gen_imgs, labels], gen_validity)
      disc_loss = np.add(r_loss, g_loss) / 2
      gen_loss_memory.append(g_loss)
      
      #CHANGE, clip weights
      for layer in discriminator.layers:
        weights = layer.get_weights()
        clipped_weights = [np.clip(w, -c, c) for w in weights]
        layer.set_weights(clipped_weights)
    
    # train generator
    noise = np.random.normal(0, 1, (batch_size, noise_size))
    gen_validity = -np.ones(batch_size)
    gen_loss = combined.train_on_batch([noise, labels], gen_validity)
    
    #print progress
    if step % 50 == 0:
      print('step: %d, D_loss: %f G_loss: %f' % (step, disc_loss, gen_loss))
    
    # save_samples
    if step % 50 == 0:
      sample_imgs(generator, noise_size, gen_loss_memory, step, cond=True)

### Define hyperparameters

In [0]:
%rm -r /content/images
%mkdir /content/images
noise_size = 100
img_shape = 32, 32, 1
num_classes = 10
batch_size = 64
steps = 5000

#CHANGE, add WGAN specific parameters
n_critic = 5
c = 0.01

### Generate the models

In [0]:
generator = build_generator(noise_size, img_shape, num_classes)
discriminator = build_discriminator(img_shape, num_classes)
compiled_models = get_compiled_models(generator, discriminator, noise_size, num_classes)

### Train the models

In [0]:
train(compiled_models, noise_size, img_shape, num_classes, batch_size, steps, n_critic, c)

### Display samples

Let's start by checking the images that we have stored.

In [0]:
%ls /content/images

You can check any image you wish by doing:

In [0]:
image_number = 50
ipyImage('/content/images/%d.png' % image_number)

### Do an animation

Probably the best way of showing the training process is by doing an animation with all the images. The next cell will do it for you.

In [0]:
class AnimObject(object):
    def __init__(self, images):
        print(len(images))
        self.fig, self.ax = plt.subplots()
        self.ax.set_title("")
        self.fig.set_size_inches((20, 10))
        self.plot = plt.imshow(images[0])
        plt.tight_layout()
        self.images = images
        
    def init(self):
        self.plot.set_data(self.images[0])
        self.ax.grid(False)
        return (self.plot,)
      
    def animate(self, i):
        self.plot.set_data(self.images[i])
        self.ax.grid(False)
        self.ax.set_xticks([])
        self.ax.set_yticks([])
        self.ax.set_title("index {}".format(i))
        return (self.plot,)

def get_figures(template, indices):
    import os.path
    images = []
    for index in indices:
        if os.path.isfile(template.format(index)):
            images.append(Image.open(template.format(index)))
    return images


images = get_figures("/content/images/{}.png", 
                     range(0, 50 * len(listdir('/content/images')) + 1, 50))
print(images)
animobject = AnimObject(images)
anim = animation.FuncAnimation(
              animobject.fig,
              animobject.animate,
              frames=len(animobject.images),
              interval=150,
              blit=True)

In [0]:
HTML(anim.to_jshtml())

## What do we make of the WGAN?

Here are some takes on the WGAN compared to the original GAN formulation.

+ **It trains slower**. When using the Wasserstein distance to train the Generator, it is interesting to have a strong Discriminator. To do that we need to train the Discriminator more often and so we need more time to train the WGAN.

The following two points are well represented in this example, presumably because the fashion MNIST is a trivial (very easy) dataset.

+ **It is more stable**. Sometimes GANs can _unlearn_ in just a few iterations what took them several minutes to learn. WGAN tends to be more stable.

+ **Error indicator** that correlates with the quality of the images. 

Also, and because the Discriminator acts like a critic and not like a classifier, sometimes the generator error can be negative. This is the equivalent of the Discriminator considering generated samples as real instead of fake.