# infoGAN

In [None]:
import tensorflow as tf
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)

In [None]:
a,b  = mnist.train.next_batch(1)
print(a.shape)
print(b.shape)
print(b)

# utility func

In [None]:
# save metrics
def save_metrics(metrics, epoch=None):
    # make directory if there is not
    path = "metrics_gpu0"
    if not os.path.isdir(path):
        os.makedirs(path)

    # save metrics
    plt.figure(figsize=(10,8))
    plt.plot(metrics["d_loss"], label="discriminative loss", color="b")
    plt.legend()
    plt.savefig(os.path.join(path, "dloss" + str(epoch) + ".png"))
    plt.close()

    plt.figure(figsize=(10,8))
    plt.plot(metrics["g_loss"], label="generative loss", color="r")
    plt.legend()
    plt.savefig(os.path.join(path, "g_loss" + str(epoch) + ".png"))
    plt.close()

    plt.figure(figsize=(10,8))
    plt.plot(metrics["g_loss"], label="generative loss", color="r")
    plt.plot(metrics["d_loss"], label="discriminative loss", color="b")
    plt.legend()
    plt.savefig(os.path.join(path, "both_loss" + str(epoch) + ".png"))
    plt.close()

In [None]:
# plot images
def save_imgs(images, plot_dim=(10,10), size=(10,10), name=None):
    # make directory if there is not
    path = "generated_figures_gpu0"
    if not os.path.isdir(path):
        os.makedirs(path)

    num_examples = plot_dim[0]*plot_dim[1]
    num_examples = 100
    fig = plt.figure(figsize=size)

    for i in range(num_examples):
        plt.subplot(plot_dim[0], plot_dim[1], i+1)
        img = images[i, :]
        img = img.reshape((28, 28))
        plt.tight_layout()
        plt.imshow(img, cmap="gray")
        plt.axis("off")
    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.savefig(os.path.join(path, str(name) + ".png"))
    plt.close()

In [None]:
# training
import pickle
import numpy as np
import os

def unpickle(file):
    fo = open(file, 'rb')
    #print(file)
    dict = pickle.load(fo, encoding='latin1')
    fo.close()
    return dict

def one_hot_vec(label):
    vec = np.zeros(10)
    vec[label] = 1
    return vec

def load_data():
    x_all = []
    y_all = []
    for i in range (5):
        d = unpickle("cifar-10-batches-py/data_batch_" + str(i+1))
        x_ = d['data']
        y_ = d['labels']
        x_all.append(x_)
        y_all.append(y_)

    d = unpickle('cifar-10-batches-py/test_batch')
    x_all.append(d['data'])
    y_all.append(d['labels'])

    x = -0.5 + (np.concatenate(x_all) / np.float32(255))
    y = np.concatenate(y_all)
    x = np.dstack((x[:, :1024], x[:, 1024:2048], x[:, 2048:]))
    x = x.reshape((x.shape[0], 32, 32, 3))

    #pixel_mean = np.mean(x[0:50000],axis=0)
    #x -= pixel_mean
    y = np.array(list(map(one_hot_vec, y)))
    X_train = x[0:50000,:,:,:]
    Y_train = y[0:50000]
    #X_test = x[50000:,:,:,:]
    #Y_test = y[50000:]

    #return (X_train, Y_train, X_test, Y_test)
    return X_train, Y_train

# model

In [None]:
# convolution/pool stride
_CONV_KERNEL_STRIDES_ = [1, 2, 2, 1]
_DECONV_KERNEL_STRIDES_ = [1, 2, 2, 1]
_REGULAR_FACTOR_ = 1.0e-4

