In [1]:
'''
This program is aim at:
1. implementing GAN and solve MNIST problem
2. trying to understand how GAN works:VanillaGAN, InfoGAN, W-GAN, ...
3. trying DCGAN, and find out wether sharing some of the feature spaces bettwen generater and discriminater helps
4. trying different training process to see which ones are more effecient and which ones fails and how
'''

'\nThis program is aim at:\n1. implementing GAN and solve MNIST problem\n2. trying to understand how GAN works\n3. trying DCGAN, and find out wether sharing some of the feature spaces bettwen generater and discriminater helps\n4. trying different training process to see which ones are more effecient and which ones fails and how\n'

In [2]:
import math
import numpy as np
import os

import tensorflow as tf

from lib.netgen import full_connected
from lib.ops import *

import lib.global_config as CONF
TENSORBOARD_ROOT = os.path.join(CONF.SYS_ROOT, "tensorboard")


In [3]:
# MNIST dataset
from tensorflow.examples.tutorials.mnist import input_data

# import mnist data sets
mnist = input_data.read_data_sets(os.path.join(CONF.SYS_ROOT, "datasets/mnist_data/"), one_hot=True)

# input data is a 28X28=784 note vector
IMAGE_WIDTH = 28
IMAGE_HEIGHT = 28
INPUT_SIZE = IMAGE_WIDTH * IMAGE_HEIGHT
# output data is a 10-class label verctor
LABEL_SIZE = 10

Extracting /notebooks/datasets/mnist_data/train-images-idx3-ubyte.gz
Extracting /notebooks/datasets/mnist_data/train-labels-idx1-ubyte.gz
Extracting /notebooks/datasets/mnist_data/t10k-images-idx3-ubyte.gz
Extracting /notebooks/datasets/mnist_data/t10k-labels-idx1-ubyte.gz


In [5]:
########### Abstractions ###########
class AbsGAN(object):
    def __init__(self, name):
        self.name = name
        
        self._discriminater = None
        self._generater = None
        
        self.d_optm = None
        self.g_optm = None
        self.d_loss = None
        self.g_loss = None
        
        self.d_summaries = None
        self.g_summaries = None
        
    def _setup(self):
        raise Exception("Not yet implemented!")
        
    def setup(self):
        self._setup()
        
        assert self._discriminater is not None
        assert self._generater is not None
        assert self.d_optm is not None
        assert self.g_optm is not None
        assert self.d_loss is not None
        assert self.d_r_loss is not None
        assert self.d_f_loss is not None
        assert self.g_loss is not None
        
    def get_optms(self):
        return self.d_optm, self.g_optm
    
    def get_loss(self):
        return self.d_loss, self.d_r_loss, self.d_f_loss, self.g_loss
    
    def create_input_summary(self):
        raise Exception("Not yet implemented!")
        
    def get_summaries(self):
        if self.d_summaries is not None:
            assert self.g_summaries is not None
            
            return self.d_summaries, self.g_summaries
        
        d_input_summ, g_input_summ = self.create_input_summary()
        
        self.d_summaries = tf.summary.merge([
            self._discriminater.summary(),
            d_input_summ
        ])
        
        self.g_summaries = tf.summary.merge([
            self._generater.summary(),
            g_input_summ
        ])
        
        return self.d_summaries, self.g_summaries
    
    
class AbsModel(object):
    def __init__(self, name, type_alias):
        self.name = name
        self.type_alias = type_alias
        
        self.summaries = []
        self.merged_summary = None
        
        self.vars_initialized = False
            
    def namescope(self):
        return "{}_{}".format(self.name, self.type_alias)

    def varscope(self):
        return "{}_{}_vars".format(self.name, self.type_alias)
    
    def reuse_vars(self, vs):
        if self.vars_initialized:
            print ">> Reusing variables in {} <<\n".format(vs.name)
            vs.reuse_variables()
        else:
            print ">> Initializing variables in {} <<\n".format(vs.name)
            self.vars_initialized = True
        
    def add_summary(self, *s):
        self.summaries.extend(s)
        
    def summary(self):
        if self.merged_summary is not None:
            return self.merged_summary

        self.merged_summary = tf.summary.merge(self.summaries)
        return self.merged_summary
    
    
