In [None]:
import tensorflow as tf
tf.test.gpu_device_name()

'/device:GPU:0'

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import cv2
import os
import random

##############################################################################################################
################################################# HELPER FUNCTIONS ###########################################
##############################################################################################################

In [None]:

class helper_functions:
    
    ########################################
    
    def get_weights(self, shape, name):
        with tf.variable_scope('weights', reuse = tf.AUTO_REUSE):
            wt_init = tf.random_normal_initializer()
            return tf.get_variable(name = name, shape = shape, initializer = wt_init)
        
        
        
    ########################################
    
    def get_bias(self, shape, name):
        with tf.variable_scope('biases', reuse = tf.AUTO_REUSE):
            init = tf.constant_initializer(0)
            return tf.get_variable(name = name, shape = shape, initializer = init)
        
        
        
    ########################################
    
    def apply_activations(self, data, activation):
        assert activation in ['relu', 'leaky_relu', 'tanh', 'sigmoid', None]
        
        if activation   == 'relu':
            return tf.nn.relu(data)
        elif activation == 'leaky_relu':
            return tf.contrib.keras.layers.LeakyReLU(0.2)(data)
        elif activation == 'tanh':
            return tf.tanh(data)
        elif activation == 'sigmoid':
            return tf.sigmoid(data)
        else:
            return data
    
    
    
    ########################################
    
    def batch_normalization(self, data, is_training, norm_mode = None):
    
        if norm_mode == 'instance':
            with tf.variable_scope('instance_norm', reuse = tf.AUTO_REUSE):
                eps         = 1e-5
                mean, sigma = tf.nn.moments(data, [1, 2], keep_dims=True)
                normalized  = (data - mean) / (tf.sqrt(sigma) + eps)
                out         = normalized

        elif norm_mode     == 'batch':
            with tf.variable_scope('batch_norm', reuse= tf.AUTO_REUSE):
                out         = tf.contrib.layers.batch_norm(data, decay = 0.99, center = True, scale = True, 
                                                   is_training=is_training, updates_collections = None)
        else:
            out = data

        return out
    

    
    ########################################
    
    def conv_2d_transpose(self, data, weights, strides, output_shape):
#         print('data ' + str(data.get_shape().as_list()))
#         print('weights ' + str(weights.get_shape().as_list()))
#         print('strides ' + str(strides))
#         print('op shape ' + str(output_shape))
        return tf.nn.conv2d_transpose(data, weights, tf.convert_to_tensor(output_shape), strides, 'SAME')
    
    
    
    
    
    ########################################
    
    def conv_block(self, data, weights, name, strides, is_training, norm_mode, activation, bias = None):

        with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
            if bias is not None:
                res   = tf.nn.conv2d(data, filter = weights, strides = strides, padding = 'SAME')
                res   = tf.nn.bias_add(res, bias)
            else:
                res   = conv_2d_transpose()
                
            conv_bn   = self.batch_normalization(res, is_training, norm_mode)
            conv_actv = self.apply_activations(conv_bn, activation)
            return conv_actv
        
    
    
    ########################################
    
    def residual_block(self, data, weights, bias, name, strides, is_training, norm_mode, activation):
        with tf.variable_scope(name, reuse = tf.AUTO_REUSE):
            with tf.variable_scope('res_1', reuse = tf.AUTO_REUSE):
                res_1 = self.conv_block(data, weights[0], 'conv_1', strides, is_training, norm_mode, activation, bias[0])
            
            with tf.variable_scope('res_2', reuse = tf.AUTO_REUSE):
                with tf.variable_scope('conv_2', reuse = tf.AUTO_REUSE):
                    conv_res_2      = tf.nn.conv2d(res_1, weights[1], strides, padding = 'SAME')
                    conv_res_2_bias = tf.nn.bias_add(conv_res_2, bias[1])
                    conv_res_2_bn   = self.batch_normalization(conv_res_2_bias, is_training, norm_mode)
            
            return tf.nn.relu(conv_res_2_bn + data)
        
        
        
    
    ########################################
    
    def deconv_block(self, data, weights, name, strides, is_training, norm_mode, activation, output_shape):
        with tf.variable_scope(name, reuse = tf.AUTO_REUSE):
            deconv_res = self.conv_2d_transpose(data, weights, strides, output_shape)
            deconv_norm = self.batch_normalization(deconv_res, is_training, norm_mode)
            deconv_actv = self.apply_activations(deconv_norm, activation)
            return deconv_actv
    