def conv2d_layer(input_layer, output_dim, kernel_size = 3, stddev = 0.02, name = 'conv2d'):
    with tf.variable_scope(name):
        init_weight = tf.truncated_normal_initializer(mean = 0.0, stddev = stddev, dtype = tf.float32)
        filter_size = [kernel_size, kernel_size, input_layer.get_shape()[-1], output_dim]
        weight = tf.get_variable(
            name = name + 'weight',
            shape = filter_size,
            initializer = init_weight,
            regularizer = tf.contrib.layers.l2_regularizer(_REGULAR_FACTOR_))
        bias = tf.get_variable(
            name = name + 'bias',
            shape = [output_dim],
            initializer = tf.constant_initializer(0.0))
        conv = tf.nn.conv2d(input_layer, weight, _CONV_KERNEL_STRIDES_, padding = 'SAME')
        conv = tf.nn.bias_add(conv, bias)
        return conv

def deconv2d_layer(input_layer, output_shape, kernel_size = 2, stddev = 0.02, name = 'deconv'):
    with tf.variable_scope(name):
        init_weight = tf.truncated_normal_initializer(mean = 0.0, stddev = stddev, dtype = tf.float32)
        filter_size = [kernel_size, kernel_size, output_shape[-1], input_layer.get_shape()[-1]]
        weight = tf.get_variable(
            name = name + 'weight',
            shape = filter_size,
            initializer = init_weight,
            regularizer = tf.contrib.layers.l2_regularizer(_REGULAR_FACTOR_))
        bias = tf.get_variable(
            name = name + 'bias',
            shape = [output_shape[-1]],
            initializer = tf.constant_initializer(0.0))
        deconv = tf.nn.conv2d_transpose(input_layer, weight, output_shape, strides = _DECONV_KERNEL_STRIDES_, padding = 'SAME')
        deconv = tf.nn.bias_add(deconv, bias)
        return deconv

def lrelu(input_layer, leak = 0.2, name = 'lrelu'):
    with tf.variable_scope(name):
        alpha1 = 0.5 * (1 + leak)
        alpha2 = 0.5 * (1 - leak)
        return alpha1 * input_layer + alpha2 * abs(input_layer)

def full_connection_layer(input_layer, output_dim, stddev = 0.02, name = 'fc'):
    # calculate input_layer dimension and reshape to batch * dimension
    input_dimension = 1
    for dim in input_layer.get_shape().as_list()[1:]:
        input_dimension *= dim

    with tf.variable_scope(name):
        init_weight = tf.truncated_normal_initializer(mean = 0.0, stddev = stddev, dtype = tf.float32)
        filter_size = [input_dimension, output_dim]
        weight = tf.get_variable(
            name = name + 'weight',
            shape = filter_size,
            initializer = init_weight,
            regularizer = tf.contrib.layers.l2_regularizer(_REGULAR_FACTOR_))
        bias = tf.get_variable(
            name = name + 'bias',
            shape = [output_dim],
            initializer = tf.constant_initializer(0.0))
        input_layer_reshape = tf.reshape(input_layer, [-1, input_dimension])
        fc = tf.matmul(input_layer_reshape, weight)
        tc = tf.nn.bias_add(fc, bias)
        return fc

class BatchNormalization:
    def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
        with tf.variable_scope(name):
            self.epsilon  = epsilon
            self.momentum = momentum
            self.name = name

    def __call__(self, x, train=True):
        return tf.contrib.layers.batch_norm(x,
                decay=self.momentum, 
                updates_collections=None,
                epsilon=self.epsilon,
                scale=True,
                is_training=train,
                scope=self.name)


In [None]:
class Generator:
    def __init__(self):
        self.reuse = False
        self.initializer = tf.contrib.layers.xavier_initializer()
        self.X_dim = 28*28*1 
        self.z_dim = 22

        self.g_bn0 = BatchNormalization(name = 'g_bn0')
        self.g_bn1 = BatchNormalization(name = 'g_bn1')
        self.g_bn2 = BatchNormalization(name = 'g_bn2')

    def __call__(self, z, training=False):
        with tf.variable_scope('g', reuse=self.reuse):
            fc0 = full_connection_layer(z, 1024, name="fc0")
            fc0 = self.g_bn0(fc0)
            fc0 = tf.nn.relu(fc0)

            fc1 = full_connection_layer(fc0, 7*7*128, name="fc1")
            fc1 = tf.reshape(fc1, [-1, 7, 7, 128])
            fc1 = self.g_bn1(fc1)
            fc1 = tf.nn.relu(fc1)

            batch_size = tf.shape(fc1)[0]
            #batch_size = bs
            
            deconv0 = deconv2d_layer(fc1, [batch_size, 14, 14, 64], name="deconv0")
            deconv0 = self.g_bn2(deconv0)
            deconv0 = tf.nn.relu(deconv0)

            deconv1 = deconv2d_layer(deconv0, [batch_size, 28, 28, 1], name="deconv1")
            output = tf.nn.tanh(deconv1)

        self.reuse = True
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='g')
        return output

