In [1]:
#!/usr/bin/env python
# coding: utf-8

# In[ ]:


"""Code for training CycleGAN."""
from datetime import datetime

import numpy as np
import os
import random
from scipy.misc import imsave

import tensorflow as tf

import data_loader, losses, model
import cyclegan_datasets
import warnings
#warnings.filterwarnings('ignore')
#warnings.simplefilter('ignore', FutureWarning)
#warnings.filterwarnings(action="ignore",message=".*regex.*",category=DeprecationWarning)
#warnings.filterwarnings(action="ignore",message=".*regex.*",category=FutureWarning)

slim = tf.contrib.slim


class CycleGAN:
    """The CycleGAN module."""

    def __init__(self, pool_size, lambda_a,
                 lambda_b, output_root_dir, to_restore,
                 base_lr, max_step,
                 dataset_name, checkpoint_dir, do_flipping, skip):
        current_time = datetime.now().strftime("%Y%m%d-%H%M%S")

        self._pool_size = pool_size
        self._size_before_crop = 286
        self._lambda_a = lambda_a
        self._lambda_b = lambda_b
        self._output_dir = os.path.join(output_root_dir, current_time)
        self._images_dir = os.path.join(self._output_dir, 'imgs')
        self._num_imgs_to_save = 20
        self._to_restore = to_restore
        self._base_lr = base_lr
        self._max_step = max_step
        self._dataset_name = dataset_name
        self._checkpoint_dir = checkpoint_dir
        self._do_flipping = do_flipping
        self._skip = skip

        

        self.fake_images_A = np.zeros(
            (self._pool_size, 1, model.IMG_HEIGHT, model.IMG_WIDTH,
             model.IMG_CHANNELS)
        )
        self.fake_images_B = np.zeros(
            (self._pool_size, 1, model.IMG_HEIGHT, model.IMG_WIDTH,
             model.IMG_CHANNELS)
        )

    def model_setup(self):
        """
        This function sets up the model to train.

        self.input_A/self.input_B -> Set of training images.
        self.fake_A/self.fake_B -> Generated images by corresponding generator
        of input_A and input_B
        self.lr -> Learning rate variable
        self.cyc_A/ self.cyc_B -> Images generated after feeding
        self.fake_A/self.fake_B to corresponding generator.
        This is use to calculate cyclic loss
        """
        self.input_a = tf.placeholder(
            tf.float32, [
                1,
                model.IMG_HEIGHT,
                model.IMG_WIDTH,
                model.IMG_CHANNELS
            ], name="input_A")
        self.input_b = tf.placeholder(
            tf.float32, [
                1,
                model.IMG_HEIGHT,
                model.IMG_WIDTH,
                model.IMG_CHANNELS
            ], name="input_B")
        self.input_ref = tf.placeholder(
            tf.float32, [
                1,
                model.IMG_HEIGHT,
                model.IMG_WIDTH,
                model.IMG_CHANNELS
            ], name="input_ref")

        self.fake_pool_A = tf.placeholder(
            tf.float32, [
                None,
                model.IMG_HEIGHT,
                model.IMG_WIDTH,
                model.IMG_CHANNELS
            ], name="fake_pool_A")
        self.fake_pool_B = tf.placeholder(
            tf.float32, [
                None,
                model.IMG_HEIGHT,
                model.IMG_WIDTH,
                model.IMG_CHANNELS
            ], name="fake_pool_B")

        self.global_step = slim.get_or_create_global_step()

        self.num_fake_inputs = 0

        self.learning_rate = tf.placeholder(tf.float32, shape=[], name="lr")

        
        inputs = {
            'images_a': self.input_a,
            'images_b': self.input_b,
            'images_ref': self.input_ref,
            'fake_pool_a': self.fake_pool_A,
            'fake_pool_b': self.fake_pool_B,
        }

        outputs = model.get_outputs(
            inputs, skip=self._skip)

        self.prob_real_a_is_real = outputs['prob_real_a_is_real']
        self.prob_real_b_is_real = outputs['prob_real_b_is_real']
        self.fake_images_a = outputs['fake_images_a']
        self.fake_images_b = outputs['fake_images_b']
        self.prob_fake_a_is_real = outputs['prob_fake_a_is_real']
        self.prob_fake_b_is_real = outputs['prob_fake_b_is_real']

        self.cycle_images_a = outputs['cycle_images_a']
        self.cycle_images_b = outputs['cycle_images_b']

        self.prob_fake_pool_a_is_real = outputs['prob_fake_pool_a_is_real']
        self.prob_fake_pool_b_is_real = outputs['prob_fake_pool_b_is_real']

    def compute_losses(self):
        """
        In this function we are defining the variables for loss calculations
        and training model.

        d_loss_A/d_loss_B -> loss for discriminator A/B
        g_loss_A/g_loss_B -> loss for generator A/B
        *_trainer -> Various trainer for above loss functions
        *_summ -> Summary variables for above loss functions
        """
        cycle_consistency_loss_a = \
            self._lambda_a * losses.cycle_consistency_loss(
                real_images=self.input_a, generated_images=self.cycle_images_a,
            )
        cycle_consistency_loss_b = \
            self._lambda_b * losses.cycle_consistency_loss(
                real_images=self.input_b, generated_images=self.cycle_images_b,
            )

        lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real)
        lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real)

        g_loss_A = \
            cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b
        g_loss_B = \
            cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a

        d_loss_A = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_a_is_real,
            prob_fake_is_real=self.prob_fake_pool_a_is_real,
        )
        d_loss_B = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_b_is_real,
            prob_fake_is_real=self.prob_fake_pool_b_is_real,
        )

        optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)

        self.model_vars = tf.trainable_variables()

        d_A_vars = [var for var in self.model_vars if 'd_A' in var.name]
        g_A_vars = [var for var in self.model_vars if 'g_A' in var.name]
        d_B_vars = [var for var in self.model_vars if 'd_B' in var.name]
        g_B_vars = [var for var in self.model_vars if 'g_B' in var.name]

        self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
        self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
        self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)
        self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars)

        for var in self.model_vars:
            print(var.name)

        # Summary variables for tensorboard
        self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)
        self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)
        self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)
        self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)

    def save_images(self, sess, epoch):
        """
        Saves input and output images.

        :param sess: The session.
        :param epoch: Currnt epoch.
        """
        if not os.path.exists(self._images_dir):
            os.makedirs(self._images_dir)

        names = ['inputA_', 'inputB_', 'fakeA_',
                 'fakeB_', 'cycA_', 'cycB_']

        with open(os.path.join(
                self._output_dir, 'epoch_' + str(epoch) + '.html'
        ), 'w') as v_html:
            for i in range(0, self._num_imgs_to_save):
                print("Saving image {}/{}".format(i, self._num_imgs_to_save))
                inputs = sess.run(self.inputs)
                filenames = sess.run(self.filenames)
                fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = sess.run([
                    self.fake_images_a,
                    self.fake_images_b,
                    self.cycle_images_a,
                    self.cycle_images_b
                ], feed_dict={
                    self.input_a: [inputs['images_i']],
                    self.input_b: [inputs['images_j']],
                    self.input_ref: [inputs['images_k']]
                })

                tensors = [inputs['images_i'], inputs['images_j'],
                           fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp]
                name_tensors = [str(filenames['filename_i'].decode()),
                                str(filenames['filename_j'].decode()),
                                str(filenames['filename_k'].decode())]
                for filename in name_tensors:
                    v_html.write(
                         filename+"  "
                    )
                v_html.write("<br>")
                index = 0
                for name, tensor in zip(names, tensors):
                    image_name = name + str(epoch) + "_" + str(i)+ ".jpg"
                    imsave(os.path.join(self._images_dir, image_name),
                           ((tensor[0] + 1) * 127.5).astype(np.uint8)
                           )
                    if index<2:
                        image_name = filename[index]
                    else:
                        image_name = os.path.join('imgs', image_name)
                    v_html.write(
                        "<img src=\"" +image_name + "\">"
                    )
                    index+=1
                    if index==2:v_html.write("<br>")
                v_html.write("<br>")

    def fake_image_pool(self, num_fakes, fake, fake_pool):
        """
        This function saves the generated image to corresponding
        pool of images.

        It keeps on feeling the pool till it is full and then randomly
        selects an already stored image and replace it with new one.
        """
        if num_fakes < self._pool_size:
            fake_pool[num_fakes] = fake
            return fake
        else:
            p = random.random()
            if p > 0.5:
                random_id = random.randint(0, self._pool_size - 1)
                temp = fake_pool[random_id]
                fake_pool[random_id] = fake
                return temp
            else:
                return fake

    def train(self):
        """Training Function."""
        # Load Dataset from the dataset folder
        self.inputs, self.filenames = data_loader.load_data(
            self._dataset_name, self._size_before_crop,
            self._do_flipping)

        # Build the network
        self.model_setup()

        # Loss function calculations
        self.compute_losses()

        # Initializing the global variables
        init = (tf.global_variables_initializer(),
                tf.local_variables_initializer())
        saver = tf.train.Saver()

        max_images = cyclegan_datasets.DATASET_TO_SIZES[self._dataset_name]

        with tf.Session() as sess:
            sess.run(init)

            # Restore the model to run the model from last checkpoint
            if self._to_restore:
                chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)
                #chkpt_fname = self.restore_dir
                saver.restore(sess, chkpt_fname)

            writer = tf.summary.FileWriter(self._output_dir)

            if not os.path.exists(self._output_dir):
                os.makedirs(self._output_dir)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            # Training Loop
            for epoch in range(sess.run(self.global_step), self._max_step):
                print("In the epoch ", epoch)
                saver.save(sess, os.path.join(
                    self._output_dir, "cyclegan"), global_step=epoch)

                # Dealing with the learning rate as per the epoch number
                if epoch < 100:
                    curr_lr = self._base_lr
                else:
                    curr_lr = self._base_lr - \
                        self._base_lr * (epoch - 100) / 100

                self.save_images(sess, epoch)

                for i in range(0, max_images):
                    print("Processing batch {}/{}".format(i, max_images))

                    inputs = sess.run(self.inputs)
                    filenams = sess.run(self.filenames)

                    # Optimizing the G_A network
                    _, fake_B_temp, summary_str = sess.run(
                        [self.g_A_trainer,
                         self.fake_images_b,
                         self.g_A_loss_summ],
                        feed_dict={
                            self.input_a:
                                [inputs['images_i']],
                            self.input_b:
                                [inputs['images_j']],
                            self.input_ref:
                                [inputs['images_k']],
                            self.learning_rate: curr_lr
                        }
                    )
                    writer.add_summary(summary_str, epoch * max_images + i)

                    fake_B_temp1 = self.fake_image_pool(
                        self.num_fake_inputs, fake_B_temp, self.fake_images_B)

                    # Optimizing the D_B network
                    _, summary_str = sess.run(
                        [self.d_B_trainer, self.d_B_loss_summ],
                        feed_dict={
                            self.input_a:
                                [inputs['images_i']],
                            self.input_b:
                                [inputs['images_j']],
                            self.input_ref:
                                [inputs['images_k']],
                            self.learning_rate: curr_lr,
                            self.fake_pool_B: fake_B_temp1
                        }
                    )
                    writer.add_summary(summary_str, epoch * max_images + i)

                    # Optimizing the G_B network
                    _, fake_A_temp, summary_str = sess.run(
                        [self.g_B_trainer,
                         self.fake_images_a,
                         self.g_B_loss_summ],
                        feed_dict={
                            self.input_a:
                                [inputs['images_i']],
                            self.input_b:
                                [inputs['images_j']],
                            self.input_ref:
                                [inputs['images_k']],
                            self.learning_rate: curr_lr
                        }
                    )
                    writer.add_summary(summary_str, epoch * max_images + i)

                    fake_A_temp1 = self.fake_image_pool(
                        self.num_fake_inputs, fake_A_temp, self.fake_images_A)

                    # Optimizing the D_A network
                    _, summary_str = sess.run(
                        [self.d_A_trainer, self.d_A_loss_summ],
                        feed_dict={
                            self.input_a:
                                [inputs['images_i']],
                            self.input_b:
                                [inputs['images_j']],
                            self.input_ref:
                                [inputs['images_k']],
                            self.learning_rate: curr_lr,
                            self.fake_pool_A: fake_A_temp1
                        }
                    )
                    writer.add_summary(summary_str, epoch * max_images + i)

                    writer.flush()
                    self.num_fake_inputs += 1

                sess.run(tf.assign(self.global_step, epoch + 1))

            coord.request_stop()
            coord.join(threads)
            writer.add_graph(sess.graph)

    def test(self):
        """Test Function."""
        print("Testing the results")

        self.inputs,self.filenames = data_loader.load_data(
            self._dataset_name, self._size_before_crop,
            self._do_flipping)

        self.model_setup()
        saver = tf.train.Saver()
        init = tf.global_variables_initializer()

        with tf.Session() as sess:
            sess.run(init)
            chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)

            saver.restore(sess, chkpt_fname)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            self._num_imgs_to_save = cyclegan_datasets.DATASET_TO_SIZES[
                self._dataset_name]
            self.save_images(sess, 0)

            coord.request_stop()
            coord.join(threads)


