In [1]:
import tensorflow as tf
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
import numpy as np
from model import *
from utility import *

model_name = "WGAN_for_irasutoya"

In [2]:
class Generator:
    def __init__(self):
        self.reuse = False
        self.g_bn0 = BatchNormalization(name = 'g_bn0')
        self.g_bn1 = BatchNormalization(name = 'g_bn1')
        self.g_bn2 = BatchNormalization(name = 'g_bn2')
        self.g_bn3 = BatchNormalization(name = 'g_bn3')

    def __call__(self, z):
        with tf.variable_scope('g', reuse=self.reuse):

            fc0 = full_connection_layer(z, 512*6*6, name="fc0")
            fc0 = tf.reshape(fc0, [-1, 6, 6, 512])

            batch_size = tf.shape(fc0)[0]
            deconv0 = deconv2d_layer(fc0, [batch_size, 12, 12, 256], kernel_size=5, name="deconv0")
            deconv0 = self.g_bn0(deconv0)
            deconv0 = lrelu(deconv0, leak=0.3)

            deconv1 = deconv2d_layer(deconv0, [batch_size, 24, 24, 128], kernel_size=5, name="deconv1")
            deconv1 = self.g_bn1(deconv1)
            deconv1 = lrelu(deconv1, leak=0.3)

            deconv2 = deconv2d_layer(deconv1, [batch_size, 48, 48, 64], kernel_size=5, name="deconv2")
            deconv2 = self.g_bn2(deconv2)
            deconv2 = lrelu(deconv2, leak=0.3)

            deconv3 = deconv2d_layer(deconv2, [batch_size, 96, 96, 3], kernel_size=5, name="deconv3")

            output = tf.nn.tanh(deconv3)
        self.reuse = True
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='g')
        return output

In [3]:
class Discriminator:
    def __init__(self):
        self.reuse = False
        self.d_bn0 = BatchNormalization(name="d_bn0")
        self.d_bn1 = BatchNormalization(name="d_bn1")
        self.d_bn2 = BatchNormalization(name="d_bn2")
        self.d_bn3 = BatchNormalization(name="d_bn3")
        self.d_bn4 = BatchNormalization(name="d_bn4")
        
    def __call__(self, x):
        with tf.variable_scope('d', reuse=self.reuse):
            x = tf.reshape(x, [-1, 96, 96, 3])

            conv1 = conv2d_layer(x, 64, kernel_size=5, name="d_conv0")
            conv1 = self.d_bn0(conv1)
            conv1 = lrelu(conv1, leak=0.3)
            
            conv2 = conv2d_layer(conv1, 128, kernel_size=5, name="d_conv1")
            conv2 = self.d_bn1(conv2)
            conv2 = lrelu(conv2, leak=0.3)

            conv3 = conv2d_layer(conv2, 256, kernel_size=5, name="d_conv2")
            conv3 = self.d_bn2(conv3)
            conv3 = lrelu(conv3, leak=0.3)

            conv4 = conv2d_layer(conv3, 512, kernel_size=5, name="d_conv3")
            conv4 = self.d_bn3(conv4)
            conv4 = lrelu(conv4, leak=0.3)
            
            disc = full_connection_layer(conv4, 1, name="disc")

        self.reuse = True
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='d')

        return disc