class GeneraterSampler(object):
    def __init__(self, z, sample_num, z_dim):
        self.sample_num = sample_num
        self.z = z
        self.z_dim = z_dim
        
    def setup(self, fake_images):
        self.fake_image_sammary = tf.summary.image("Fake_Images_Sample", fake_images, self.sample_num)
        
    def sample(self, sess, i):
        sample_z = np.random.uniform(-1, 1, size=(self.sample_num, self.z_dim))
        return sess.run(self.fake_image_sammary, feed_dict={self.z: sample_z})

In [6]:
############## Basic GAN, only distinguish fake/real ##############
class VanillaGAN(AbsGAN):
    D_LEARNING_RATE = 0.00005
    D_FEATURE_DEPTH_STARTER = 32
    D_KSIZE = 3
    D_FEATURE_SPACE_DEPTH = 5
    
    G_LEARNING_RATE = 0.00005
    G_FEATURE_DEPTH_STARTER = 32
    G_KSIZE = 3
    G_FEATURE_SPACE_DEPTH = 5
    
    ## Discriminator ##
    class D(AbsModel):
        def __init__(self, name, image_size, n_channels, dim_y, dim_z):
            AbsModel.__init__(self, name, "D")
            
            self.image_size = image_size
            self.n_channels = n_channels
            self.dim_y = dim_y
            self.dim_z = dim_z
            
        def define_graph(self, images):
            with tf.variable_scope(self.varscope()) as vs:
                self.reuse_vars(vs)
                
                conv_in_activations = []
                conv_in_weights = []
                conv_in_bias = []
                conv_out_activations = []
                conv_out_weights = []
                conv_out_bias = []
                    
                # FIXME: should not let tf do reshaping
                input_image_shape = images.get_shape().as_list()
                if input_image_shape[1] != self.image_size or input_image_shape[2] != self.image_size:
                    head = tf.image.resize_images(
                        images, [self.image_size, self.image_size]
                    )
                else:
                    head = images
                    
                #### in ####
                featured_layers = []
                feature_depth = VanillaGAN.D_FEATURE_DEPTH_STARTER
                for _ in range(VanillaGAN.D_FEATURE_SPACE_DEPTH - 1):
                    head, layer, a, w, b = conv2d_pool("conv_in_{}".format(feature_depth),
                                                       head,
                                                       VanillaGAN.D_KSIZE, feature_depth,
                                                       activation=tf.nn.sigmoid,
                                                       more=True)
                    
                    conv_in_activations.append(a)
                    conv_in_weights.append(w)
                    conv_in_bias.append(b)
                    
                    featured_layers.append(layer)
                    feature_depth *= 2

                #### turning ####
                head, a, w, b = conv2d_norm("conv_in_{}".format(feature_depth),
                                            head,
                                            VanillaGAN.D_KSIZE, feature_depth,
                                            activation=tf.nn.sigmoid,
                                            more=True)
                
                conv_in_activations.append(a)
                conv_in_weights.append(w)
                conv_in_bias.append(b)

                #### out ####
                featured_layers.reverse()
                for i in range(0, len(featured_layers)):
                    feature_depth /= 2
                    head, a, w, b = cat_conv2d("conv_out_{}".format(feature_depth),
                                               [featured_layers[i], head],
                                               VanillaGAN.D_KSIZE, feature_depth,
                                               activation=tf.nn.relu,
                                               more=True)
                    
                    conv_out_activations.append(a)
                    conv_out_weights.append(w)
                    conv_out_bias.append(b)

                #### flatten ####
                flattened, node_num, w, b = conv2d_one("flatten",
                                                       head,
                                                       activation=tf.nn.sigmoid,
                                                       more=True)

                #### full-connected ####
                fc = full_connected([
                    (self.dim_y, tf.nn.relu)
                ])

                outputs, fc_ws, fc_bs, fc_outs = fc.gen(flattened, node_num, name="fc")
                
            # summary
            with tf.name_scope(self.namescope()):
                self.add_summary(
                    tf.summary.histogram("weights_flatten", w),
                    tf.summary.histogram("biases_flatten", b),
                    tf.summary.histogram("activations_flatten", flattened)
                )

                for fc_w in fc_ws:
                    self.add_summary( tf.summary.histogram("weights_fc", fc_w) )
                for fc_b in fc_bs:
                    self.add_summary( tf.summary.histogram("biases_fc", fc_b) )
                for fc_out in fc_outs:
                    self.add_summary( tf.summary.histogram("outputs_fc", fc_out) )

                for a in conv_in_activations:
                    self.add_summary( tf.summary.histogram("conv_in_activation", a) )
                for w in conv_in_weights:
                    self.add_summary( tf.summary.histogram("conv_in_weights", w) )
                for b in conv_in_bias:
                    self.add_summary( tf.summary.histogram("conv_in_bias", b) )
                for a in conv_out_activations:
                    self.add_summary( tf.summary.histogram("conv_out_activation", a) )
                for w in conv_out_weights:
                    self.add_summary( tf.summary.histogram("conv_out_weights", w) )
                for b in conv_out_bias:
                    self.add_summary( tf.summary.histogram("conv_out_bias", b) )
                
            return outputs
            
        def define_optimize(self, real_output, fake_output):
            with tf.name_scope(self.namescope() + "_real_loss"):
                real_loss = cross_entropy(real_output, tf.ones_like(real_output))

            with tf.name_scope(self.namescope() + "_fake_loss"):
                fake_loss = cross_entropy(fake_output, tf.zeros_like(fake_output))

            # D loss
            with tf.name_scope(self.namescope() + "_loss") as ns:
                loss = real_loss + fake_loss
            
            # find trainable variables
            trainable = [var for var in tf.trainable_variables() if self.varscope() in var.name]
            print "Training variables in {}:".format(self.varscope())
            for var in trainable:
                print var
                
            with tf.name_scope(self.namescope() + "_optimizer") as ns:
                optim = tf.train.AdamOptimizer(
                    VanillaGAN.D_LEARNING_RATE
                ).minimize(
                    loss, var_list=trainable
                )
            
            # summary
            with tf.name_scope(self.namescope()):
                self.add_summary( tf.summary.scalar("loss_real_discriminater", real_loss) )
                self.add_summary( tf.summary.scalar("loss_fake_discriminater", fake_loss) )
                self.add_summary( tf.summary.scalar("loss_discriminater", loss) )
            
            self.loss = loss
            
            return optim, self.loss, real_loss, fake_loss
        ######### End Discriminater #########
        
        
    ## Generator ##
    class G(AbsModel):
        def __init__(self, name, image_size, n_channels, dim_z, batch_size):
            AbsModel.__init__(self, name, "G")
            
            self.image_size = image_size
            self.n_channels = n_channels
            self.dim_z = dim_z
            self.batch_size = batch_size
            
            
        def define_graph(self, z):
            deconv_activations = []
            deconv_weights = []
            deconv_bias = []
            
            with tf.variable_scope(self.varscope()) as vs:
                self.reuse_vars(vs)
                    
                # find out each de-conv's shape
                fake_image_shape = [self.image_size, self.image_size, self.n_channels]
                feature_layers_shapes = []
                
                def shape_conv2d(input_shape, feature_depth):
                    return [input_shape[0] / 2, input_shape[1] / 2, feature_depth]
                
                prev_shape = fake_image_shape
                feature_depth = VanillaGAN.G_FEATURE_DEPTH_STARTER
                for _ in range(VanillaGAN.G_FEATURE_SPACE_DEPTH):
                    next_shape = shape_conv2d(prev_shape, feature_depth)
                    feature_layers_shapes.append(next_shape)
                    prev_shape = next_shape
                    feature_depth *= 2
                    
                feature_layers_shapes.reverse()
                
                # full-connect to the last feature layers
                last_layer_shape = feature_layers_shapes[0]
                last_layer_len = last_layer_shape[0] * last_layer_shape[1] * last_layer_shape[2]
                
                fc = full_connected([
                    (last_layer_len, tf.nn.relu)
                ])

                last_layer, fc_ws, fc_bs, fc_outs = fc.gen(z, self.dim_z, name="fc")
                
                # un-flatten
                conv_layer = tf.reshape(last_layer, [-1] + last_layer_shape)
                
                # de-conv
                batch_size = -1 #TODO: maybe need to be given?
                for i in range(1, len(feature_layers_shapes)):
                    output_image_shape = feature_layers_shapes[i]
                    output_shape = [self.batch_size] + output_image_shape
                    
                    conv_layer, a, w, b = deconv2d("{}x{}x{}".format(*output_image_shape),
                                                   conv_layer, output_shape, VanillaGAN.G_KSIZE,
                                                   activation=tf.nn.relu,
                                                   more=True)
                    
                    deconv_activations.append(a)
                    deconv_weights.append(w)
                    deconv_bias.append(b)
                    
                # final image
                fake_images, fi_a, fi_w, fi_b = deconv2d("fake_images",
                                                conv_layer,
                                                [self.batch_size] + fake_image_shape,
                                                VanillaGAN.G_KSIZE,
                                                more=True)
                
                fake_images = tf.nn.tanh(fake_images)
                
            # summary
            with tf.name_scope(self.namescope()):
                for fc_out in fc_outs:
                    self.add_summary( tf.summary.histogram("fc_activations", fc_out) )
                for fc_w in fc_ws:
                    self.add_summary( tf.summary.histogram("fc_weights", fc_w) )
                for fc_b in fc_bs:
                    self.add_summary( tf.summary.histogram("fc_bias", fc_b) )
                    
                for a in deconv_activations:
                    self.add_summary( tf.summary.histogram("deconv_activations", a) )
                for w in deconv_weights:
                    self.add_summary( tf.summary.histogram("deconv_weights", w) )
                for b in deconv_bias:
                    self.add_summary( tf.summary.histogram("deconv_bias", b) )
                    
                self.add_summary( tf.summary.histogram("fake_image_activations", fi_a) )
                self.add_summary( tf.summary.histogram("fake_image_weights", fi_w) )
                self.add_summary( tf.summary.histogram("fake_image_bias", fi_b) )
                
            return fake_images
        
        def define_optimize(self, fake_output):
            with tf.name_scope(self.namescope() + "_loss") as ns:
                loss = cross_entropy(
                    fake_output, tf.ones_like(fake_output)
                )
                
            # find trainable variables
            trainable = [var for var in tf.trainable_variables() if self.varscope() in var.name]
            print "Training variables in {}:".format(self.varscope())
            for var in trainable:
                print var
            
            with tf.name_scope(self.namescope() + "_optimize") as ns:
                optim = tf.train.AdamOptimizer(
                    VanillaGAN.G_LEARNING_RATE
                ).minimize(
                    loss,
                    var_list=trainable
                )
            
            with tf.name_scope(self.namescope()):
                self.add_summary(tf.summary.scalar("loss_generater", loss))
            
            self.loss = loss
            
            return optim, self.loss
        ################ End Generater ###############
    
    
    def __init__(self, name):
        super(VanillaGAN, self).__init__(name)
        self.fake_images = None
        
    def init(self, raw_image_width, raw_image_height, image_size, n_channels, dim_y, dim_z, batch_size):
        with tf.name_scope("inputs") as scope:
            self.z = tf.placeholder(tf.float32,
                                    shape=[None, dim_z],
                                    name="Z")
            self.real_images = tf.placeholder(tf.float32,
                                              shape=[None, raw_image_width, raw_image_height, n_channels],
                                              name="real_images")
            
        self._discriminater = VanillaGAN.D(self.name, image_size, n_channels, dim_y, dim_z)
        self._generater = VanillaGAN.G(self.name, image_size, n_channels, dim_z, batch_size)
        self._sampler = GeneraterSampler(self.z, batch_size, dim_z)
        
    def setup(self):
        self.fake_images = self._generater.define_graph(self.z)
        
        real_output = self._discriminater.define_graph(self.real_images)
        fake_output = self._discriminater.define_graph(self.fake_images)
        
        self.g_optm, self.g_loss = self._generater.define_optimize(fake_output)
        self.d_optm, self.d_loss, self.d_r_loss, self.d_f_loss = self._discriminater.define_optimize(real_output, fake_output)
        
        self._sampler.setup(self.fake_images)
        
    def create_input_summary(self):
        assert self.z is not None 
        assert self.real_images is not None
        assert self.fake_images is not None
        
        # summary fake imagess
        summary_fake_images_d = tf.summary.image(
            "Fake_Images_Discriminated",
            self.fake_images, 5)
        summary_fake_images_g = tf.summary.image(
            "Fake_Images_Generated",
            self.fake_images, 5)
        
        sum_hist_fake_iamge_d = tf.summary.histogram("Fake_Image_Hist_D", self.fake_images)
        sum_hist_fake_iamge_g = tf.summary.histogram("Fake_Image_Hist_G", self.fake_images)
        
        summary_z = tf.summary.histogram("Z", self.z)
        summary_real_images = tf.summary.image("Real_Images", self.real_images, 5)
        
        d_summ = tf.summary.merge([
            summary_z,
            summary_real_images,
            summary_fake_images_d,
            sum_hist_fake_iamge_d
        ])
        g_summ = tf.summary.merge([
            summary_z,
            summary_fake_images_g,
            sum_hist_fake_iamge_g
        ])
        
        return d_summ, g_summ
        
    def feed_discriminater(self, **kwarg):
        images = kwarg["images"]
        z = kwarg["z"]
        
        return {self.real_images: images, self.z: z}
    
    def feed_generater(self, **kwarg):
        z = kwarg["z"]
        
        return {self.z: z}
    
    def sample(self, sess, i):
        return self._sampler.sample(sess, i)
    
   