##############################################################################################################
################################################# GENERATOR  #################################################
##############################################################################################################

In [None]:


class Generator:
    
    ########################################
    
    def __init__(self, name, image_shape, is_training, norm_mode, activation, batch_size, helper_functions):
        self.image_shape   = image_shape
        self.batch_size    = batch_size
        self.helper        = helper_functions
        self.no_res_blocks = 6 if self.image_shape[0] <= 128 else 9
        self.is_training   = is_training
        self.norm_mode     = norm_mode
        self.activation    = activation
        self.g_filters     = [32, 64, 128] if self.image_shape[0] <= 128 else [32, 128, 256]
        self.t_filter_size = 3
        self.t_num_filters = 128 if self.image_shape[0] <= 128 else 256
        self.t_resd_stride = [1, 1, 1, 1]
        self.g_name        = name
        
        
    
    ###########################################################################################################
    
                                        ### ENCODING BLOCK ###
    
    ###########################################################################################################
    
    def Encoding_phase(self, data):

        self.g_en_wts_1  = self.helper.get_weights([7, 7, data.get_shape()[3].value, self.g_filters[0]], 'g_en_wts_1')
        self.g_en_bias_1 = self.helper.get_bias([self.g_filters[0]], 'g_en_bias_1')
        enc_conv_1       = self.helper.conv_block(data, self.g_en_wts_1, 'g_conv_1', [1, 1, 1, 1,],
                                            self.is_training, self.norm_mode, self.activation, self.g_en_bias_1)
    
        
        self.g_en_wts_2  = self.helper.get_weights([3, 3, self.g_filters[0], self.g_filters[1]], 'g_en_wts_2')
        self.g_en_bias_2 = self.helper.get_bias([self.g_filters[1]], 'g_en_bias_2')
        enc_conv_2       = self.helper.conv_block(enc_conv_1, self.g_en_wts_2, 'g_conv_2', [1, 2, 2, 1],
                                            self.is_training, self.norm_mode, self.activation, self.g_en_bias_2)
        
        self.g_en_wts_3  = self.helper.get_weights([3, 3, self.g_filters[1], self.g_filters[2]], 'g_en_wts_3')
        self.g_en_bias_3 = self.helper.get_bias([self.g_filters[2]], 'g_en_bias_3')
        enc_conv_3       = self.helper.conv_block(enc_conv_2, self.g_en_wts_3, 'g_conv_3', [1, 2, 2, 1],
                                            self.is_training, self.norm_mode, self.activation, self.g_en_bias_3)
        
        return enc_conv_3
    
    
    
    ###########################################################################################################
    
                                        ### TRANSFORMATION BLOCK ###
    
    ###########################################################################################################
    
    def add_resd_weights_to_class_obj(self, data):
        
        for res_block_no in range(0, self.no_res_blocks):
            for sub_block_no in range(0, 2):
                channels = self.t_num_filters
                if res_block_no == 0 and sub_block_no == 0:
                    channels = data.get_shape().as_list()[3]
                
                wts_name     = 'g_resd_wts_' + str(res_block_no) + '_' + str(sub_block_no)
                bias_name    = 'g_resd_bias_' + str(res_block_no) + '_' + str(sub_block_no)
                filter_shape = [self.t_filter_size, self.t_filter_size, channels, self.t_num_filters]
            
                weights = self.helper.get_weights(filter_shape, wts_name)
                bias    = self.helper.get_bias([self.t_num_filters], bias_name)
                setattr(self, wts_name, weights)
                setattr(self, bias_name, bias)
    
    

    ##########################################
    
    def Transformation_phase(self, data):
        self.add_resd_weights_to_class_obj(data)
        resd_weights_dict = vars(self)
        
        t_output = data
        
        for res_block_no in range(0, self.no_res_blocks):
            weights_li = []
            bias_li    = []
            wts_name   = 'g_resd_wts_' + str(res_block_no) +'_'
            bias_name  = 'g_resd_bias_' + str(res_block_no) + '_'
            
            weights_li.append(resd_weights_dict[wts_name + str(0)])
            weights_li.append(resd_weights_dict[wts_name + str(1)])
            
            bias_li.append(resd_weights_dict[bias_name + str(0)])
            bias_li.append(resd_weights_dict[bias_name + str(1)])
            t_output = self.helper.residual_block(t_output, weights_li, bias_li, ('g_t_resd_' + str(res_block_no)),
                                            self.t_resd_stride, self.is_training, self.norm_mode, self.activation)
        
        
        return t_output
    
    
    
    ###########################################################################################################
    
                                        ### DECODING BLOCK ###
    
    ###########################################################################################################
    
    def get_filters_and_op_shape(self, data, filter_size, num_filters, stride):
        batch_size, h, w, c = data.get_shape().as_list()
        filter_shape = [filter_size, filter_size, num_filters, c]
        op_shape   = [1, h * stride, w * stride, num_filters]
        return filter_shape, op_shape
    
    
    
    ##########################################
    
    def Decoding_phase(self, data):

        filter_shape, op_shape = self.get_filters_and_op_shape(data, 3, self.g_filters[1], 2)
        