In [4]:
class GAN:
    def __init__(self):
        self.batch_size = 64
        self.img_size = 96
        self.z_size = 100
        
        self.epochs = 50000
        self.epoch_saveMetrics = 1000
        self.epoch_saveSampleImg = 1000
        self.epoch_saveParamter = 5000
        self.losses = {"d_loss":[], "g_loss":[]}

        # unrolled counts
        self.steps = 5

        self.dataset = np.load("irasutoya_face_1813x96x96x3_jpg.npy")
        self.dataset = (self.dataset/255)# - 0.5

        self.X_tr = tf.placeholder(tf.float32, shape=[None, self.img_size, self.img_size, 3])
        self.z = tf.placeholder(tf.float32, [None, self.z_size])
        
        self.g = Generator()
        self.d = Discriminator()
        self.Xg = self.g(self.z)



    def loss(self):
        disc_tr = self.d(self.X_tr)
        disc_gen = self.d(self.Xg)
        
        loss_d = -tf.reduce_mean(disc_tr) + tf.reduce_mean(disc_gen)
        loss_g = -tf.reduce_mean(disc_gen)

        return loss_g, loss_d

    def train(self):
        # Optimizer

        self.L_g, self.L_d = self.loss()

        d_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        d_train_op = d_opt.minimize(self.L_d, var_list=self.d.variables)
        g_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        g_train_op = g_opt.minimize(self.L_g, var_list=self.g.variables)

        self.clip_updates = [w.assign(tf.clip_by_value(w, -0.01, 0.01)) for w in self.d.variables]

        saver = tf.train.Saver()
        
        config = tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                visible_device_list= "0"
            )
        )
                
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())

            # preparing noise vec for test
            bs = 100
            test_z = np.random.uniform(-1, 1, size=[bs, self.z_size])

            for epoch in range(self.epochs):
                for step in range(self.steps):
                    # extract images for training
                    rand_index = np.random.randint(0, self.dataset.shape[0], size=self.batch_size)
                    X_mb = self.dataset[rand_index, :].astype(np.float32)
                    X_mb = np.reshape(X_mb, [-1, 96, 96, 3])
                    z = np.random.uniform(-1, 1, size=[self.batch_size, self.z_size])

                    sess.run([self.clip_updates])

                    # train Discriminator
                    _, d_loss_value = sess.run([d_train_op, self.L_d], feed_dict={
                        self.X_tr: X_mb,
                        self.z:z,
                    })
         
                # extract images for training
                rand_index = np.random.randint(0, self.dataset.shape[0], size=self.batch_size)
                X_mb = self.dataset[rand_index, :].astype(np.float32)
                X_mb = np.reshape(X_mb, [-1, 96, 96, 3])
                z = np.random.uniform(-1, 1, size=[self.batch_size, self.z_size])

                # train Generator
                _, g_loss_value = sess.run([g_train_op, self.L_g], feed_dict={
                    self.X_tr: X_mb,
                    self.z:z,
                })

                # append loss value for visualizing
                self.losses["d_loss"].append(np.sum(d_loss_value))
                self.losses["g_loss"].append(np.sum(g_loss_value))
                
                # print epoch
                if epoch % 100 == 0:
                    print('epoch:{0}, d_loss:{1}, g_loss{2} '.format(epoch, d_loss_value, g_loss_value))
                
                # visualize loss
                if epoch % self.epoch_saveMetrics == 0:
                    save_metrics(model_name, self.losses, epoch)

                # visualize generated images during training
                if epoch % self.epoch_saveSampleImg == 0:
                    img = sess.run(self.Xg, feed_dict={self.z: test_z})
                    #img = (img+ 0.5)*1.0
                    save_imgs(model_name, img, name=str(epoch))

                # save model parameters 
                if epoch % self.epoch_saveParamter == 0:
                    dir_path = "model_" + model_name
                    if not os.path.isdir(dir_path):
                        os.makedirs(dir_path)

                    saver.save(sess, dir_path + "/" + str(epoch) + ".ckpt")

In [None]:
gan = GAN()
gan.train()


Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
epoch:0, d_loss:-0.0008500061230733991, g_loss-0.0006087312940508127 


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

epoch:100, d_loss:-3.3866782188415527, g_loss1.725684404373169 
epoch:200, d_loss:-3.426732063293457, g_loss1.7325007915496826 
epoch:300, d_loss:-3.380202054977417, g_loss1.6623528003692627 
epoch:400, d_loss:-3.392364501953125, g_loss1.6626920700073242 
epoch:500, d_loss:-3.1641602516174316, g_loss1.6485841274261475 
epoch:600, d_loss:-3.352792501449585, g_loss1.6465719938278198 
epoch:700, d_loss:-3.040391445159912, g_loss1.5147663354873657 
epoch:800, d_loss:-3.1910400390625, g_loss1.6020570993423462 
epoch:900, d_loss:-2.5679636001586914, g_loss1.605576515197754 
epoch:1000, d_loss:-3.1543943881988525, g_loss1.5741395950317383 


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

