<a href="https://colab.research.google.com/github/TA-aiacademy/course_3.0/blob/v2-5_gan/08_v2-5_GAN/Part1/03_WGAN_GP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# WGAN-GP on MNIST

### 本章節內容大綱
* [WGAN-GP原理](#WGAN-GP原理)
* [使用 Gradient penalty 取代 weight clipping](#使用-Gradient-penalty-取代-weight-clipping)

# WGAN-GP原理

WGAN用了比較粗糙的作法——weight clipping，讓discriminator變得平滑而接近1-Lipschitz function，然而這種做法有以下兩點問題：<br>
第一、最後Discriminator的參數很容易收斂到clipping範圍的兩個邊界值，只能學習出簡單的function

<br>
<img src='https://hackmd.io/_uploads/Hye-8m-eT.jpg'>
<br>

第二、很容易因為clipping範圍些微差異的設定，在Discriminator network各層間出現梯度爆炸或梯度消失

<br>
<img src='https://hackmd.io/_uploads/r1JBI7-gT.jpg'>
<br>

所以之後就有了改良版——WGAN-GP(gradient penalty)，取代了weight clipping，WGAN-GP將Wasserstein distance改成下式：

$$W(P_{data}, P_G) = \max_D (E_{x\sim P_{data}}\; [D(x)] - E_{x\sim P_G}\;[D(x)] - \lambda E_{x\sim P_{penalty}}\;\;[max(0,||\nabla_x D(x)|| -1)])$$

因為$||\nabla_x D(x)|| \leq 1$ (Discriminator 對 input 的 gradient 小於或等於 1)就等價於 1-Lipschitz function ，所以 WGAN-GP 就加上了 discriminator gradient 減去 1 的懲罰項以近似 1-Lipschitz function ，然而實際實驗過後，作者將其改寫成下式（將其 gradient 接近 1），訓練效果會更好：

$$W(P_{data}, P_G) = \max_D (E_{x\sim P_{data}}\; [D(x)] - E_{x\sim P_G}\;[D(x)] - \lambda E_{x\sim P_{penalty}}\;\;[(||\nabla_x D(x)|| -1)^2])$$

其中 lambda 是 penalty weight，P_penalty 是在真實資料與生成資料之中 sample 資料，對其 gradient 作懲罰即可，而不需要對整個 data space 都作 graident。

# Import

In [None]:
''' basic package '''
import os
import time
import imageio
import glob
from IPython.display import display, Image

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import (
    Dense, Conv2DTranspose, Conv2D, BatchNormalization,
    LeakyReLU, Dropout, Reshape, Flatten
)

from tensorflow.keras.losses import BinaryCrossentropy

from tensorflow.keras.optimizers import Adam, RMSprop

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

# Config

In [None]:
BATCH_SIZE = 128
BUFFER_SIZE = 60000
z_dim = 100
EPOCHS = 50
learning_rate = 1e-4
num_examples_to_generate = 16
gp_weight = 10

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# Train

### 定義 generator & discriminator

In [None]:
# Define model


class Generator(Model):
    def __init__(self, z_dim):
        super(Generator, self).__init__()

        self.model = Sequential()

        # [z_dim] => [7, 7, 128]
        self.model.add(Dense(7 * 7 * 128, use_bias=False, input_shape=(z_dim,)))
        self.model.add(LeakyReLU())
        self.model.add(Reshape((7, 7, 128)))

        # [7, 7, 128] => [14, 14, 64]
        self.model.add(Conv2DTranspose(64, 5, strides=2, padding='same', use_bias=False))
        self.model.add(LeakyReLU())

        # [14, 14, 64] => [28, 28, 1]
        self.model.add(Conv2DTranspose(1, 5, strides=2, padding='same', activation='tanh', use_bias=False))

    def call(self, x):
        return self.model(x)


class Discriminator(Model):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = Sequential()

        # [28, 28, 1] => [14, 14, 64]
        self.model.add(Conv2D(64, 5, strides=2, padding='same', input_shape=(28, 28, 1)))
        self.model.add(LeakyReLU())

        # [14, 14, 64] => [7, 7, 128]
        self.model.add(Conv2D(128, 5, strides=2, padding='same'))
        self.model.add(LeakyReLU())

        # [7, 7, 128] => [4, 4, 256]
        self.model.add(Conv2D(256, 5, strides=2, padding='same'))
        self.model.add(LeakyReLU())

        # [4, 4, 256] => [1]
        self.model.add(Flatten())
        self.model.add(Dense(1))

    def call(self, x):
        return self.model(x)

# 使用 Gradient penalty 取代 weight clipping

跟WGAN或Vanilla GAN的最大差異就是在loss function中多了gradient penalty，基本上就是從真實資料與生成資料中sample資料，然後對discriminator作gradient descend後得到該資料點上的graident，取完norm後減一再平方就是我們要的gradient penalty了！

<br>
<img src="https://hackmd.io/_uploads/BJxwL7bxa.jpg" width=500  />

In [None]:
def gradient_penalty(real_images, fake_images):

    # 先從 0 ~ 1 隨機 sample 一組權重(shape 必須與圖片的 tensor 吻合)
    epsilon = tf.random.uniform([real_images.shape[0], 1, 1, 1], 0.0, 1.0)

    # 將權重乘上一組真資料與假資料，這樣就等同於從真假資料之間 sample 一筆資料
    x_hat = epsilon * real_images + (1 - epsilon) * fake_images

    with tf.GradientTape() as t:
        # watch 這個 method 是確保 tape 能夠指認要微分的對象
        t.watch(x_hat)
        d_hat = discriminator(x_hat)
    gradients = t.gradient(d_hat, x_hat)

    # 先取 gradient 的 norm ， 減一後再平方
    g_norm = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1, 2]))
    gradient_penalty = tf.reduce_mean((g_norm - 1.0) ** 2)
    return gradient_penalty