#         print('opshape - Gen - deco ' +  str(op_shape))
#         print('filter_shape ')
        self.g_dec_wts_1 = self.helper.get_weights(filter_shape, 'g_dec_wts_1')
        dec_deconv_1     = self.helper.deconv_block(data, self.g_dec_wts_1, 'g_deconv_1', [1, 2, 2, 1], 
                                                self.is_training, self.norm_mode, self.activation, op_shape)
        
        filter_shape, op_shape = self.get_filters_and_op_shape(dec_deconv_1, 3, self.g_filters[0], 2)
        self.g_dec_wts_2 = self.helper.get_weights(filter_shape, 'g_dec_wts_2')
        dec_deconv_2     = self.helper.deconv_block(dec_deconv_1, self.g_dec_wts_2, 'g_deconv_2', [1, 2, 2, 1], 
                                                self.is_training, self.norm_mode, self.activation, op_shape)
        
        
        filter_shape, op_shape = self.get_filters_and_op_shape(dec_deconv_2, 7, 3, 1)
        self.g_dec_wts_3 = self.helper.get_weights(filter_shape, 'g_dec_wts_3')
        dec_deconv_3     = self.helper.deconv_block(dec_deconv_2, self.g_dec_wts_3, 'g_deconv_3', [1, 1, 1, 1],
                                                self.is_training, None, 'tanh', op_shape)
        
        return dec_deconv_3
        
        
    
    ###########################################################################################################
    
                                        ###  FORWARD PROPAGATION ###
    
    ###########################################################################################################
    
    def g_feed_forward(self, data):
        with tf.variable_scope(self.g_name, reuse = tf.AUTO_REUSE):
            encoding_res       = self.Encoding_phase(data)
            transformation_res = self.Transformation_phase(encoding_res)
            decoding_res       = self.Decoding_phase(transformation_res)
            
            self.g_variables     = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.g_name)
            return decoding_res
        
        
            

##############################################################################################################
#################################################  DISCRIMINATOR ##########################################
##############################################################################################################

In [None]:

