# WGAN

- **DCGAN의 문제점:** 
    1) discriminator와 generator 간의 균형을 유지하며 학습하기 어려움
    2) 학습이 완료된 이후에도 mode dropping이 발생 (**mode collapsing**: 생성자가 하나의 최빈값에 치우쳐 변환)
- **원인:** discriminator가 충분히 제 역할을 해주지 못해 모델이 최적점까지 학습되지 못함


- **Wasserstein GAN의 차별점:**
    1) discriminator 대신 새로 정의한 critic을 사용. discriminator는 Real/Fake를 판별하기 위해 'sigmoid'를 사용, output은 예측 확률값
    2) critic은 EM(Earth Mover) distance로 부터 얻은 scalar값을 이용
    3) EM distance는 확률분포 간의 거리를 측정하는 척도 (기존 척도인 KL divergence는 매우 strict하기 때문에 비연속적인 경우가 있어 학습에 어려움이 있음)

$
\textbf{Entropy}: H(q) = - \sum\limits^C_{c=1}q(y_c)log(q(y_c)) \\
\rightarrow C:\text{범주의 개수} \\
\quad q:\text{Probability mass function} \\
\textbf{Cross-entropy}: H_p(q) = - \sum\limits^C_{c=1}q(y_c)log(p(y_c)) \\
\rightarrow q: \text{실제 분포} \\
\quad p: \text{예측 분포} \\
\quad \text{Training Data에서는 q를 알기 때문에 dissimilarity를 계산하는데 사용할 수 있음} \\
\quad \textbf{Cross-entropy} > \textbf{Entropy} \\
\textbf{Cross-Entropy of Loss Function}: -\cfrac{1}{n}\sum\limits^n_{i=1}\sum\limits^C_{c=1}L_{ic}log(P_{ic})
$

$
\textbf{Kullback-Leibler Divergence}: D_{KL}(q||p) = -\sum\limits^C_{c=1}q(y_c)[log(p(y_c))-log(q(y_c))] = H_p(q) - H(q) \\
\rightarrow Hp(q) \geq H(q) \\
\quad \text{예측분포 p를 실제분포 q에 가깝게 하는 것이 목표} \\
\quad \text{cross-entropy를 최소화 시키는 것이 KL Divergence를 최소화시키는 것} \\
\quad \text{이것이 불확실성을 제어하고자하는 예측모형의 실질적인 목적}
$