def gan_loss(d_real_output, d_fake_output, real_images, fake_images):

    # 與vanilla GAN 不同的地方是不加 log 而是直接用 output 來算gradient

    # discriminator loss
    d_loss = tf.reduce_mean(d_fake_output) - tf.reduce_mean(d_real_output) + gradient_penalty(
        real_images, fake_images) * gp_weight

    # generator loss
    g_loss = tf.reduce_mean(-d_fake_output)
    return d_loss, g_loss

# Model, seed and checkpoint setting

In [None]:
generator = Generator(z_dim)
discriminator = Discriminator()

g_optimizer = RMSprop(learning_rate)
d_optimizer = RMSprop(learning_rate)

seed = tf.random.normal([num_examples_to_generate, z_dim])

save_dir = './saved_imgs_GP'
checkpoint_dir = './training_checkpoints_GP'

if not os.path.exists(save_dir):
    os.makedirs(save_dir)
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(g_optimizer=g_optimizer,
                                 d_optimizer=d_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [None]:
@tf.function
def train_step(real_images, generator, discriminator, g_optimizer, d_optimizer):
    noise = tf.random.normal([real_images.shape[0], z_dim])
    with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
        fake_images = generator(noise, training=True)
        d_real_logits = discriminator(real_images)
        d_fake_logits = discriminator(fake_images)

        d_loss, g_loss = gan_loss(d_real_logits, d_fake_logits, real_images, fake_images)

    g_gradients = g_tape.gradient(g_loss, generator.trainable_variables)
    d_gradients = d_tape.gradient(d_loss, discriminator.trainable_variables)

    g_optimizer.apply_gradients(zip(g_gradients, generator.trainable_variables))
    d_optimizer.apply_gradients(zip(d_gradients, discriminator.trainable_variables))

    return d_loss, g_loss

In [None]:
def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()

        for image_batch in dataset:
            d_loss, g_loss = train_step(image_batch, generator, discriminator, g_optimizer, d_optimizer)

        # Produce images
        print('Time for epoch {} is {} sec'.format(epoch + 1, time.time() - start))
        print('discriminator loss: %.5f' % d_loss)
        print('generator loss: %.5f' % g_loss)
        generate_and_save_images(generator, epoch + 1, seed, save_dir)

        # Save the model every 25 epochs
        if (epoch + 1) % 25 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
    # generating / saving after the final epoch
    generate_and_save_images(generator, epochs, seed, save_dir)
    checkpoint.save(file_prefix=checkpoint_prefix)

In [None]:
def generate_and_save_images(model, epoch, test_input, save_path):

    predictions = model(test_input)

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    # 每 5 個 epoches 存一次圖片
    if (epoch + 1) % 5 == 0:
        plt.savefig(os.path.join(save_path, 'image_at_epoch_{:04d}.png'.format(epoch)))

    plt.show()

In [None]:
%%time
train(train_dataset, EPOCHS)

In [None]:
# 使用imageio製作gif圖
anim_file = 'saved_imgs_GP/wgan-gp.gif'

with imageio.get_writer(anim_file, mode='I') as writer:

    filenames = glob.glob('saved_imgs_GP/image*.png')
    filenames = sorted(filenames)
    last = -1
    for i, filename in enumerate(filenames):
        frame = 2*(i**0.5)
        if round(frame) > round(last):
            last = frame
        else:
            continue
        image = imageio.imread(filename)
        writer.append_data(image)
    image = imageio.imread(filename)
    writer.append_data(image)

display(Image(filename=anim_file))

In [None]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
noise = tf.random.normal([1, z_dim])
img = generator(noise)

plt.imshow(img[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.show()