class Discriminator:
    def __init__(self, name, is_training, norm, activation, helper_functions):
        self.d_name          = name
        self.is_training   = is_training
        self.norm_mode     = norm
        self.activation    = activation
        self.d_filter_size = 4
        self.d_num_filters = [64, 128, 256, 512, 1]
        self.conv_stride   = [1, 2, 2, 1]
        self.helper = helper_functions
        
        
    def d_feed_forward(self, data):
        with tf.variable_scope(self.d_name, reuse = tf.AUTO_REUSE):
            self.d_wts_1  = self.helper.get_weights([4, 4, data.get_shape()[3].value, self.d_num_filters[0]], 'd_wts_1')
            self.d_bias_1 = self.helper.get_bias([self.d_num_filters[0]], 'd_bias_1')
            disc_conv_1   = self.helper.conv_block(data, self.d_wts_1, 'd_conv_1', self.conv_stride,
                                                self.is_training, self.norm_mode, self.activation, self.d_bias_1)


            self.d_wts_2  = self.helper.get_weights([4, 4, self.d_num_filters[0], self.d_num_filters[1]], 'd_wts_2')
            self.d_bias_2 = self.helper.get_bias([self.d_num_filters[1]], 'd_bias_2')
            disc_conv_2       = self.helper.conv_block(disc_conv_1, self.d_wts_2, 'd_conv_2', self.conv_stride,
                                                self.is_training, self.norm_mode, self.activation, self.d_bias_2)

            self.d_wts_3  = self.helper.get_weights([4, 4, self.d_num_filters[1], self.d_num_filters[2]], 'd_wts_3')
            self.d_bias_3 = self.helper.get_bias([self.d_num_filters[2]], 'd_bias_3')
            disc_conv_3       = self.helper.conv_block(disc_conv_2, self.d_wts_3, 'd_conv_3', self.conv_stride,
                                                self.is_training, self.norm_mode, self.activation, self.d_bias_3)


            self.d_wts_4 = self.helper.get_weights([4, 4, self.d_num_filters[2], self.d_num_filters[3]], 'd_wts_4')
            self.d_bias_4 = self.helper.get_bias([self.d_num_filters[3]], 'd_bias_4')
            disc_conv_4  = self.helper.conv_block(disc_conv_3, self.d_wts_4, 'd_conv_4', self.conv_stride,
                                                self.is_training, self.norm_mode, self.activation, self.d_bias_4)

            self.d_wts_5 = self.helper.get_weights([4, 4, self.d_num_filters[3], self.d_num_filters[4]], 'd_wts_5')
            self.d_bias_5 = self.helper.get_bias([self.d_num_filters[4]], 'd_bias_5')
            disc_conv_5  = self.helper.conv_block(disc_conv_4, self.d_wts_5, 'd_conv_5', [1, 1, 1, 1], 
                                                self.is_training, None, None, self.d_bias_5)

            conv_5_res = tf.reduce_mean(disc_conv_5, axis = [1, 2, 3])
            self.d_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.d_name)
            
            return conv_5_res

##############################################################################################################
################################################# TRAINING CYCLEGAN ##########################################
##############################################################################################################

In [None]:

