# CH.9. GAN

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('./mnist/data/',one_hot=True)

Extracting ./mnist/data/train-images-idx3-ubyte.gz
Extracting ./mnist/data/train-labels-idx1-ubyte.gz
Extracting ./mnist/data/t10k-images-idx3-ubyte.gz
Extracting ./mnist/data/t10k-labels-idx1-ubyte.gz


In [2]:
total_epoch=100
batch_size=100
learning_rate=0.0002
n_hidden=256
n_input=28*28
n_noise=128

## Placeholders

In [3]:
X=tf.placeholder(tf.float32,[None,n_input])
Z=tf.placeholder(tf.float32,[None,n_noise])

## Variables

### G_Variables

In [4]:
G_W1=tf.Variable(tf.random_normal([n_noise,n_hidden],stddev=0.01))
G_b1=tf.Variable(tf.zeros([n_hidden]))
G_W2=tf.Variable(tf.random_normal([n_hidden,n_input],stddev=0.01))
G_b2=tf.Variable(tf.zeros([n_input]))

### D_Variables

In [5]:
D_W1=tf.Variable(tf.random_normal([n_input,n_hidden],stddev=0.01))
D_b1=tf.Variable(tf.zeros([n_hidden]))
D_W2=tf.Variable(tf.random_normal([n_hidden,1],stddev=0.01))
D_b2=tf.Variable(tf.zeros([1]))


## GAN Architecture

In [6]:
def generator(noise_z):
    hidden=tf.nn.relu(tf.matmul(noise_z,G_W1)+G_b1)
    output=tf.nn.sigmoid(tf.matmul(hidden,G_W2)+G_b2)
    
    return output

def discriminator(inputs):
    hidden=tf.nn.relu(tf.matmul(inputs,D_W1)+D_b1)
    output=tf.nn.sigmoid(tf.matmul(hidden,D_W2)+D_b2)
    
    return output

#noise generator
def get_noise(batch_size,n_noies):
    return np.random.normal(size=(batch_size,n_noise))

G=generator(Z)
D_gene=discriminator(G)
D_real=discriminator(X)

In [7]:
loss_D=tf.reduce_mean(tf.log(D_real)+tf.log(1-D_gene))
loss_G=tf.reduce_mean(tf.log(D_gene))

In [8]:
D_var_list=[D_W1,D_b1,D_W2,D_b2]
G_var_list=[G_W1,G_b1,G_W2,G_b2]

train_D=tf.train.AdamOptimizer(learning_rate).minimize(-loss_D,var_list=D_var_list)  ## minus sign is for maximization
train_G=tf.train.AdamOptimizer(learning_rate).minimize(-loss_G,var_list=G_var_list)

## Session

In [9]:
sess=tf.Session()
sess.run(tf.global_variables_initializer())

total_batch=int(mnist.train.num_examples/batch_size)
loss_val_D,loss_val_G=0,0

## Learning Process

In [None]:
for epoch in range(total_epoch):
    for jmi in range(total_batch):
        batch_xs,batch_ys=mnist.train.next_batch(batch_size)
        noise=get_noise(batch_size,n_noise)
        
        _,loss_val_D=sess.run([train_D,loss_D],feed_dict={X:batch_xs,Z:noise})
        _,loss_val_G=sess.run([train_G,loss_G],feed_dict={Z:noise})
        
    print('Epoch:','%04d'%(epoch+1),'D loss: {:.4}'.format(loss_val_D),'G loss: {:.4}'.format(loss_val_G))
    
    if epoch ==0 or (epoch+1)%10==0:
        sample_size=10
        noise=get_noise(sample_size,n_noise)
        samples=sess.run(G,feed_dict={Z:noise})
        
        fig,ax=plt.subplots(1,sample_size,figsize=(sample_size,1))
        
        for jmi in range(sample_size):
            ax[jmi].set_axis_off()
            ax[jmi].imshow(np.reshape(samples[jmi],(28,28)))
            
        plt.savefig('samples/{}.png'.format(str(epoch).zfill(3)),bbox_inches='tight')
        plt.close(fig)
print('Optimization has been done!')

Epoch: 0001 D loss: -0.3069 G loss: -2.381
Epoch: 0002 D loss: -0.1007 G loss: -3.224
Epoch: 0003 D loss: -0.6004 G loss: -1.551
Epoch: 0004 D loss: -0.3725 G loss: -1.694
Epoch: 0005 D loss: -0.5326 G loss: -1.831
Epoch: 0006 D loss: -0.36 G loss: -2.184
Epoch: 0007 D loss: -0.3415 G loss: -2.438
Epoch: 0008 D loss: -0.3073 G loss: -2.744
Epoch: 0009 D loss: -0.311 G loss: -2.218
Epoch: 0010 D loss: -0.4089 G loss: -2.302
Epoch: 0011 D loss: -0.4294 G loss: -2.354
Epoch: 0012 D loss: -0.5052 G loss: -2.048
Epoch: 0013 D loss: -0.4266 G loss: -1.905
Epoch: 0014 D loss: -0.4111 G loss: -2.13
Epoch: 0015 D loss: -0.5616 G loss: -1.708
Epoch: 0016 D loss: -0.5562 G loss: -2.141
Epoch: 0017 D loss: -0.4431 G loss: -2.38
Epoch: 0018 D loss: -0.4699 G loss: -2.405
Epoch: 0019 D loss: -0.4736 G loss: -2.407
Epoch: 0020 D loss: -0.3418 G loss: -2.519
Epoch: 0021 D loss: -0.4724 G loss: -2.166
Epoch: 0022 D loss: -0.5251 G loss: -2.153
Epoch: 0023 D loss: -0.6858 G loss: -2.069
Epoch: 0024 D lo