<a href="https://colab.research.google.com/github/TA-aiacademy/course_3.0/blob/v2-5_gan/08_v2-5_GAN/Part1/02_WGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# WGAN on MNIST

### 本章節內容大綱
* [WGAN原理](#WGAN原理)
* [1-Lipschitz 實作](#1-Lipschitz-實作)

# WGAN原理

在vanilla GAN裡面，是使用JS divergence來衡量兩個分佈的遠近，不過JS divergence有一個缺點，就是當資料沒有交集的時候，它的值永遠會是log2；而生成資料與真實資料皆可看成高維空間中的低維manifolds(流形)，他們的交集基本上可以被忽略（或者說可以找到一個discriminator將其輕易得分開），那將會使JS divergence永遠維持在log2，而讓generator的gradient停在0無法更新。

<br>
<img src="https://hackmd.io/_uploads/SJJPXXWla.png" width=500  />

WGAN作者使用了Wasserstein distance(又作Earth mover's distance，EMD)來取代原本的JS divergence，簡單來說，EMD就是將P分佈變成Q分佈所需要的最小代價：

$$B(\gamma) = \sum_{x_p,x_q}\gamma(x_p,x_q)||x_p-x_q|| $$
$$W(P,Q) = \min_{\gamma \in \Pi} B(\gamma)$$

不過這個方法要窮舉所有的moving plan，WGAN作者透過複雜的數學推導，將下式作為 discriminator 的 objective function 直接衡量 $P_G$ 與 $P_{data}$ 之間的 Wasserstein distance：

$$V(G,D) = \max_{D \in 1-Lipschitz} (E_{x \sim P_{data}}\;[D(x)] - E_{x \sim P_{G}}\;[D(x)])$$

這邊的1-Lipschitz function簡單來說就是要discriminator變得平滑，讓generator能夠依循它的gradient更新。

# Import

In [None]:
''' basic package '''
import os
import time
import imageio
import glob
from IPython.display import display, Image

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import (
    Dense, Conv2DTranspose, Conv2D, BatchNormalization,
    LeakyReLU, Dropout, Reshape, Flatten
)

from tensorflow.keras.losses import BinaryCrossentropy

from tensorflow.keras.optimizers import Adam, RMSprop

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

# Config

In [None]:
BATCH_SIZE = 128
BUFFER_SIZE = 60000  # tf2.0 的 shuffle 需要定義「抽籤桶」要多大，設 60000 意指全部的資料
z_dim = 100  # latent/noise vector z 的維度
EPOCHS = 50
learning_rate = 1e-4
num_examples_to_generate = 16
clip = [-0.05, 0.05]  # 將 weight 限制在 - 0.05 ~ + 0.05 之間


(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')

# 將圖片正規化至 [-1 ~ 1]
train_images = (train_images - 127.5) / 127.5

train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# Train

### 定義 generator & discriminator

In [None]:
# Define model


class Generator(Model):
    def __init__(self, z_dim):
        super(Generator, self).__init__()

        self.model = Sequential()

        # [z_dim] => [7, 7, 128]
        self.model.add(Dense(7 * 7 * 128, input_shape=(z_dim,)))
        self.model.add(LeakyReLU())
        self.model.add(Reshape((7, 7, 128)))

        # [7, 7, 128] => [14, 14, 64]
        self.model.add(Conv2DTranspose(64, 5, strides=2, padding='same'))
        self.model.add(LeakyReLU())

        # [14, 14, 64] => [28, 28, 1]
        self.model.add(Conv2DTranspose(1, 5, strides=2, padding='same', activation='tanh'))

    def call(self, x):
        return self.model(x)


class Discriminator(Model):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = Sequential()

        # [28, 28, 1] => [14, 14, 64]
        self.model.add(Conv2D(64, 5, strides=2, padding='same', input_shape=(28, 28, 1)))
        self.model.add(LeakyReLU())

        # [14, 14, 64] => [7, 7, 128]
        self.model.add(Conv2D(128, 5, strides=2, padding='same'))
        self.model.add(LeakyReLU())

        # [7, 7, 128] => [4, 4, 256]
        self.model.add(Conv2D(256, 5, strides=2, padding='same'))
        self.model.add(LeakyReLU())

        # [4, 4, 256] => [1]
        self.model.add(Flatten())
        self.model.add(Dense(1))

    def call(self, x):
        return self.model(x)

In [None]:
def gan_loss(d_real_output, d_fake_output):

    # 與vanilla GAN 不同的地方是不加 log 而是直接用 output 來算gradient

    # discriminator loss

    d_loss = tf.reduce_mean(d_fake_output) - tf.reduce_mean(d_real_output)

    # generator loss
    g_loss = tf.reduce_mean(-d_fake_output)
    return d_loss, g_loss

In [None]:
generator = Generator(z_dim)
discriminator = Discriminator()

g_optimizer = RMSprop(learning_rate)
d_optimizer = RMSprop(learning_rate)

# 固定 seed 來確定我們之後產生的圖片品質是不是有比之前的好
seed = tf.random.normal([num_examples_to_generate, z_dim])

save_dir = './saved_imgs_wgan'
checkpoint_dir = './training_checkpoints_wgan'

if not os.path.exists(save_dir):
    os.makedirs(save_dir)
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(g_optimizer=g_optimizer,
                                 d_optimizer=d_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

# 1-Lipschitz 實作

因為 WGAN 作者在當時還沒想到比較好的方法實現 1-Lipschitz function 的限制，所以在 WGAN 中是直接使用 weight clipping 讓 discriminator 變得平滑，如果 c 選得夠小的話，的確可以讓 discriminator 是 1-Lipschitz function（只要斜率小於等於 1 就滿足條件），但這同時也限制了 Discriminator 的能力。所以實務上我們會想辦法兩者兼顧得調整 c ，讓 Discriminator 盡量接近 1-Lipschitz function 同時又保有一定的能力。

In [None]:
@tf.function
def train_step(real_images, generator, discriminator, g_optimizer, d_optimizer):
    noise = tf.random.normal([BATCH_SIZE, z_dim])
    with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
        fake_images = generator(noise)

        d_real_output = discriminator(real_images)
        d_fake_output = discriminator(fake_images)

        d_loss, g_loss = gan_loss(d_real_output, d_fake_output)

    g_gradients = g_tape.gradient(g_loss, generator.trainable_variables)
    d_gradients = d_tape.gradient(d_loss, discriminator.trainable_variables)

    g_optimizer.apply_gradients(zip(g_gradients, generator.trainable_variables))
    d_optimizer.apply_gradients(zip(d_gradients, discriminator.trainable_variables))

    '''weight clipping method in WGAN'''
    D_weight_clip_opt = [var.assign(tf.clip_by_value(var, clip[0], clip[1]))
                         for var in discriminator.trainable_variables]

    return d_loss, g_loss

In [None]:
def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()

        for image_batch in dataset:
            d_loss, g_loss = train_step(image_batch, generator, discriminator,
                                        g_optimizer, d_optimizer)

        # 產生圖片
        print('Time for epoch {} is {} sec'.format(epoch + 1, time.time() - start))
        print('discriminator loss: %.5f' % d_loss)
        print('generator loss: %.5f' % g_loss)
        generate_and_save_images(generator, epoch + 1, seed, save_dir)

        # 每 25 個 epochs 存一次模型
        if (epoch + 1) % 25 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)

    # 在最後一個 epoch 再產生一次圖片與儲存一次權重
    generate_and_save_images(generator, epochs, seed, save_dir)
    checkpoint.save(file_prefix=checkpoint_prefix)

In [None]:
def generate_and_save_images(model, epoch, test_input, save_path):

    predictions = model(test_input)

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    # 每 5 個 epoches 存一次圖片
    if (epoch + 1) % 5 == 0:
        plt.savefig(os.path.join(save_path, 'image_at_epoch_{:04d}.png'.format(epoch)))

    plt.show()

In [None]:
%%time
train(train_dataset, EPOCHS)

In [None]:
# 使用imageio製作gif圖
anim_file = 'saved_imgs_wgan/wgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:

    filenames = glob.glob('saved_imgs_wgan/image*.png')
    filenames = sorted(filenames)
    last = -1
    for i, filename in enumerate(filenames):
        frame = 2*(i**0.5)
        if round(frame) > round(last):
            last = frame
        else:
            continue
        image = imageio.imread(filename)
        writer.append_data(image)
    image = imageio.imread(filename)
    writer.append_data(image)

display(Image(filename=anim_file))

In [None]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
noise = tf.random.normal([1, z_dim])
img = generator(noise, training=False)

plt.imshow(img[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.show()