In [7]:
############## Info-GAN,distingwish fake/real & labels ##############
class InfoGAN:
    pass

In [None]:
# Training configurations
DEFAULT_CONF = {
    "TRAINING_TIMES": 30000,
    "SUMMARY_FREQ": 50,
    "MODEL_SAVE_FREQ": 500,
    "BATCH_SIZE": 150,
    "RAW_IMAGE_WIDTH": 28,
    "RAW_IMAGE_HEIGHT": 28,
    "IMAGE_SIZE": 32,
    "N_CHAN": 1,
    "Y_DIM": 10,
    "Z_DIM": 10,
    "SAMPLE_FREQ": 200,
    "TRAINING_MODE": "D1G2", # "D1G1", "D1G2", "LOSS_THRESHOLD" ...
    # loss threshold arguments
    "LOSS_THRESHOLD_STARTER": None,
    "LOSS_THRESHOLD_FACTOR": 0.8,
    "DISCRIMINATER_FIRST": True
}

# Training
def train(task_name, gan, conf={}):
    # configurations
    conf = merge_conf(conf, DEFAULT_CONF)
    
    tf.reset_default_graph()
    
    # initialize gan
    #TODO: need to switch gan types here while doing gan.init
    gan.init(conf["RAW_IMAGE_WIDTH"],
             conf["RAW_IMAGE_HEIGHT"],
             conf["IMAGE_SIZE"],
             conf["N_CHAN"],
             conf["Y_DIM"],
             conf["Z_DIM"],
             conf["BATCH_SIZE"])
    
    if conf["TRAINING_MODE"] == "LOSS_THRESHOLD":
        loss_threshold = conf["LOSS_THRESHOLD_STARTER"]
        D_turn = conf["DISCRIMINATER_FIRST"]
    
    D_loss = None
    G_loss = None
    
    with tf.Session() as sess:
        print "Setup..."
        gan.setup()
        
        # prepare D & G
        D_optm, G_optm = gan.get_optms()
        D_l, D_r_l, D_f_l, G_l = gan.get_loss()
        D_summ, G_summ = gan.get_summaries()
        
        # tensorboard & saver
        tensorboard_save_path = os.path.join(TENSORBOARD_ROOT, task_name)
        writer = tf.summary.FileWriter(tensorboard_save_path)
        saver = tf.train.Saver()
        
        # record graph
        writer.add_graph(sess.graph)

        # initialize variables
        init = tf.global_variables_initializer()
        init.run()

        print "Now training..."
        print "Run `tensorboard --logdir=%s` to see more information." % TENSORBOARD_ROOT
        # training process
        for i in xrange(conf["TRAINING_TIMES"]):
            # batching data & format data
            batch_inputs, batch_labels = mnist.train.next_batch(conf["BATCH_SIZE"])
            batch_images = np.reshape(batch_inputs, 
                                      (conf["BATCH_SIZE"],
                                       conf["RAW_IMAGE_WIDTH"],
                                       conf["RAW_IMAGE_HEIGHT"],
                                       conf["N_CHAN"]))
            
            batch_z_D = np.random.uniform(-1, 1, size=(conf["BATCH_SIZE"], conf["Z_DIM"]))
            batch_z_G = np.random.uniform(-1, 1, size=(conf["BATCH_SIZE"], conf["Z_DIM"]))
            
            feed_dict_D = gan.feed_discriminater(images=batch_images, z=batch_z_D)
            feed_dict_G = gan.feed_generater(z=batch_z_G)
            
            # optimize D
            def optimize_d():
                if i % conf["SUMMARY_FREQ"] == 0:
                    _, D_loss, s = sess.run([D_optm, D_r_l, D_summ], feed_dict=feed_dict_D)
                    writer.add_summary(s, i)
                    # print "Add summary for Discriminater at {}".format(i)
                    print "D loss:[{}]".format(D_loss)
                else:
                    _, D_loss = sess.run([D_optm, D_l], feed_dict=feed_dict_D)
                    
            # optimize G
            def optimize_g():
                if i % conf["SUMMARY_FREQ"] == 0:
                    _, G_loss, s = sess.run([G_optm, G_l, G_summ], feed_dict=feed_dict_G)
                    writer.add_summary(s, i)
                    # print "Add summary for Generater at {}".format(i)
                    print "G loss:[{}]".format(G_loss)
                else:
                    _, G_loss = sess.run([G_optm, G_l], feed_dict=feed_dict_G)    
               
            # ---------- training mode : LOSS_THRESHOLD ---------- 
            if conf["TRAINING_MODE"] == "LOSS_THRESHOLD":
                if D_turn:
                    optimize_d()
                else:
                    optimize_g()

                # decide turning
                if loss_threshold is None:
                    if D_turn:
                        loss_threshold = D_loss * conf["LOSS_THRESHOLD_FACTOR"]
                    else:
                        loss_threshold = G_loss * conf["LOSS_THRESHOLD_FACTOR"]

                if D_turn and D_loss < loss_threshold:
                    D_turn = False
                    loss_threshold = None
                    print "Turning to G at {} with loss [{}]".format(i, D_loss)
                elif not D_turn and G_loss < loss_threshold:
                    D_turn = True
                    loss_threshold = None
                    print "Turning to D at {} with loss [{}]".format(i, G_loss)
                    
            # ---------- training mode : D1G2 ----------
            elif conf["TRAINING_MODE"] == "D1G2":
                optimize_d()
                optimize_g()
                optimize_g()
                
            # ---------- training mode : D1G1 ----------
            else:
                optimize_d()
                optimize_g()
                
            # sampling
            if i % conf["SAMPLE_FREQ"] == 0:
                s = gan.sample(sess, i)
                writer.add_summary(s, i)
                
            # save model
            if i % conf["MODEL_SAVE_FREQ"] == 0:
                saver.save(sess, os.path.join(TENSORBOARD_ROOT, "model_{}.ckpt".format(task_name)), i)
                # print "Model saved at {}".format(i)
                
            if i % int(conf["TRAINING_TIMES"] / 50) == 0:
                print "processing[{}%]...".format(i * 10000 / conf["TRAINING_TIMES"] / 100.0)
                
        print "All processes done."
                

In [None]:
train("VanillaGAN-20171023-1505", VanillaGAN("GAN3"))

Setup...
>> Initializing variables in GAN3_G_vars <<

>> Initializing variables in GAN3_D_vars <<

generating convolution layer:conv_in_32
generating convolution layer:conv_in_64
generating convolution layer:conv_in_128
generating convolution layer:conv_in_256
generating convolution layer:conv_in_512
generating convolution layer:conv_out_256
generating convolution layer:conv_out_128
generating convolution layer:conv_out_64
generating convolution layer:conv_out_32
generating convolution layer:flatten
>> Reusing variables in GAN3_D_vars <<

generating convolution layer:conv_in_32
generating convolution layer:conv_in_64
generating convolution layer:conv_in_128
generating convolution layer:conv_in_256
generating convolution layer:conv_in_512
generating convolution layer:conv_out_256
generating convolution layer:conv_out_128
generating convolution layer:conv_out_64
generating convolution layer:conv_out_32
generating convolution layer:flatten
Training variables in GAN3_G_vars:
<tf.Variable '