epoch:1100, d_loss:-3.1767451763153076, g_loss1.5984090566635132 
epoch:1200, d_loss:-3.214693069458008, g_loss1.6140341758728027 
epoch:1300, d_loss:-1.5086851119995117, g_loss1.607201099395752 
epoch:1400, d_loss:-3.2003185749053955, g_loss1.6188175678253174 
epoch:1500, d_loss:-3.216465473175049, g_loss1.628318428993225 
epoch:1600, d_loss:-3.018012046813965, g_loss1.5358792543411255 
epoch:1700, d_loss:-1.2702953815460205, g_loss1.6120343208312988 
epoch:1800, d_loss:-3.18013334274292, g_loss1.6453540325164795 
epoch:1900, d_loss:-3.093172550201416, g_loss1.652710199356079 
epoch:2000, d_loss:-3.2414674758911133, g_loss1.659667730331421 


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

epoch:2100, d_loss:-3.3052401542663574, g_loss1.6640225648880005 
epoch:2200, d_loss:-2.630314350128174, g_loss1.6716670989990234 
epoch:2300, d_loss:-3.001194715499878, g_loss0.903390645980835 
epoch:2400, d_loss:-3.2874653339385986, g_loss1.6642711162567139 
epoch:2500, d_loss:-3.2365236282348633, g_loss1.6152950525283813 
epoch:2600, d_loss:-3.1944897174835205, g_loss1.5616952180862427 
epoch:2700, d_loss:-3.2751827239990234, g_loss1.6680381298065186 
epoch:2800, d_loss:-3.1420364379882812, g_loss1.6755542755126953 
epoch:2900, d_loss:-3.3386073112487793, g_loss1.6904711723327637 
epoch:3000, d_loss:-3.2997493743896484, g_loss1.6703888177871704 


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

epoch:3100, d_loss:-3.062669038772583, g_loss1.690253496170044 
epoch:3200, d_loss:-2.4660940170288086, g_loss-0.23381896317005157 
epoch:3300, d_loss:-2.0798873901367188, g_loss-0.9440329074859619 
epoch:3400, d_loss:-3.307492733001709, g_loss1.6811426877975464 
epoch:3500, d_loss:-3.25783634185791, g_loss1.6546530723571777 
epoch:3600, d_loss:-1.179602026939392, g_loss0.7919972538948059 
epoch:3700, d_loss:-3.2749404907226562, g_loss1.6661837100982666 
epoch:3800, d_loss:-2.5333712100982666, g_loss1.3817596435546875 
epoch:3900, d_loss:-3.2288386821746826, g_loss1.6568715572357178 
epoch:4000, d_loss:-3.232605457305908, g_loss1.6622910499572754 


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

epoch:4100, d_loss:-3.248871088027954, g_loss1.6684553623199463 
epoch:4200, d_loss:-1.9791247844696045, g_loss-0.07900112867355347 
epoch:4300, d_loss:-3.2681241035461426, g_loss1.6699167490005493 
epoch:4400, d_loss:-3.25211238861084, g_loss1.6714924573898315 
epoch:4500, d_loss:-3.219465732574463, g_loss1.6465368270874023 
epoch:4600, d_loss:-3.204529047012329, g_loss1.6778008937835693 
epoch:4700, d_loss:-2.477304697036743, g_loss0.435303270816803 
epoch:4800, d_loss:-3.242654323577881, g_loss1.6714305877685547 
epoch:4900, d_loss:-1.5853631496429443, g_loss1.5607523918151855 
epoch:5000, d_loss:-3.2471675872802734, g_loss1.6554279327392578 


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

epoch:5100, d_loss:-3.2521800994873047, g_loss1.6583064794540405 
epoch:5200, d_loss:-3.1109256744384766, g_loss1.5417931079864502 
epoch:5300, d_loss:-3.178222179412842, g_loss1.6181814670562744 
