In [1]:
import tensorflow as tf
import os
import numpy as np

model_name = "AC-DRAGAN_SRResNet_PixelCNN_for_IRASUTOYA"

In [2]:
from model import *
from utility import *

In [3]:
class Generator:
    def __init__(self):
        self.reuse = False
        self.g_bn0 = BatchNormalization(name = 'g_bn0')

        self.num_res_blocks = 16
        self.num_pixel_CNN_blocks = 3
        
        self.res_bns = []
        for i in range(int(self.num_res_blocks)):
            self.res_bns.append(BatchNormalization(name = "res_%d" % (2*i)))
            self.res_bns.append(BatchNormalization(name = "res_%d" % (2*i+1)))
        
        self.ps_bns = []
        for i in range(int(self.num_pixel_CNN_blocks)):
            self.ps_bns.append(BatchNormalization(name = "ps_%d" % i))
        
        self.g_bn1 = BatchNormalization(name = 'g_bn1')
        
    def __call__(self, z):
        with tf.variable_scope('g', reuse=self.reuse):
            
            # reshape from inputs
            with tf.variable_scope('fc0'):
                #z0 = tf.reshape(z, [-1, self.z_dim])
                fc0 = full_connection_layer(z, 64*16*16, name="fc0")
                fc0 = self.g_bn0(fc0)
                fc0 = tf.nn.relu(fc0)
                fc0 = tf.reshape(fc0, [-1,16,16,64])

            assert fc0.get_shape().as_list()[1:] == [16,16,64]
            
            layers = []
            layers.append(fc0)
            
            for i in range(int(self.num_res_blocks)):
                with tf.variable_scope('res_%d' % (i+1)):
                    res = conv2d_layer(layers[-1], 64, kernel_size=3, strides=1, name="g_conv_res_%d" % (2*i))
                    res = self.res_bns[2*i](res)
                    res = tf.nn.relu(res)

                    res = conv2d_layer(res, 64, kernel_size=3, strides=1, name="g_conv_res_%d" % (2*i+1))
                    res = self.res_bns[2*i+1](res)
                    res = layers[-1] + res
                    layers.append(res)                    

            assert layers[-1].get_shape().as_list()[1:] == [16,16,64]
            
            with tf.variable_scope('conv17'):
                conv17 = conv2d_layer(layers[-1], 64, kernel_size=3, strides=1, name="g_conv_17")
                conv17 = self.g_bn1(conv17)
                conv17 = tf.nn.relu(conv17)
                conv17 = layers[0] + conv17
                layers.append(conv17)

            assert layers[-1].get_shape().as_list()[1:] == [16, 16, 64]

            for i in range(int(self.num_pixel_CNN_blocks)):
                with tf.variable_scope('pixel_CNN_%d' % (i+1)):
                    ps = conv2d_layer(layers[-1], 256, kernel_size=3, strides=1, name="g_conv_ps_%d" % (i))
                    ps = pixel_shuffle_layer(ps, 2, 64)
                    ps = self.ps_bns[i](ps)
                    ps = tf.nn.relu(ps)
                    layers.append(ps)

            assert layers[-1].get_shape().as_list()[1:] == [128, 128, 64]
                    
            with tf.variable_scope('output'):
                output = conv2d_layer(layers[-1], 3, kernel_size=9, strides=1, name="output")
                output = tf.nn.sigmoid(output)

            assert output.get_shape().as_list()[1:] == [128, 128, 3]            
            
        self.reuse = True
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='g')
        return output