def main(to_train, log_dir, checkpoint_dir):
    """

    :param to_train: Specify whether it is training or testing. 1: training; 2:
     resuming from latest checkpoint; 0: testing.
    :param log_dir: The root dir to save checkpoints and imgs. The actual dir
    is the root dir appended by the folder with the name timestamp.
    :param config_filename: The configuration file.
    :param checkpoint_dir: The directory that saves the latest checkpoint. It
    only takes effect when to_train == 2.
    :param skip: A boolean indicating whether to add skip connection between
    input and output.
    """
    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)

    lambda_a = 10.0
    lambda_b = 10.0
    pool_size = 50

    to_restore = (to_train == 2)
    base_lr = 0.0002
    max_step = 200
    dataset_name = 'lipstick_data'
    do_flipping = False
    skip = False


    cyclegan_model = CycleGAN(pool_size, lambda_a, lambda_b, log_dir,
                              to_restore, base_lr, max_step,
                              dataset_name, checkpoint_dir, do_flipping, skip)

    
    if to_train > 0:
        cyclegan_model.train()
    else:
        cyclegan_model.test()
    

In [2]:

if __name__ == '__main__':
    log_dir = "./log"
    main(0, log_dir,"./log/20190219-163649")


Testing the results
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensor_slices(string_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If `shuffle=False`, omit the `.shuffle(...)`.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensor_slices(input_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If `shuffle=False`, omit the `.shuffle(...)`.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensors(tensor).repeat(num_epochs)`.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.TextLineDataset`.
Instructions for 

`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.


Saving image 1/630
Saving image 2/630
Saving image 3/630
Saving image 4/630
Saving image 5/630
Saving image 6/630
Saving image 7/630
Saving image 8/630
Saving image 9/630
Saving image 10/630
Saving image 11/630
Saving image 12/630
Saving image 13/630
Saving image 14/630
Saving image 15/630
Saving image 16/630
Saving image 17/630
Saving image 18/630
Saving image 19/630
Saving image 20/630
Saving image 21/630
Saving image 22/630
Saving image 23/630
Saving image 24/630
Saving image 25/630
Saving image 26/630
Saving image 27/630
Saving image 28/630
Saving image 29/630
Saving image 30/630
Saving image 31/630
Saving image 32/630
Saving image 33/630
Saving image 34/630
Saving image 35/630
Saving image 36/630
Saving image 37/630
Saving image 38/630
Saving image 39/630
Saving image 40/630
Saving image 41/630
Saving image 42/630
Saving image 43/630
Saving image 44/630
Saving image 45/630
Saving image 46/630
Saving image 47/630
Saving image 48/630
Saving image 49/630
Saving image 50/630
Saving im

Saving image 399/630
Saving image 400/630
Saving image 401/630
Saving image 402/630
Saving image 403/630
Saving image 404/630
Saving image 405/630
Saving image 406/630
Saving image 407/630
Saving image 408/630
Saving image 409/630
Saving image 410/630
Saving image 411/630
Saving image 412/630
Saving image 413/630
Saving image 414/630
Saving image 415/630
Saving image 416/630
Saving image 417/630
Saving image 418/630
Saving image 419/630
Saving image 420/630
Saving image 421/630
Saving image 422/630
Saving image 423/630
Saving image 424/630
Saving image 425/630
Saving image 426/630
Saving image 427/630
Saving image 428/630
Saving image 429/630
Saving image 430/630
Saving image 431/630
Saving image 432/630
Saving image 433/630
Saving image 434/630
Saving image 435/630
Saving image 436/630
Saving image 437/630
Saving image 438/630
Saving image 439/630
Saving image 440/630
Saving image 441/630
Saving image 442/630
Saving image 443/630
Saving image 444/630
Saving image 445/630
Saving image 