In [None]:
class Discriminator:
    def __init__(self, cat_size, con_size):
        self.reuse = False

        self.cat_size = cat_size
        self.con_size = con_size        

        self.d_bn0 = BatchNormalization(name="d_bn0")
        
    def __call__(self, x,training=False, name=''):
        def leaky_relu(x, leak=0.2, name='outputs'):
            return tf.maximum(x, x * leak, name=name)

        with tf.name_scope('d' + name), tf.variable_scope('d', reuse=self.reuse):
            x = tf.reshape(x, [-1, 28, 28, 1])

            conv1 = conv2d_layer(x, 64, name="d_conv0")
            conv1 = lrelu(conv1)
            
            conv2 = conv2d_layer(conv1, 128, name="d_conv1")
            conv2 = self.d_bn0(conv2)
            conv2 = lrelu(conv2)

            fc0 = full_connection_layer(conv2, 1024, name="fc0")
            fc0 = tf.nn.relu(fc0)
            
            fc1 = full_connection_layer(fc0, 128, name="fc1")
            fc1 = tf.nn.relu(fc1)

        with tf.name_scope('disc' + name), tf.variable_scope('disc', reuse=self.reuse):
            disc = full_connection_layer(fc1, 1, name = 'disc')

        with tf.name_scope('q_us' + name), tf.variable_scope('q_us', reuse=self.reuse):
            x = tf.reshape(x, [-1, 28, 28, 1])
            us_cat = full_connection_layer(fc1, self.cat_size, name = 'us_cat')
            us_con = full_connection_layer(fc1, self.con_size, name = 'us_con')

        with tf.name_scope('q_ss' + name), tf.variable_scope('q_ss', reuse=self.reuse):
            x = tf.reshape(x, [-1, 28, 28, 1])
            ss_cat = full_connection_layer(fc1, self.cat_size, name = 'ss_cat')
            ss_con = full_connection_layer(fc1, self.con_size, name = 'ss_con')      

            print('discriminator ouput dis:', disc.get_shape())
            print('discriminator ouput cat:', ss_cat.get_shape())
            print('discriminator ouput cont:', ss_con.get_shape())

        self.reuse = True
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='d')
        #self.disc_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='d')
        #+ tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='disc')
        #self.q_us_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='d') + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='q_us')
        #self.q_ss_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='d')
        #+ tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='q_ss')
        return disc, us_cat, us_con, ss_cat, ss_con

In [None]:
def one_hot_vec(label):
    vec = np.zeros(10)
    vec[label] = 1
    return vec