In [4]:
class Discriminator:
    def __init__(self, cat_size):
        self.reuse = False
        self.d_bn0 = BatchNormalization(name="d_bn0")
        self.d_bn1 = BatchNormalization(name="d_bn1")
        self.d_bn2 = BatchNormalization(name="d_bn2")
        self.d_bn3 = BatchNormalization(name="d_bn3")
        self.d_bn4 = BatchNormalization(name="d_bn4")
        
        self.cat_size = cat_size

    def __call__(self, x):
        def leaky_relu(x):
            return lrelu(x, leak=0.2)

        with tf.variable_scope('d', reuse=self.reuse):
           
            x = tf.reshape(x, [-1, 128, 128, 3])
            with tf.variable_scope('conv1'):
                conv1 = tf.layers.conv2d(x, 32, [4, 4], [2 ,2], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                conv1 = leaky_relu(conv1)

            with tf.variable_scope('res1'):
                res1 = tf.layers.conv2d(conv1, 32, [3,3], [1,1], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                res1 = leaky_relu(res1)
                res1 = tf.layers.conv2d(res1, 32, [3,3], [1,1], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                res1 = res1 + conv1
                res1 = leaky_relu(res1)

            with tf.variable_scope('res2'):
                res2 = tf.layers.conv2d(res1, 32, [3,3], [1,1], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                res2 = leaky_relu(res2)
                res2 = tf.layers.conv2d(res2, 32, [3,3], [1,1], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                res2 = res2 + res1
                res2 = leaky_relu(res2)

            with tf.variable_scope('conv2'):
                conv2 = tf.layers.conv2d(res2, 64, [4, 4], [2 ,2], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                conv2 = leaky_relu(conv2)

            with tf.variable_scope('res3'):
                res3 = tf.layers.conv2d(conv2, 64, [3,3], [1,1], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                res3 = leaky_relu(res3)
                res3 = tf.layers.conv2d(res3, 64, [3,3], [1,1], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                res3 = leaky_relu(res3)
                res3 = res3 + conv2
                res3 = leaky_relu(res3)

            with tf.variable_scope('res4'):
                res4 = tf.layers.conv2d(res3, 64, [3,3], [1,1], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                res4 = leaky_relu(res4)
                res4 = tf.layers.conv2d(res4, 64, [3,3], [1,1], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                res4 = leaky_relu(res4)
                res4 = res4 + res3
                res4 = leaky_relu(res4)

            with tf.variable_scope('conv3'):
                conv3 = tf.layers.conv2d(res4, 128, [4, 4], [2 ,2], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                conv3 = leaky_relu(conv3)

            num_res_itr = 3
            layers = []
            layers.append(conv3)
            
            depth = [128, 256, 512, 1024]
            for i in range(int(num_res_itr)):
                with tf.variable_scope('res_%d_1' % (i+1+4)):
                    res = tf.layers.conv2d(layers[-1], depth[i], [3,3], [1,1], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                    res = leaky_relu(res)
                    res = tf.layers.conv2d(res, depth[i], [3,3], [1,1], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                    res = leaky_relu(res)
                    res = layers[-1] + res
                    res = leaky_relu(res)
                layers.append(res)

                with tf.variable_scope('res_%d_2' % (i+1+4)):
                    res = tf.layers.conv2d(layers[-1], depth[i], [3,3], [1,1], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                    res = leaky_relu(res)
                    res = tf.layers.conv2d(res, depth[i], [3,3], [1,1], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                    res = leaky_relu(res)
                    res = layers[-1] + res
                    res = leaky_relu(res)

                conv = tf.layers.conv2d(res, depth[i+1], [4, 4], [2 ,2], padding="SAME", kernel_initializer=tf.truncated_normal_initializer(0.0, 0.02))
                conv = leaky_relu(conv) 
                layers.append(conv)

            disc = full_connection_layer(layers[-1], 1, name="disc")
            aux = full_connection_layer(layers[-1], self.cat_size, name="aux")

        self.reuse = True
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='d')

        return disc, aux

In [None]:
class GAN:
    def __init__(self):
        self.batch_size = 64
        self.img_size = 128
        self.rand_size = 100
        self.cat_size = 4
        self.z_size = self.rand_size + self.cat_size
        
        self.epochs = 50000
        self.epoch_saveMetrics = 50
        self.epoch_saveSampleImg = 50
        self.epoch_saveParamter = 5000
        self.losses = {"d_loss":[], "g_loss":[]}

        # unrolled counts
        self.steps = 5

        self.X_tr = tf.placeholder(tf.float32, shape=[None, self.img_size*self.img_size*3])
        self.Y_tr = tf.placeholder(tf.float32, shape=[None, self.cat_size])
        self.z = tf.placeholder(tf.float32, [None, self.z_size])
        self.X_per = tf.placeholder(tf.float32, shape=[None, self.img_size*self.img_size*3])

        self.g = Generator()
        self.d = Discriminator(self.cat_size)
        self.Xg = self.g(self.z)
        #self.dtd = DTD()
        self.irasutoya = IRASUTOYA()

    def loss(self):
        disc_tr, aux_tr = self.d(self.X_tr)
        disc_gen, aux_gen = self.d(self.Xg)
        
        lambda_adv = 34
        lambda_gp = 0.5
       
        loss_g = lambda_adv*tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_gen, labels=tf.ones_like(disc_gen)))
        #loss_g = lambda_adv*tf.reduce_mean(1 - tf.log(disc_gen + TINY))

        diff = self.X_per - self.X_tr
        #print(g_outputs.shape[0])
        alpha = tf.random_uniform(shape=[self.batch_size,1], minval=0., maxval=1.)
        interpolates = self.X_tr + (alpha*diff)
        disc_interplates, _ = self.d(interpolates)
        gradients = tf.gradients(disc_interplates, [interpolates])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = tf.reduce_mean((slopes-1.)**2)

        loss_d_tr = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_tr, labels=tf.ones_like(disc_tr)))
        loss_d_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_gen, labels=tf.zeros_like(disc_gen)))
        loss_d = lambda_adv*(loss_d_tr + loss_d_gen)
        loss_d += lambda_gp*gradient_penalty

        loss_c_tr = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=aux_tr, labels=self.Y_tr))
        loss_c_gen = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=aux_gen, labels=self.Y_tr))
        loss_c = (loss_c_tr + loss_c_gen)

        loss_g += loss_c
        loss_d += loss_c
        return loss_g, loss_d

    def train(self):
        # Optimizer
        d_lr = 1e-4
        d_beta1 = 0.5
        g_lr = 1e-4
        g_beta1 = 0.5

        self.L_g, self.L_d = self.loss()

        d_opt = tf.train.AdamOptimizer(learning_rate=d_lr)
        d_train_op = d_opt.minimize(self.L_d, var_list=self.d.variables)
        g_opt = tf.train.AdamOptimizer(learning_rate=g_lr)
        g_train_op = g_opt.minimize(self.L_g, var_list=self.g.variables)

        saver = tf.train.Saver()
        
        config = tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                visible_device_list= "0"
            )
        )
                
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())

            # preparing noise vec for test
            bs = 100
            test_z = np.random.uniform(-1, 1, size=[bs, self.z_size])

            # cat 0
            bs = 10
            test_cat0_rand = np.random.uniform(-1, 1, size=[bs, self.rand_size])
            test_cat0_cat = np.zeros([bs, self.cat_size])
            test_cat0_cat[:, 0] = np.linspace(-1, 1, num=bs)
            test_cat0_z = np.concatenate((test_cat0_rand, test_cat0_cat), axis=1)  

            # cat 1
            bs = 10
            #test_cat1_rand = np.random.uniform(-1, 1, size=[bs, self.rand_size])
            test_cat1_cat = np.zeros([bs, self.cat_size])
            test_cat1_cat[:, 1] = np.linspace(-1, 1, num=bs)
            test_cat1_z = np.concatenate((test_cat0_rand, test_cat1_cat), axis=1)  

            # cat 2
            bs = 10
            #test_cat1_rand = np.random.uniform(-1, 1, size=[bs, self.rand_size])
            test_cat2_cat = np.zeros([bs, self.cat_size])
            test_cat2_cat[:, 1] = np.linspace(-1, 1, num=bs)
            test_cat2_z = np.concatenate((test_cat0_rand, test_cat2_cat), axis=1)

            # cat 3
            bs = 10
            #test_cat1_rand = np.random.uniform(-1, 1, size=[bs, self.rand_size])
            test_cat3_cat = np.zeros([bs, self.cat_size])
            test_cat3_cat[:, 1] = np.linspace(-1, 1, num=bs)
            test_cat3_z = np.concatenate((test_cat0_rand, test_cat3_cat), axis=1)

            for epoch in range(self.epochs):

                # visualize generated images during training
                if epoch % self.epoch_saveSampleImg == 0:
                    # rand
                    img = sess.run(self.Xg, feed_dict={self.z: test_z})
                    save_imgs(model_name, img, name=str(epoch)+"_rand")

                    # cat
                    test_cat = np.concatenate((test_cat0_z, test_cat1_z, test_cat2_z, test_cat3_z), axis=0)
                    img = sess.run(self.Xg, feed_dict={self.z: test_cat})
                    save_imgs(model_name, img, plot_dim=(4,10), size=(20, 8), name=str(epoch)+"_cat")

                for step in range(self.steps):
                    # extract images for training
                    #rand_index = np.random.randint(0, self.dataset.shape[0], size=self.batch_size)
                    #X_mb, Y_mb = self.dataset[rand_index, :].astype(np.float32)
                    X_mb, Y_mb = self.irasutoya.extract(self.batch_size, self.img_size)
                    X_mb = np.reshape(X_mb, [self.batch_size, -1])
                    X_mb_per = X_mb + 0.5*np.std(X_mb)*np.random.random(X_mb.shape)

                    rand = np.random.uniform(-1, 1, size=[self.batch_size, self.rand_size])
                    #print(rand.shape)
                    #print(Y_mb.shape)
                    z = np.hstack((rand, Y_mb))
                    #print(z.shape)

                    # train Discriminator
                    _, d_loss_value = sess.run([d_train_op, self.L_d], feed_dict={
                        self.X_tr: X_mb,
                        self.z:z,
                        self.Y_tr: Y_mb,
                        self.X_per: X_mb_per,
                    })

                # train Generator
                _, g_loss_value = sess.run([g_train_op, self.L_g], feed_dict={
                    self.X_tr: X_mb,
                    self.z:z,
                    self.Y_tr: Y_mb,
                    self.X_per: X_mb_per,
                })

                # append loss value for visualizing
                self.losses["d_loss"].append(np.sum(d_loss_value))
                self.losses["g_loss"].append(np.sum(g_loss_value))
                
                # print epoch
                if epoch % 1 == 0:
                    print('epoch:{0}, d_loss:{1}, g_loss: {2} '.format(epoch, d_loss_value, g_loss_value))
                
                # visualize loss
                if epoch % self.epoch_saveMetrics == 0:
                    save_metrics(model_name, self.losses, epoch)


                # save model parameters 
                if epoch % self.epoch_saveParamter == 0:
                    dir_path = "model_" + model_name
                    if not os.path.isdir(dir_path):
                        os.makedirs(dir_path)

                    saver.save(sess, dir_path + "/" + str(epoch) + ".ckpt")

In [None]:
gan = GAN()
gan.train()

init IRASUTOYA
1839




epoch:0, d_loss:10.717164039611816, g_loss: 0.45469027757644653 
epoch:1, d_loss:10.66978931427002, g_loss: 0.8071359395980835 
epoch:2, d_loss:6.59810733795166, g_loss: 1.2604871988296509 
epoch:3, d_loss:4.8326735496521, g_loss: 5.1893744468688965 
epoch:4, d_loss:4.323659420013428, g_loss: 0.3473781645298004 
epoch:5, d_loss:6.667032241821289, g_loss: 0.7388228178024292 
epoch:6, d_loss:1.5299060344696045, g_loss: 0.8058099746704102 
epoch:7, d_loss:0.7448770999908447, g_loss: 1.3019134998321533 
epoch:8, d_loss:0.30956876277923584, g_loss: 2.12503719329834 
epoch:9, d_loss:0.6747624278068542, g_loss: 2.5885472297668457 
epoch:10, d_loss:0.3586234450340271, g_loss: 4.15969705581665 
epoch:11, d_loss:0.670871376991272, g_loss: 3.7476446628570557 
epoch:12, d_loss:1.5405986309051514, g_loss: 1.1674647331237793 
epoch:13, d_loss:2.2786359786987305, g_loss: 2.4205222129821777 
epoch:14, d_loss:3.8509740829467773, g_loss: 1.8217016458511353 
epoch:15, d_loss:2.854445457458496, g_loss: 0.

###### 