$
\textbf{KL - Divergence }\text{is not symmetric} \\
\rightarrow D_{KL}(P||Q) \ne D_{KL}(Q||P) \\
\quad \text{Can't use Distance Metric!}
$

$
\textbf{Jensen - Shannon Divergence}: JSD(P,Q) = \cfrac{1}{2}D_{KL}(P||M) + \cfrac{1}{2}D_{KL}(Q||M) \\
\qquad \qquad \qquad \qquad \qquad \qquad where \; M = \cfrac{1}{2}(P+Q) \\
\rightarrow JSD(P,Q) = JSD(Q, P) \\
$

$
\textbf{The Total Variation(TV) distance}: \delta(\mathbb{P}_r,\mathbb{P}_g) = \sup\limits_{A\in\sum}|\mathbb{P}_r(A) - \mathbb{P}_g(A)| \\
\rightarrow \text{두 확률분포의 측정값이 벌어질 수 있는 가장 큰 값} \\
$

$
\textbf{Earth-Mover(EM) distance or Wasserstein-1}: W(\mathbb{P}_r,\mathbb{P}_g) = \inf\limits_{\gamma\in\prod(\mathbb{P}_r,\mathbb{P}_g)} \mathbb{E}_{(x,y)\thicksim\gamma}[\lVert x-y\rVert] \\
\rightarrow \text{두 확률분포의 결합확률분포 } \prod(\mathbb{P}_r,\mathbb{P}_g) \text{ 중에서 x와 y 거리의 기대값을 가장 작게 추정한 값} \\
\rightarrow \text{얼마나 많은 질량 }\gamma(x,y) \text{를 } d=\lVert x-y \rVert \text{만큼 옮겨야하는지에 대한 지표}
$

### GAN에서의 거리함수
$
L^{(D)} = - \int_x p_{data}(x)logD(x)dx - \int_x p_g(x)log(1-D(x))dx \\
\rightarrow L^{(D)} = - \int_x(p_{data}(x)logD(x) + p_g(x)log(1-D(x)))dx \\
\quad (y \rightarrow alogy + blog(1-y)) \\
\rightarrow \cfrac{a}{a+b} \text{ 에서 최대값을 갖음} \\
\rightarrow D^{*}(x) = \cfrac{p_{data}}{p_{data} + p_{g}} \\
L^{(D^*)} = \mathbb{E}_{x\thicksim p_{data}}log\cfrac{p_{data}}{p_{data}+p_g} - \mathbb{E}_{x\thicksim p_{g}}log\cfrac{p_{g}}{p_{data}+p_g} \\
\rightarrow 2log2 - D_{KL}\bigg[p_{data}\Vert \cfrac{p_{data}+p_g}{2}\bigg] - D_{KL}\bigg[p_{g}\Vert \cfrac{p_{data}+p_g}{2}\bigg] \\
\rightarrow L^{(D^*)} = 2log2-2D_{JS}(p_{data} \Vert p_g) \\
\rightarrow L^{(D^*)} \text{를 최소화 한다는 것은 } D_{JS}(p_{data} \Vert p_g) \text{를 최대화 한다는 것} \\
\quad \text{즉, 판별기가 실제 데이터에서 가짜 데이터를 정확하게 분류한다는 것을 의미함}
$

$
\text{최적의 생성기는 생성기 분포가 실제 데이터 분포와 동일한 경우일 때 형성} \\
\text{즉, } G^{(*)}(x) \rightarrow p_g = p_{data} \text{ 를 의미 }
\text{최적의 생성기가 주워졌을 때, 최적의 판별기는} \\
\rightarrow D^{(*)}(x) = \cfrac{p_{data}}{p_{data} + p_{g}} = \cfrac{1}{2} \\
\rightarrow L^{(*)} = 2log2 = 0.60
$

### 두 분포가 겹치는 영역이 없는 경우
$
p_{data} = (x,y) \quad where \; x = 0, \; y \thicksim U(0,1) \\
p_g = (x,y) \quad where \; x = \theta, y \thicksim U(0,1) \\
$

$
\cdot D_{KL}(p_g \Vert p_{data}) = \mathbb{E}_{x=\theta,y\thicksim U(0,1)}\log\cfrac{p_g(x,y)}{p_{data}(x,y)} = \sum 1 \log\cfrac{1}{0} = + \infty \\
\cdot D_{JS}(p_g \Vert p_{data}) = \cfrac{1}{2}\mathbb{E}_{x=0, y\thicksim U(0,1)}\log\cfrac{p_{data}(x,y)}{\cfrac{p_{data}(x,y) + p_g{(x,y)}}{2}} + \cfrac{1}{2}\mathbb{E}_{x=\theta,y\thicksim U(0,1)}\log\cfrac{p_g(x,y)}{\cfrac{p_{data}(x,y)+p_g(x,y)}{2}} \\
\qquad\qquad\qquad = \cfrac{1}{2}\sum1\log\cfrac{1}{\frac{1}{2}} + \cfrac{1}{2}\sum1\log\cfrac{1}{\frac{1}{2}} = \log2 \\
\cdot W(p_{data},p_g) = |\theta|
$

# MNIST

In [11]:
%reset

Once deleted, variables cannot be recovered. Proceed (y/[n])?  y


In [12]:
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
import os
import math
%matplotlib inline
%config InlineBackend.figure_format='retina'

from keras.layers import concatenate, Dense, Reshape, BatchNormalization, Activation, Conv2DTranspose
from keras.layers import Conv2D, LeakyReLU, Flatten

In [13]:
from keras import backend as K
from keras.models import load_model

In [14]:
def plot_images(generator,
                noise_input,
                noise_label=None,
                noise_codes=None,
                show=False,
                step=0,
                model_name="gan"):
    """
    # Arguments
        generator (Model)
        noise_input (ndarray)
        show (bool)
        step (int)
        model_name (string)

    """
    filepath = os.path.join(model_name, "generated")
    os.makedirs(filepath, exist_ok=True)
    
    filename = os.path.join(filepath, "%05d.png" %step)
    
    rows = int(math.sqrt(noise_input.shape[0]))
    if noise_label is not None:
        noise_input = [noise_input, noise_label]
        if noise_codes is not None:
            noise_input += noise_codes

    images = generator.predict(noise_input, verbose=None)
    plt.figure(figsize=(2.2, 2.2))
    num_images = images.shape[0]
    image_size = images.shape[1]
    for i in range(num_images):
        plt.subplot(rows, rows, i + 1)
        image = np.reshape(images[i], [image_size, image_size])
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.savefig(filename)
    if show:
        plt.show()
    else:
        plt.close('all')

In [15]:
def build_generator(inputs, image_size):
    # Stack: BN-ReLU-Conv2DTranspose
    # inputs: z-vector(noise)
    # image_size: Target size
    # return: Model
    
    image_resize = image_size // 4
    kernel_size = 5
    layer_filters = [128, 64, 32, 1]
    
    x = Dense(image_resize*image_resize*layer_filters[0])(inputs)
    x = Reshape((image_resize, image_resize, layer_filters[0]))(x)
    
    for filters in layer_filters:
        # 1st, 2nd Conv layers: strides = 2
        # 3rd, 4th Conv layers: strides = 1
        if filters > layer_filters[-2]: #128, 64
            strides = 2
        else:
            strides = 1
        # BN-ReLU-Conv2DTranspose
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv2DTranspose(filters=filters,
                            kernel_size=kernel_size,
                            strides=strides,
                            padding='same'
                           )(x)
        
    x = Activation('sigmoid')(x)
    generator = keras.Model(inputs, x, name='generator')
    return generator

In [16]:
def build_discriminator(inputs, activation='sigmoid'):
    # BN으로는 수렴하지 않음
    # Stack: LeakyReLU-Conv2D
    # inputs: Image
    # return: Model
    
    kernel_size = 5
    layer_filters = [32, 64, 128, 256]
    
    x = inputs
    for filters in layer_filters:
        # 1st, 2nd, 3rd Conv layers: strides = 2
        # 4th Conv layers: strides = 1
        if filters == layer_filters[-1]:
            strides = 1
        else:
            strides = 2
        # LeakyReLU-Conv2D
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(filters=filters,
                   kernel_size=kernel_size,
                   strides=strides,
                   padding='same'
                  )(x)
    x = Flatten()(x)
    x = Dense(1)(x)
    x = Activation(activation)(x)
    discriminator = keras.Model(inputs, x, name='discriminator')
    return discriminator

In [17]:
def wasserstein_loss(y_label, y_pred):
    return -K.mean(y_label * y_pred)

In [18]:
def build_and_train_models():
    (x_train, _), (_, _) = keras.datasets.mnist.load_data()
    
    # Reshape & Normalize
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32')/255
    
    model_name = 'MNIST_WGAN'
    
    # Additional wgan params
    n_critic = 5
    clip_value = 0.01
    
    # Network Params
    latent_size = 100 # z-vector dimension
    batch_size = 64
    train_steps = 40000
    lr = 5e-5 # DCGAN: 2e-4
    # decay = 6e-8 # decays the lerning rate over time; DCGAN
    input_shape = (image_size, image_size, 1)

    # In Keras 2.11.0, 'decay' argument changed to 'weight_decay'
    # Discriminator model
    inputs = keras.Input(shape=input_shape, name='discriminator_input')
    #discriminator = gan.discriminator(inputs, activation='linear')
    discriminator = build_discriminator(inputs, activation='linear')
    optimizer = keras.optimizers.legacy.RMSprop(learning_rate=lr)
    discriminator.compile(loss=wasserstein_loss,
                          optimizer=optimizer,
                          metrics=['accuracy']
                         )
    discriminator.summary()
    
    # Generator model
    input_shape = (latent_size, )
    inputs = keras.Input(shape=input_shape, name='z_input')
    #generator = gan.generator(inputs, image_size)
    generator = build_generator(inputs, image_size)
    generator.summary()
    
    # Adversarial model
    #optimizer = keras.optimizers.RMSprop(learning_rate=lr*0.5, decay=decay*0.5,)
    discriminator.trainable = False
    ## Adversarial = Generator + Discriminator
    adversarial = keras.Model(inputs, discriminator(generator(inputs)), name='adversarial')
    adversarial.compile(loss=wasserstein_loss,
                        optimizer=optimizer,
                        metrics=['accuracy']
                       )
    adversarial.summary()
    
    models = (generator, discriminator, adversarial)
    params = (batch_size, latent_size, n_critic, clip_value, train_steps, model_name)
    train(models, x_train, params)

In [19]:
def train(models, x_train, params):
    # Discriminator와 Adversarial Model을 배치 단위로 번갈아 훈련
    ## Discriminator는 제대로 레이블이 붙은 진짜와 가짜 이미지를 가지고 훈련
    ## Adversarial는 진짜인 척하는 가짜 이미지로 훈련
    
    # GAN Model
    generator, discriminator, adversarial = models
    
    # Network Params
    # batch_size, latent_size, train_steps, model_name = params
    batch_size, latent_size, n_critic, clip_value, train_steps, model_name = params
    
    # Save Generator Images every 500 epochs
    save_interval = 500
    
    # Noise_Input
    noise_input = np.random.uniform(-1.0, 1.0, size=[16, latent_size]) # 16 x 100
    
    train_size = x_train.shape[0]
    real_labels = np.ones((batch_size, 1))
    for i in range(train_steps): # train_steps: 40,000
        
        # First - Discriminator Train
        # train discriminator n_critic times
        loss = 0
        acc = 0
        for _ in range(n_critic):
            # 1 batch of real (label=1.0) and fake images (label = -1.0)
            # randomly pick real images from dataset
            rand_indexes = np.random.randint(0, train_size, size=batch_size)
            real_images = x_train[rand_indexes]
            # generate fake images from noise using generator
            # generate noise using uniform distribution
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_size])
            fake_images = generator.predict(noise, verbose=None)
        
            # train the discriminator network
            # real data label=1, fake data label=-1
            # instead of 1 combined batch of real and fake images,
            # train with 1 batch of real data first, then 1 batch of fake images.
            # this tweak prevents the gradient from vanishing due to opposite signs of real and fake data labels
            # and small magnitude of weights due to clipping.
            real_loss, real_acc = discriminator.train_on_batch(real_images, real_labels)
            fake_loss, fake_acc = discriminator.train_on_batch(fake_images, -real_labels)
            
            # accumulate average loss and accuracy
            loss += 0.5 * (real_loss + fake_loss)
            acc += 0.5 * (real_acc + fake_acc)
            
            # clip discriminator weights to satisfy Lipschitz constraint
            for layer in discriminator.layers:
                weights = layer.get_weights()
                weights = [np.clip(weight, -clip_value, clip_value) for weight in weights]
                layer.set_weights(weights)
        
        # average loss and accuracy per n_critic training iterations
        loss /= n_critic
        acc /= n_critic
        log = "%d: [discriminator loss: %f, acc: %f]" %(i, loss, acc)
        
        # Second - Adversarial Train
        # train the adversarial network for 1 batch
        # 1 batch of fake images with label=1.0
        # since the discriminator weights are frozen in adversarial network
        # only the generator is trained
        # generate noise using uniform distribution
        noise = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_size])

        # train the adversarial network
        # note that unlike in discriminator training,
        # we do not save the fake images in a variable
        # the fake images go to the discriminator input of the adversarial
        # for classification
        # fake images are labelled as real
        # log the loss and accuracy
        loss, acc = adversarial.train_on_batch(noise, real_labels)
        log = "%s: [adversarial loss: %f, acc: %f]" %(log, loss, acc)
        if i % 100 == 0:
            print(log)
        
        # Show generator images per 500 epochs
        if (i+1) % save_interval == 0: # 500
            if (i+1) == train_steps: # 40,000
                show = True
            else:
                show = False
                
            # Make generator images per 500 epochs
            plot_images(generator,
                        noise_input=noise_input,
                        show=show,
                        step=(i+1),
                        model_name=model_name
                       )
    
    # save the model after training the generator
    # the trained generator can be reloaded for future MNIST digit generation
    generator.save(model_name + "/mnist_wgan.h5")

In [None]:
if __name__=='__main__':
    build_and_train_models()

"""
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    help_ = "Load generator h5 model with trained weights"
    parser.add_argument("-g", "--generator", help=help_)
    args = parser.parse_args() # args에 위 내용 저장
    if args.generator:
        generator = load_model(args.generator)
        gan.test_generator(generator)
    else:
        build_and_train_models()
"""

Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 discriminator_input (InputL  [(None, 28, 28, 1)]      0         
 ayer)                                                           
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 28, 28, 1)         0         
                                                                 
 conv2d_4 (Conv2D)           (None, 14, 14, 32)        832       
                                                                 
 leaky_re_lu_5 (LeakyReLU)   (None, 14, 14, 32)        0         
                                                                 
 conv2d_5 (Conv2D)           (None, 7, 7, 64)          51264     
                                                                 
 leaky_re_lu_6 (LeakyReLU)   (None, 7, 7, 64)          0         
                                                     

In [None]:
from keras.models import load_model

generator = load_model("./MNIST_WGAN/mnist_wgan.h5")
noise = np.random.uniform(-1.0, 1.0, size=[16, 100])
gan.plot_images(generator,
                noise_input=noise,
                show=True,
                model_name="./MNIST_WGAN/test_image"
               )