class cycleGAN:
    def __init__(self, image_shape, batch_size, learning_rate, train_data_path, test_data_path, num_epochs, 
                                             to_restore, helper_functions, output_dir):
        self.image_shape = image_shape
        self.batch_size  = batch_size
        self.is_training = tf.placeholder(tf.bool, name = 'is_training')
        self.lr = learning_rate
        self.X = tf.placeholder(tf.float32, [None, self.image_shape[0], self.image_shape[1], self.image_shape[2]])
        self.Y = tf.placeholder(tf.float32, [None, self.image_shape[0], self.image_shape[1], self.image_shape[2]])
        self.train_data_path = train_data_path
        self.test_data_path = test_data_path
        self.helper = helper_functions
        self.num_of_epochs = num_epochs
        self.to_restore = to_restore
        self.output_dir = output_dir
        self.check_point_dir = self.output_dir + '/checkpoints/'
        self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
        
        
    def get_input_batches(self, images_batch_names, path):
        images_arr = []
        for image_name in images_batch_names:
            
            image = cv2.resize(cv2.imread(path + image_name), (256, 256))
            img = cv2.normalize(image, None, 0, 128, cv2.NORM_MINMAX)
            images_arr.append(img)
        return np.asarray(images_arr)
    
    
    def get_optimizers(self):
        
        with tf.variable_scope('optim', reuse = tf.AUTO_REUSE):
            '''
            Gen_XY transforms images from X to Y (e.g., Horse -> Zebra)
            Gen_YX transforms images from Y to X (e.g., Zebra -> Horse)

            Disc_X: scores how real an image of X looks (e.g. does this image look like a Horse?)
            Disc_Y: scores how real an image of Y looks (e.g. does this image look like a Zebra?)

            Generators generate fake images from originals
            Discriminators discriminate whether the image is real or fake
            '''
            Gen_XY = Generator('Gen_XY', self.image_shape, self.is_training, 'instance', 'relu', self.batch_size, self.helper)
            Gen_YX = Generator('Gen_YX', self.image_shape, self.is_training, 'instance', 'relu', self.batch_size, self.helper)

            Disc_X = Discriminator('D_X', self.is_training, 'instance', 'leaky_relu', self.helper)
            Disc_Y = Discriminator('D_Y', self.is_training, 'instance', 'leaky_relu', self.helper)


            self.gen_xy = Gen_XY.g_feed_forward(self.X) # Generates fake_Y from X
            self.gen_yx = Gen_YX.g_feed_forward(self.Y) # Generates fake_X from Y

    #         disc_x = Disc_X.feed_forward(self.train_X)
    #         disc_y = Disc_Y.feed_forward(self.train_Y)

            gen_xyx = Gen_YX.g_feed_forward(self.gen_xy) # transforming fake_Y back to X
            gen_yxy = Gen_XY.g_feed_forward(self.gen_yx) # transforming fake_X back to Y

            disc_real_x = Disc_X.d_feed_forward(self.X) # probablilty of real X to be real
            disc_real_y = Disc_Y.d_feed_forward(self.Y) # probability of reay Y to be real

            disc_fake_x = Disc_X.d_feed_forward(self.gen_yx) # discriminating fake X (generated by Gen_YX)
            disc_fake_y = Disc_Y.d_feed_forward(self.gen_xy) # discriminating fake Y (generated by Gen_XY)

            '''
            loss functions should comprise following:
                - Discriminator should approve all the images from the dataset (train_X and train_Y)
                - Discriminator should disapprove all the generated images (gen_xy and gen_yx)
                - Generators should make the discriminators approve all the generated images
                - The generated (fake) image should retain the properties of the original image i.e. if we generate a
                  fake image using Gen_XY and transformed to original image using Gen_YX, the output of Gen_YX should possess
                  the properties of the train_X image, thus satisfying cyclic-consistency. 
            '''

            # Discriminator loss
            loss_disc_real_x = tf.reduce_mean(tf.squared_difference(disc_real_x, 1))
            loss_disc_real_y = tf.reduce_mean(tf.squared_difference(disc_real_y, 1))

            loss_disc_fake_x = tf.reduce_mean(tf.square(disc_fake_x))
            loss_disc_fake_y = tf.reduce_mean(tf.square(disc_fake_y))

            self.disc_X_total_loss = (loss_disc_real_x + loss_disc_fake_x) / 2
            self.disc_Y_total_loss = (loss_disc_real_y + loss_disc_fake_y) / 2

            # Generator loss
            '''
            Generator should be successful in fooling the discriminator. In other words, Generator should make the discriminator
            believe that the generated images are real images. This can be done if the recommendation made by the Discriminator
            is as close to 1 as possible. so, Generator would like to minimize ((Discriminator_X(Generator_YX(image))) - 1)^2
            '''
            loss_gen_yx = tf.reduce_mean(tf.squared_difference(disc_fake_x, 1))
            loss_gen_xy = tf.reduce_mean(tf.squared_difference(disc_fake_y, 1))

            self.loss_cycle = tf.reduce_mean(tf.abs(self.X - gen_xyx) + tf.abs(self.Y - gen_yxy))

            '''multiplied loss_cycle with 10 so as to give more importance to cycle loss than to discriminator loss'''
            self.gen_xy_total_loss = loss_gen_xy + (10 * self.loss_cycle)
            self.gen_yx_total_loss = loss_gen_yx + (10 * self.loss_cycle)

            trainable_variables = tf.trainable_variables()
            G_XY_vars = [var for var in trainable_variables if 'Gen_XY' in var.name]
            G_YX_vars = [var for var in trainable_variables if 'Gen_YX' in var.name]
            D_X_vars = [var for var in trainable_variables if 'D_X' in var.name]
            D_Y_vars = [var for var in trainable_variables if 'D_Y' in var.name]
            
            
            # optimizers
            self.Gen_XY_optim = tf.train.AdamOptimizer(learning_rate = self.lr, beta1 = 0.5).minimize(self.gen_xy_total_loss, var_list = G_XY_vars)
            self.Gen_YX_optim = tf.train.AdamOptimizer(learning_rate = self.lr, beta1 = 0.5).minimize(self.gen_yx_total_loss, var_list = G_YX_vars)
            self.Disc_X_optim = tf.train.AdamOptimizer(learning_rate = self.lr, beta1 = 0.5).minimize(self.disc_X_total_loss, var_list = D_X_vars)
            self.Disc_Y_optim = tf.train.AdamOptimizer(learning_rate = self.lr, beta1 = 0.5).minimize(self.disc_Y_total_loss, var_list = D_Y_vars)

        
        
    
    
    def train_cycleGAN(self):
    
        print('training the cycle Gan')
        self.get_optimizers()
        
            
        train_X_images_path = self.train_data_path +'/trainA/'
        train_Y_images_path = self.train_data_path +'/trainB/'

        train_X_images_names = os.listdir(train_X_images_path)
        train_Y_images_names = os.listdir(train_Y_images_path)
        num_of_train_X_images = len(train_X_images_names)
        num_of_train_Y_images = len(train_Y_images_names)

        num_of_train_images = min(num_of_train_X_images, num_of_train_Y_images)
        num_of_images_per_batch = num_of_train_images // self.batch_size

        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

        for epoch in range(self.session.run(self.global_step), self.num_of_epochs):

            if self.to_restore:
                checkpoint_file = tf.train.latest_checkpoint(self.check_point_dir)
                saver.restore(self.session, checkpoint_file)

            if not os.path.exists(self.check_point_dir):
                os.makedirs(self.check_point_dir)

            random.shuffle(train_X_images_names)
            random.shuffle(train_Y_images_names)
            if epoch >= 100:
                self.lr = self.lr - self.lr * (epoch-100)/100

            for index in range(0, num_of_images_per_batch, self.batch_size):
                
                train_x = self.get_input_batches(train_X_images_names[index : index + self.batch_size], train_X_images_path)
                train_y = self.get_input_batches(train_Y_images_names[index : index + self.batch_size], train_Y_images_path)

                # training Gen_XY_optim to generate fake Y images from true X
                _, fake_Y = self.session.run([self.Gen_XY_optim, self.gen_xy], feed_dict = {self.X : train_x,
                                                                                     self.Y : train_y,
                                                                                     self.is_training : True
                                                                                     })
                # training Gen_YX_optim to generate fake X images from true Y
                _, fake_X = self.session.run([self.Gen_YX_optim, self.gen_yx], feed_dict = {self.X : train_x,
                                                                                         self.Y : train_y,
                                                                                         self.is_training : True})
                _ = self.session.run([self.Disc_X_optim], feed_dict = {self.X : train_x,
                                                                    self.Y : train_y,
                                                                    self.is_training : True})

                _ = self.session.run([self.Disc_Y_optim], feed_dict = {self.X : train_x,
                                                                    self.Y : train_y,
                                                                    self.is_training : True})


                print('epoch ' + str(epoch))
                self.session.run(tf.assign(self.global_step, epoch + 1))
                saver.save(self.session,self.output_dir + '/cycleGAN', global_step = epoch)

                
    
    def create_directory(self, directory):
        if not os.path.exists(directory):
            os.makedirs(directory)
                
    
    def apply_transformtions_on_test_images(self):
       
        print('testing the test images')
        
        test_images_output_dir = self.output_dir + '/test/'
        
        fake_X_path = test_images_output_dir + 'fake_X/'
        fake_Y_path = test_images_output_dir + 'fake_Y/'
        test_X_path = test_images_output_dir + 'test_X/'
        test_Y_path = test_images_output_dir + 'test_Y/'
        
        self.create_directory(test_images_output_dir)
        self.create_directory(fake_X_path)
        self.create_directory(fake_Y_path)
        self.create_directory(test_X_path)
        self.create_directory(test_Y_path)
        
        
        
        test_X_images_path = self.test_data_path +'/testA/'
        test_Y_images_path = self.test_data_path + '/testB/'
        
        test_X_images_names = os.listdir(test_X_images_path)
        test_Y_images_names = os.listdir(test_Y_images_path)
        
        num_of_test_images = min(100, len(test_X_images_names), len(test_Y_images_names))
        with tf.Session() as session:
            meta_graph = tf.train.import_meta_graph(self.output_dir + '/cycleGAN-9.meta')
            meta_graph.restore(session, tf.train.latest_checkpoint('./outputs/'))
            
            random.shuffle(test_X_images_names)
            random.shuffle(test_Y_images_names)
            for index in range(0, num_of_test_images):
                
                
                test_x = self.get_input_batches(test_X_images_names[index : index + self.batch_size], test_X_images_path)
                test_y = self.get_input_batches(test_Y_images_names[index : index + self.batch_size], test_Y_images_path)

                #                 print('test_X ' + np.shape(test_x))
                #                 print('test_Y ' + np.shape(test_Y))
                fake_X, fake_Y = session.run([self.gen_yx, self.gen_xy], feed_dict = {self.X : test_x,
                                                                                    self.Y : test_y,
                                                                                    self.is_training : False})

                print('fake X ' + str(np.shape(fake_X)))
                print('fake Y ' + str(np.shape(fake_Y)))
                cv2.imwrite((fake_X_path + str(index) + '_fx.png'), np.uint8(fake_X[0]))
                cv2.imwrite((fake_Y_path + str(index) + '_fy.png'), np.uint8(fake_Y[0]))
                cv2.imwrite((test_X_path + str(index) + '_tx.png'), np.uint8(test_x[0]))
                cv2.imwrite((test_Y_path + str(index) + '_ty.png'), np.uint8(test_y[0]))
                #                   count = count + 1


In [None]:


if __name__ == '__main__':
    image_shape = [256, 256, 3]
    batch_size  = 1
    learning_rate = 0.0002
    train_data_path = '/content/drive/My Drive/horse2zebra'
    test_data_path = '/content/drive/My Drive/horse2zebra'
    num_epochs = 10
    to_restore = False
    hf = helper_functions()
    output_dir = './outputs'
    
    cycle_GAN = cycleGAN(image_shape, batch_size, learning_rate, train_data_path, test_data_path, num_epochs, 
                        to_restore, hf, output_dir)
    cycle_GAN.train_cycleGAN()
    cycle_GAN.apply_transformtions_on_test_images()