In [1]:
import tensorflow as tf
import os
import numpy as np
from model import *
from utility import *

model_name = "BEGAN_for_irasutoya"

In [2]:
class Generator:
    def __init__(self):
        self.reuse = False
        self.g_bn0 = BatchNormalization(name = 'g_bn0')
        self.g_bn1 = BatchNormalization(name = 'g_bn1')
        self.g_bn2 = BatchNormalization(name = 'g_bn2')
        self.g_bn3 = BatchNormalization(name = 'g_bn3')

    def __call__(self, z):
        with tf.variable_scope('g', reuse=self.reuse):
            fc0 = full_connection_layer(z, 1024, name="fc0")
            fc0 = self.g_bn0(fc0)
            fc0 = tf.nn.relu(fc0)

            fc1 = full_connection_layer(fc0, 6*6*512, name="fc1")
            fc1 = self.g_bn1(fc1)
            fc1 = tf.nn.relu(fc1)

            fc1 = tf.reshape(fc1, [-1, 6, 6, 512])

            batch_size = tf.shape(fc1)[0]
            deconv0 = deconv2d_layer(fc1, [batch_size, 12, 12, 256], kernel_size=5, name="deconv0")

            deconv0 = lrelu(deconv0, leak=0.3)

            deconv1 = deconv2d_layer(deconv0, [batch_size, 24, 24, 128], kernel_size=5, name="deconv1")
            deconv1 = self.g_bn2(deconv1)
            deconv1 = lrelu(deconv1, leak=0.3)

            deconv2 = deconv2d_layer(deconv1, [batch_size, 48, 48, 64], kernel_size=5, name="deconv2")
            deconv2 = self.g_bn3(deconv2)
            deconv2 = lrelu(deconv2, leak=0.3)

            deconv3 = deconv2d_layer(deconv2, [batch_size, 96, 96, 3], kernel_size=5, name="deconv3")

            output = tf.nn.tanh(deconv3)
        self.reuse = True
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='g')
        return output

In [3]:
class Discriminator:
    def __init__(self):
        self.reuse = False
        self.d_bn0 = BatchNormalization(name="d_bn0")
        self.d_bn1 = BatchNormalization(name="d_bn1")
        self.d_bn2 = BatchNormalization(name="d_bn2")
        self.d_bn3 = BatchNormalization(name="d_bn3")
        self.d_bn4 = BatchNormalization(name="d_bn4")
        self.d_bn5 = BatchNormalization(name="d_bn5")
        self.d_bn6 = BatchNormalization(name="d_bn6")
        self.d_bn7 = BatchNormalization(name="d_bn7")
        self.d_bn8 = BatchNormalization(name="d_bn8")

    def __call__(self, x):
        with tf.variable_scope('d', reuse=self.reuse):
            x = tf.reshape(x, [-1, 96, 96, 3])

            conv1 = conv2d_layer(x, 64, kernel_size=5, name="d_conv0")
            conv1 = self.d_bn0(conv1)
            conv1 = tf.nn.relu(conv1) # 48x48x64
            
            conv2 = conv2d_layer(conv1, 128, kernel_size=5, name="d_conv1")
            conv2 = self.d_bn1(conv2)
            conv2 = tf.nn.relu(conv2) # 24x24x128

            conv3 = conv2d_layer(conv2, 256, kernel_size=5, name="d_conv2")
            conv3 = self.d_bn2(conv3)
            conv3 = tf.nn.relu(conv3) # 12x12x256

            conv4 = conv2d_layer(conv3, 512, kernel_size=5, name="d_conv3")
            conv4 = self.d_bn3(conv4)
            conv4 = tf.nn.relu(conv4) # 6x6x512
            conv4 = tf.reshape(conv4, [-1, 6*6*512])

            fc0 = full_connection_layer(conv4, 64, name="fc0")
            fc0 = self.d_bn4(fc0)
            fc0 = tf.nn.relu(fc0)
            
            fc1 = full_connection_layer(fc0, 6*6*512, name="fc1")
            fc1 = self.d_bn5(fc1)
            fc1 = tf.nn.relu(fc1)
            fc1 = tf.reshape(fc1, [-1, 6, 6, 512])

            batch_size = tf.shape(fc1)[0]
            deconv0 = deconv2d_layer(fc1, [batch_size, 12, 12, 256], kernel_size=5, name="deconv0")
            deconv0 = self.d_bn6(deconv0)
            deconv0 = tf.nn.relu(deconv0)    

            deconv1 = deconv2d_layer(deconv0, [batch_size, 24, 24, 128], kernel_size=5, name="deconv1")
            deconv1 = self.d_bn7(deconv1)
            deconv1 = tf.nn.relu(deconv1)

            deconv2 = deconv2d_layer(deconv1, [batch_size, 48, 48, 64], kernel_size=5, name="deconv2")
            deconv2 = self.d_bn8(deconv2)
            deconv2 = tf.nn.relu(deconv2)

            deconv3 = deconv2d_layer(deconv2, [batch_size, 96, 96, 3], kernel_size=5, name="deconv3")
            output = tf.nn.tanh(deconv3)
            
            #batch_size = tf.shape(output)[0]
            #recon_error = tf.sqrt(2*tf.nn.l2_loss(output - x))/batch_size
            recon_error = tf.reduce_mean(tf.abs(output - x))

        self.reuse = True
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='d')

        return output, recon_error

