In [1]:
import tensorflow as tf
import os
import numpy as np
from model import *
from utility import *

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)

model_name = "WGAN_for_MNIST"

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [2]:
class Generator:
    def __init__(self):
        self.reuse = False
        self.g_bn0 = BatchNormalization(name = 'g_bn0')

    def __call__(self, z, training=False):
        with tf.variable_scope('g', reuse=self.reuse):
            fc0 = full_connection_layer(z, 7*7*512, name="fc0")
            fc0 = tf.reshape(fc0, [-1, 7, 7, 512])

            batch_size = tf.shape(fc0)[0]
            deconv0 = deconv2d_layer(fc0, [batch_size, 14, 14, 256], kernel_size=4, name="deconv0")
            deconv0 = self.g_bn0(deconv0)
            deconv0 = lrelu(deconv0, leak=0.3)

            deconv1 = deconv2d_layer(deconv0, [batch_size, 28, 28, 1], kernel_size=4, name="deconv1")
            deconv1 = tf.nn.tanh(deconv1)
            output = deconv1 

        self.reuse = True
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='g')
        return output

In [3]:
class Discriminator:
    def __init__(self):
        self.reuse = False
        self.d_bn0 = BatchNormalization(name="d_bn0")
        self.d_bn1 = BatchNormalization(name="d_bn1")
        
    def __call__(self, x,training=False, name=''):
        with tf.variable_scope('d', reuse=self.reuse):
            x = tf.reshape(x, [-1, 28, 28, 1])

            conv1 = conv2d_layer(x, 128, kernel_size=4, name="d_conv0")
            conv1 = self.d_bn0(conv1)
            conv1 = lrelu(conv1, leak=0.3)
            
            conv2 = conv2d_layer(conv1, 256, kernel_size=4, name="d_conv1")
            conv2 = self.d_bn1(conv2)
            conv2 = lrelu(conv2, leak=0.3)
            conv2 = tf.contrib.layers.flatten(conv2)

            fc0 = full_connection_layer(conv2, 512, name="fc0")
            fc0 = lrelu(fc0)

            fc1 = full_connection_layer(fc0, 128, name="fc1")
            fc1 = lrelu(fc1)

            disc = full_connection_layer(fc1, 1, name = 'disc')

        self.reuse = True
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='d')

        return disc

In [4]:
class GAN:
    def __init__(self):
        self.batch_size = 64
        self.img_size = 28
        self.z_size = 50
        
        self.epochs = 50000
        self.epoch_saveMetrics = 1000
        self.epoch_saveSampleImg = 1000
        self.epoch_saveParamter = 5000
        self.losses = {"d_loss":[], "g_loss":[]}

        # unrolled counts
        self.steps = 5

        self.X_tr = tf.placeholder(tf.float32, shape=[None, self.img_size, self.img_size, 1])
        self.z = tf.placeholder(tf.float32, [None, self.z_size])
        
        self.g = Generator()
        self.d = Discriminator()
        self.Xg = self.g(self.z)

    def loss(self):
        disc_tr = self.d(self.X_tr)
        disc_gen = self.d(self.Xg)
        
        loss_d = - tf.reduce_mean(disc_tr) + tf.reduce_mean(disc_gen)
        loss_g = - tf.reduce_mean(disc_gen)

        return loss_g, loss_d

    def train(self):
        # Optimizer

        self.L_g, self.L_d = self.loss()

        d_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        d_train_op = d_opt.minimize(self.L_d, var_list=self.d.variables)
        g_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        g_train_op = g_opt.minimize(self.L_g, var_list=self.g.variables)
        
        self.clip_updates = [w.assign(tf.clip_by_value(w, -0.01, 0.01)) for w in self.d.variables]

        saver = tf.train.Saver()
        
        config = tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                visible_device_list= "1"
            )
        )
                
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())

            # preparing noise vec for test
            bs = 100
            test_z = np.random.uniform(-1, 1, size=[bs, self.z_size])

            for epoch in range(self.epochs):
                for step in range(self.steps):
                    # extract images for training
                    X_mb, _ = mnist.train.next_batch(self.batch_size)
                    X_mb = np.reshape(X_mb, [-1, 28, 28, 1])
                    z = np.random.uniform(-1, 1, size=[self.batch_size, self.z_size])

                    sess.run([self.clip_updates])

                    # train Discriminator
                    _, d_loss_value = sess.run([d_train_op, self.L_d], feed_dict={
                        self.X_tr: X_mb,
                        self.z:z,
                    })

                # extract images for training
                X_mb, _ = mnist.train.next_batch(self.batch_size)
                X_mb = np.reshape(X_mb, [-1, 28, 28, 1])
                z = np.random.uniform(-1, 1, size=[self.batch_size, self.z_size])

                # train Generator
                _, g_loss_value = sess.run([g_train_op, self.L_g], feed_dict={
                    self.X_tr: X_mb,
                    self.z:z,
                })

                # append loss value for visualizing
                self.losses["d_loss"].append(np.sum(d_loss_value))
                self.losses["g_loss"].append(np.sum(g_loss_value))
                
                # print epoch
                if epoch % 50 == 0:
                    print('epoch:{0}, d_loss:{1}, g_loss{2} '.format(epoch, d_loss_value, g_loss_value))
                
                # visualize loss
                if epoch % self.epoch_saveMetrics == 0:
                    save_metrics(model_name, self.losses, epoch)

                # visualize generated images during training
                if epoch % self.epoch_saveSampleImg == 0:
                    img = sess.run(self.Xg, feed_dict={self.z: test_z})
                    save_imgs(model_name, img, name=str(epoch))

                # save model parameters 
                if epoch % self.epoch_saveParamter == 0:
                    dir_path = "model_" + model_name
                    if not os.path.isdir(dir_path):
                        os.makedirs(dir_path)

                    saver.save(sess, dir_path + "/" + str(epoch) + ".ckpt")

In [5]:
gan = GAN()
gan.train()



epoch:0, d_loss:7.80474510975182e-06, g_loss4.3148029362782836e-05 
epoch:50, d_loss:-0.03192339837551117, g_loss0.014821100980043411 
epoch:100, d_loss:-2.600034236907959, g_loss1.644270658493042 
epoch:150, d_loss:-4.017401218414307, g_loss2.507479190826416 
epoch:200, d_loss:-2.3247578144073486, g_loss1.544884204864502 
epoch:250, d_loss:-2.342836618423462, g_loss2.1563286781311035 
epoch:300, d_loss:-2.61594295501709, g_loss1.780001163482666 
epoch:350, d_loss:-2.146009683609009, g_loss2.1056604385375977 
epoch:400, d_loss:-2.3375797271728516, g_loss1.012495994567871 
epoch:450, d_loss:-2.00949764251709, g_loss0.7387288808822632 
epoch:500, d_loss:-1.9537532329559326, g_loss0.7506390810012817 
epoch:550, d_loss:-2.028782606124878, g_loss1.0892536640167236 
epoch:600, d_loss:-1.847143530845642, g_loss0.6226832866668701 
epoch:650, d_loss:-1.5504164695739746, g_loss1.971895694732666 
epoch:700, d_loss:-1.6994699239730835, g_loss1.8043723106384277 
epoch:750, d_loss:-1.731435775756836

KeyboardInterrupt: 