In [None]:
class GAN:
    def __init__(self):
        self.batch_size = 128
        self.img_size = 28
        self.cat_size = 10
        self.con_size = 2
        self.rand_size = 62

        self.epochs = 100000
        self.epoch_saveMetrics = 3000
        self.epoch_saveSampleImg = 3000
        self.epoch_saveParamter = 10000
        self.losses = {"d_loss":[], "g_loss":[]}

        self.X_tr = tf.placeholder(tf.float32, shape=[None, self.img_size, self.img_size, 1])
        self.cat_tr_label = self.cat_label = tf.placeholder(tf.int32, [None])
        self.cat_tr = tf.placeholder(tf.float32, [None, self.cat_size])
        self.cat_label = tf.placeholder(tf.int32, [None])
        self.cat = tf.placeholder(tf.float32, [None, self.cat_size])
        self.con = tf.placeholder(tf.float32, [None, self.con_size])
        self.Z1 = tf.placeholder(tf.float32, [None, self.rand_size+self.cat_size+self.con_size])
        
        self.g = Generator()
        self.d = Discriminator(self.cat_size, self.con_size)
        self.Xg = self.g(self.Z1)

        self.ss_prob = 0.1

    def loss(self):
        disc_tr, _, _, ss_cat_tr, ss_con_tr = self.d(self.X_tr)
        disc_gen, us_cat_gen, us_con_gen, _, _ = self.d(self.Xg)
        
        loss_d_tr = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_tr, labels=tf.ones_like(disc_tr)))
        loss_d_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_gen, labels=tf.zeros_like(disc_gen)))
        loss_d = (loss_d_tr + loss_d_gen)
        
        loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_gen, labels=tf.ones_like(disc_gen)))
        
        loss_us_cat = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=us_cat_gen, labels=self.cat_label))
        loss_us_con = tf.reduce_mean(tf.squared_difference(us_con_gen, self.con))

        part_batch_size = (int)(self.batch_size*self.ss_prob)
        loss_ss_cat_tr = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=ss_cat_tr[:part_batch_size], labels=self.cat_tr_label[:part_batch_size]))
        #loss_ss_con_tr = 0
        loss_ss_cat_gen = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=us_cat_gen[:part_batch_size], labels=self.cat_tr_label[:part_batch_size]))
        #loss_ss_con_gen = 0
        #tf.reduce_mean(tf.squared_difference(ss_con_gen[:self.batch_size*self.ss_prob], self.con[:self.batch_size*self.ss_prob]))

        d_cost = loss_d + (loss_us_cat + loss_us_con)*0.8 + (loss_ss_cat_tr + loss_ss_cat_gen)*2 
        g_cost = loss_g + (loss_us_cat + loss_us_con)*0.8 + (loss_ss_cat_tr + loss_ss_cat_gen)*2
    
        return g_cost, d_cost

    def train(self):
        # Optimizer
        d_lr = 1e-4
        g_lr = 1e-3

        self.L_g, self.L_d = self.loss()

        d_opt = tf.train.AdamOptimizer(learning_rate=d_lr)
        d_train_op = d_opt.minimize(self.L_d, var_list=self.d.variables)
        g_opt = tf.train.AdamOptimizer(learning_rate=g_lr)
        g_train_op = g_opt.minimize(self.L_g, var_list=self.g.variables)

        saver = tf.train.Saver()
        #%debug

        
        config = tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                visible_device_list="0"
            )
        )
                
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            epoch_pre = 0
            path = "model_infoGAN"
            #saver.restore(sess, path+"/dcgan_model" + str(epoch_pre) + ".ckpt")

            # visualizing categorical test
            bs = 60
            cat_test_z_rand = np.random.normal(0, 1, size=[bs, self.rand_size])
            cat_test_z_cat_label = np.repeat(np.arange(10), bs/10)
            cat_test_z_cat = np.array(list(map(one_hot_vec, cat_test_z_cat_label)))
            cat_test_z_con = np.random.normal(0, 1, size=[bs, self.con_size])
            cat_test_z = np.concatenate((cat_test_z_rand,cat_test_z_cat,cat_test_z_con), axis=1)

            # visualizing continuous factor 1
            bs = 100
            con1_test_z_rand = np.random.normal(0, 1, size=[bs,self.rand_size])
            con1_test_z_cat_label = np.repeat(np.arange(10), bs/10)
            con1_test_z_cat = np.array(list(map(one_hot_vec, con1_test_z_cat_label)))
            con1_test_z_con1 = np.tile(np.linspace(-1.0, 1.0, num=10), bs//10)[:, np.newaxis]
            con1_test_z_con2 = np.zeros([bs])[:, np.newaxis]
            con1_test_z = np.concatenate((con1_test_z_rand, con1_test_z_cat, con1_test_z_con1, con1_test_z_con2), axis=1)       

            # visualizing continuous factor 2
            bs = 100
            con2_test_z_rand = np.random.normal(0, 1, size=[bs,self.rand_size])
            con2_test_z_cat_label = np.repeat(np.arange(10), bs/10)
            con2_test_z_cat = np.array(list(map(one_hot_vec, con2_test_z_cat_label)))
            con2_test_z_con2 = np.tile(np.linspace(-1.0, 1.0, num=10), bs//10)[:, np.newaxis]
            con2_test_z_con1 = np.zeros([bs])[:, np.newaxis]
            con2_test_z = np.concatenate((con2_test_z_rand, con2_test_z_cat, con2_test_z_con1, con2_test_z_con2), axis=1)   

            for epoch in range(self.epochs):              
                for _ in range(1):
                    # 訓練データを抜粋
                    X_mb, Y_mb = mnist.train.next_batch(self.batch_size)
                    X_mb = np.reshape(X_mb, [-1, 28, 28, 1])
                    X_mb = (X_mb-0.5)*2.0
                    Y_mb_label = np.argmax(Y_mb, axis=1)

                    z_rand = np.random.normal(-1, 1, size=[self.batch_size, self.rand_size])
                    z_cat_label = np.random.randint(0, 10, [self.batch_size])
                    #z_cat_label = np.argmax(Y_mb, axis=1)
                    z_cat = np.array(list(map(one_hot_vec, z_cat_label)))

                    z_con = np.random.normal(0, 1, size=[self.batch_size, self.con_size])
                    z = np.concatenate((z_rand,z_cat,z_con), axis=1)
                    #from IPython.core.debugger import Pdb; Pdb().set_trace()
                    _, d_loss_value = sess.run([d_train_op, self.L_d], feed_dict={
                        self.X_tr: X_mb,
                        self.Z1:z,
                        self.cat_tr_label: Y_mb_label,
                        self.cat_label:z_cat_label,
                        #self.cat: z_cat,
                        self.con: z_con,
                    })

                # train G
                _, g_loss_value = sess.run([g_train_op, self.L_g], feed_dict={
                        self.X_tr: X_mb,
                        self.Z1:z,
                        self.cat_tr_label: Y_mb_label,
                        self.cat_label:z_cat_label,
                        #self.cat: z_cat,
                        self.con: z_con,
                })

                # generate Sample Imgs
                #sampleImgsOfX2Y, sampleImgsOfY2X = sess.run([self.X2Y, self.Y2X], feed_dict={self.X_tr: X_mb, self.Y_tr: Y_mb})

                # 結果をappend
                self.losses["d_loss"].append(np.sum(d_loss_value))
                self.losses["g_loss"].append(np.sum(g_loss_value))
                
                if epoch % 100 == 0:
                    print("epoch:" + str(epoch+epoch_pre))

                # lossの可視化
                if epoch % self.epoch_saveMetrics == 1:
                    save_metrics(self.losses, epoch)

                # 画像の変換テスト
                if epoch % self.epoch_saveSampleImg == 0:
                    con1_img = sess.run(self.Xg, feed_dict={self.Z1: con1_test_z})
                    con1_img = (con1_img*0.5) + 0.5
                    save_imgs(con1_img, name=str(epoch)+"_con1")

                    con2_img = sess.run(self.Xg, feed_dict={self.Z1:con2_test_z})
                    con2_img = (con2_img*0.5) + 0.5
                    save_imgs(con2_img, name=str(epoch)+"_con2")
                # parameterのsave
                if epoch % self.epoch_saveParamter == 1:
                    path = "model_gpu0"
                    if not os.path.isdir(path):
                        os.makedirs(path)

                    saver.save(sess, path+"/dcgan_model" + str(epoch+epoch_pre) + ".ckpt")
       

    def sample_images(self, row=5, col=12, inputs=None, epoch=None):
        images = self.g(inputs, training=True)
        return images

In [None]:
gan = GAN()
gan.train()