In [None]:
class GAN:
    def __init__(self):
        self.batch_size = 64
        self.img_size = 96
        self.z_size = 100
        
        self.epochs = 50000
        self.epoch_saveMetrics = 1000
        self.epoch_saveSampleImg = 1000
        self.epoch_saveParamter = 5000
        self.losses = {"d_loss":[], "g_loss":[], "M_value":[]}

        # unrolled counts
        self.steps = 5

        self.dataset = np.load("irasutoya_face_1813x96x96x3_jpg.npy")
        self.dataset = (self.dataset/255) - 0.5

        self.X_tr = tf.placeholder(tf.float32, shape=[None, self.img_size, self.img_size, 3])
        self.z = tf.placeholder(tf.float32, [None, self.z_size])
        
        self.g = Generator()
        self.d = Discriminator()
        self.Xg = self.g(self.z)

        self.k = tf.Variable(0., trainable=False)
        self.lambda_ = 1e-3
        self.gamma = 0.75
        
    def loss(self):
        output_tr, recon_error_tr = self.d(self.X_tr)
        output_gen, recon_error_gen = self.d(self.Xg)
        
        loss_d = recon_error_tr - self.k*recon_error_gen
        loss_g = recon_error_gen

        self.M = recon_error_tr + tf.abs(self.gamma*recon_error_tr - recon_error_gen)
        self.update_k = self.k.assign(self.k + self.lambda_*(self.gamma*recon_error_tr - recon_error_gen))

        return loss_g, loss_d

    def train(self):
        # Optimizer
        d_lr = 2e-4
        d_beta1 = 0.5
        g_lr = 2e-4
        g_beta1 = 0.5

        self.L_g, self.L_d = self.loss()

        d_opt = tf.train.AdamOptimizer(learning_rate=d_lr)
        d_train_op = d_opt.minimize(self.L_d, var_list=self.d.variables)
        g_opt = tf.train.AdamOptimizer(learning_rate=g_lr)
        g_train_op = g_opt.minimize(self.L_g, var_list=self.g.variables)

        saver = tf.train.Saver()
        
        config = tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                visible_device_list= "0"
            )
        )
                
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())

            # preparing noise vec for test
            bs = 100
            test_z = np.random.uniform(-1, 1, size=[bs, self.z_size])

            for epoch in range(self.epochs):

                # extract images for training
                rand_index = np.random.randint(0, self.dataset.shape[0], size=self.batch_size)
                X_mb = self.dataset[rand_index, :].astype(np.float32)
                X_mb = np.reshape(X_mb, [-1, 96, 96, 3])

                z = np.random.uniform(-1, 1, size=[self.batch_size, self.z_size])

                # train Discriminator
                _, d_loss_value = sess.run([d_train_op, self.L_d], feed_dict={
                    self.X_tr: X_mb,
                    self.z:z,
                })

                # train Generator
                _, g_loss_value = sess.run([g_train_op, self.L_g], feed_dict={
                    self.X_tr: X_mb,
                    self.z:z,
                })

                # update k
                M_value, _ = sess.run([self.M, self.update_k], feed_dict={
                    self.X_tr: X_mb,
                    self.z:z,
                })

                # append loss value for visualizing
                self.losses["d_loss"].append(np.sum(d_loss_value))
                self.losses["g_loss"].append(np.sum(g_loss_value))
                self.losses["M_value"].append(M_value)
                
                # print epoch
                if epoch % 100 == 0:
                    print('epoch:{0}, d_loss:{1}, g_loss:{2}, M:value:{3} '.format(epoch, d_loss_value, g_loss_value, M_value))
                
                # visualize loss
                if epoch % self.epoch_saveMetrics == 0:
                    save_metrics(model_name, self.losses, epoch)

                # visualize generated images during training
                if epoch % self.epoch_saveSampleImg == 0:
                    img = sess.run(self.Xg, feed_dict={self.z: test_z})
                    img = (img*0.5) + 0.5
                    save_imgs(model_name, img, name=str(epoch))

                # save model parameters 
                if epoch % self.epoch_saveParamter == 0:
                    dir_path = "model_" + model_name
                    if not os.path.isdir(dir_path):
                        os.makedirs(dir_path)

                    saver.save(sess, dir_path + "/" + str(epoch) + ".ckpt")

In [None]:
gan = GAN()
gan.train()



epoch:0, d_loss:0.42656436562538147, g_loss:0.2646520435810089, M:value:0.460518479347229 
epoch:100, d_loss:0.08670857548713684, g_loss:0.07593969255685806, M:value:0.09418926388025284 
epoch:200, d_loss:0.07160013169050217, g_loss:0.06475510448217392, M:value:0.07781708985567093 
epoch:300, d_loss:0.06462905555963516, g_loss:0.0506284236907959, M:value:0.06408795714378357 
epoch:400, d_loss:0.06083686277270317, g_loss:0.046965714544057846, M:value:0.06062028184533119 
epoch:500, d_loss:0.057797156274318695, g_loss:0.03979562222957611, M:value:0.06076148524880409 
epoch:600, d_loss:0.05294986441731453, g_loss:0.03910772129893303, M:value:0.05293484777212143 
epoch:700, d_loss:0.04869881644845009, g_loss:0.03582185506820679, M:value:0.049931496381759644 
epoch:800, d_loss:0.05000486224889755, g_loss:0.0343509204685688, M:value:0.05283728241920471 
epoch:900, d_loss:0.04677318409085274, g_loss:0.034027110785245895, M:value:0.047986630350351334 
epoch:1000, d_loss:0.04616539552807808, g_