## Generative Adversarial Networks

https://arxiv.org/abs/1406.2661

In [1]:
%matplotlib inline
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('../data/mnist/data', one_hot=True)

Extracting ../data/mnist/data\train-images-idx3-ubyte.gz
Extracting ../data/mnist/data\train-labels-idx1-ubyte.gz
Extracting ../data/mnist/data\t10k-images-idx3-ubyte.gz
Extracting ../data/mnist/data\t10k-labels-idx1-ubyte.gz


In [2]:
####################
# hyper parameters #
####################
total_epoch = 100
batch_size = 100
learning_rate = 0.0002
# Neural Layer params
n_hidden = 256
n_input = 28 * 28
n_noise = 128  # Generator의 입력값으로 사용할 노이즈의 크기

In [3]:
################
# Neural Layer #
################
# GAN은 Unsupervised Learning 이므로 Y(label)을 사용하지 않음
X = tf.placeholder(tf.float32, [None, n_input])
# Fake image를 생성할 Noise Z를 입력값으로 사용
Z = tf.placeholder(tf.float32, [None, n_noise])


# Generator 신경망에서 사용할 변수 설정
# G_W1 : input(noise, [128]) -> hidden
# G_W2 : hidden -> output(fake image, (28*28))
G_W1 = tf.get_variable(name="G_W1", shape=[n_noise, n_hidden], 
                       initializer=tf.contrib.layers.xavier_initializer())
G_b1 = tf.Variable(tf.zeros([n_hidden]))
G_W2 = tf.get_variable(name="G_W2", shape=[n_hidden, n_input], 
                       initializer=tf.contrib.layers.xavier_initializer())
G_b2 = tf.Variable(tf.zeros([n_input]))


# Discriminator 신경망에서 사용할 변수 설정
# D_W1 : input(mnist data, [28*28]) -> hidden
# D_W2 : hidden -> ouput(0(F) or 1(T), [1])
D_W1 = tf.get_variable(name="D_W1", shape=[n_input, n_hidden], 
                       initializer=tf.contrib.layers.xavier_initializer())
D_b1 = tf.Variable(tf.zeros([n_hidden]))
D_W2 = tf.get_variable(name="D_W2", shape=[n_hidden, 1], 
                       initializer=tf.contrib.layers.xavier_initializer())
D_b2 = tf.Variable(tf.zeros([1]))

In [4]:
#####################
# Generator         #
#    & Dicriminator #
#####################

# Generator(G) 신경망 구성
def generator(noise_z):
    hidden = tf.nn.relu(
                    tf.matmul(noise_z, G_W1) + G_b1)
    output = tf.nn.sigmoid(
                    tf.matmul(hidden, G_W2) + G_b2)
    return output


# Discriminator(D) 신경망 구성
def discriminator(inputs):
    hidden = tf.nn.relu(
                    tf.matmul(inputs, D_W1) + D_b1)
    output = tf.nn.sigmoid(
                    tf.matmul(hidden, D_W2) + D_b2)
    return output


# Random한 노이즈(Z)를 만드는 함수
def get_noise(batch_size, n_noise):
    return np.random.normal(size=(batch_size, n_noise))


# G: 노이즈 Z를 이용해 Fake image를 만들 Generator
# D_gene : Fake image에 대한 Discriminator
# D_real : Real image에 대한 Discriminator
G = generator(Z)
D_gene = discriminator(G)
D_real = discriminator(X)

In [5]:
#################
# Loss Function #
#################

# loss_D : Generator가 만든 이미지가 가짜라고 판별하도록 학습
# Discriminator를 학습시키기 위해서는 D_real은 1에 가까워야 하고(진짜라고 판단),
# D_gene은 0에 가까워야 함(가짜라고 판단)
loss_D = tf.reduce_mean(tf.log(D_real) + tf.log(1 - D_gene))

# loss_G : D_gene을 1(진짜)에 가깝도록 학습
loss_G = tf.reduce_mean(tf.log(D_gene))

# loss_D를 계산할 때 Discriminator에 사용되는 변수만 사용하고,
# loss_G를 계산할 때는 Generator에 사용되는 변수만사용하여 최적화함
D_var_list = [D_W1, D_b1, D_W2, D_b2]
G_var_list = [G_W1, G_b1, G_W2, G_b2]

# Optimize
train_D = tf.train.AdamOptimizer(learning_rate).minimize(-loss_D,
                                                         var_list=[D_var_list])
train_G = tf.train.AdamOptimizer(learning_rate).minimize(-loss_G, 
                                                         var_list=[G_var_list])

In [8]:
######################
# GAN model Training #
######################
sess = tf.Session()
sess.run(tf.global_variables_initializer())

total_batch = int(mnist.train.num_examples/batch_size)
# loss_D와 loss_G의 결과 저장할 변수 설정
loss_val_D, loss_val_G = 0, 0

for epoch in range(total_epoch):
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        noise = get_noise(batch_size, n_noise)

        # 판별기와 생성기 신경망을 각각 학습시킵니다.
        _, loss_val_D = sess.run([train_D, loss_D],
                                 feed_dict={X: batch_xs, Z: noise})
        _, loss_val_G = sess.run([train_G, loss_G],
                                 feed_dict={Z: noise})

    print('Epoch:', '%04d' % epoch,
          'D loss: {:.4}'.format(loss_val_D),
          'G loss: {:.4}'.format(loss_val_G))

        ##############
        # 학습이 되는지 보기위해 주기적으로 이미지 생성하여 저장
        ##############
    
    if epoch == 0 or (epoch + 1) % 10 == 0:
        sample_size = 10
        noise = get_noise(sample_size, n_noise)
        samples = sess.run(G, feed_dict={Z: noise})

        fig, ax = plt.subplots(1, sample_size, figsize=(sample_size, 1))

        for i in range(sample_size):
            ax[i].set_axis_off()
            ax[i].imshow(np.reshape(samples[i], (28, 28)))


        if not os.path.exists('samples'):
            os.mkdir('samples')

        plt.savefig('samples/{}.png'.format(str(epoch).zfill(3)), bbox_inches='tight')
        plt.close(fig)
            
print('최적화 완료!')

Epoch: 0000 D loss: -0.1416 G loss: -3.409
Epoch: 0001 D loss: -0.053 G loss: -4.053
Epoch: 0002 D loss: -0.05139 G loss: -4.372
Epoch: 0003 D loss: -0.01849 G loss: -5.322
Epoch: 0004 D loss: -0.02052 G loss: -5.565
Epoch: 0005 D loss: -0.01538 G loss: -5.122
Epoch: 0006 D loss: -0.1109 G loss: -3.944
Epoch: 0007 D loss: -0.03203 G loss: -4.435
Epoch: 0008 D loss: -0.1434 G loss: -3.865
Epoch: 0009 D loss: -0.09122 G loss: -3.956
Epoch: 0010 D loss: -0.161 G loss: -3.858
Epoch: 0011 D loss: -0.1149 G loss: -3.869
Epoch: 0012 D loss: -0.1571 G loss: -4.008
Epoch: 0013 D loss: -0.1327 G loss: -3.93
Epoch: 0014 D loss: -0.1901 G loss: -3.597
Epoch: 0015 D loss: -0.1769 G loss: -4.077
Epoch: 0016 D loss: -0.1671 G loss: -3.876
Epoch: 0017 D loss: -0.1662 G loss: -3.99
Epoch: 0018 D loss: -0.204 G loss: -3.704
Epoch: 0019 D loss: -0.2226 G loss: -3.626
Epoch: 0020 D loss: -0.4628 G loss: -2.669
Epoch: 0021 D loss: -0.2551 G loss: -3.95
Epoch: 0022 D loss: -0.06858 G loss: -4.263
Epoch: 002