# Cycle GAN

This is an attempt to re-implement the paper Cycle Gan

Paper: https://arxiv.org/pdf/1703.10593.pdf

Other Resources: 
* https://www.tensorflow.org/tutorials/generative/cyclegan

In [1]:
import glob
import tensorflow as tf

In [11]:
image_size = 256
channels = 3
num_images = 500
batch_size = 1 #2

In [12]:
def load_images(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels = channels)
    
    image = tf.cast(image, tf.float32)
    
    image = (image/127.5) - 1
    
    image = tf.image.resize(image, [image_size, image_size], method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    return image

def load_data(path, num_images = num_images):
    image_list = glob.glob(path)[:num_images]
    
    data = tf.data.Dataset.list_files(image_list)
    data = data.map(load_images)
    data = data.shuffle(num_images).batch(batch_size)
    return data

In [13]:
DIR = 'E:\Image Datasets\Horses_2_Zebras'
train_horses = load_data(DIR + '\\Train\\Horse\\*.jpg')
train_zebras = load_data(DIR  + '\\Train\\Zebra\\*.jpg')
# len(train_horses), len(train_zebras)

In [14]:
class CycleGAN(object):
    def __init__(self, image_size, channels):
        self.init = tf.random_normal_initializer(0, 0.02)
        self.image_shape = (image_size, image_size, channels)
        
    def encoder(self, inp, filters, kernel_size, strides):
        
        x = tf.keras.layers.Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, 
                                   padding = 'same', kernel_initializer = self.init, use_bias = True)(inp)
        
        x = tf.keras.layers.BatchNormalization()(x)
        
        x = tf.keras.layers.ReLU()(x)
        return x
    
    def decoder(self, inp, filters, kernel_size ,strides):
        
        x = tf.keras.layers.Conv2DTranspose(filters = filters, kernel_size = kernel_size, strides = strides, 
                                            padding = 'same', kernel_initializer = self.init, use_bias = True)(inp)
        
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        return x
    
    def residual_block(self, inp, filters = 256, kernel_size = (3, 3), strides = (1, 1)):
        
        x = tf.keras.layers.Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, 
                                   padding = 'same', kernel_initializer = self.init, use_bias = True)(inp)
        x = tf.keras.layers.Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, 
                                   padding = 'same', kernel_initializer = self.init, use_bias = True)(x)
        
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        
        x = tf.keras.layers.Concatenate()([x, inp])
        return x
    
    # for discriminator
    def downsample(self, inp, filters, kernel_size = (4, 4), strides = (2, 2), padding = 'same'):
        
        x = tf.keras.layers.Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, 
                                   padding = padding, kernel_initializer = self.init, use_bias = True)(inp)
        
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        return x
    
    
class Generator(CycleGAN):
    def __init__(self, image_size, channels):
        super().__init__(image_size, channels)
        pass
    
    def generator_a2b(self):
        
        inp = tf.keras.layers.Input(shape = self.image_shape)
        
        x = self.encoder(inp, 64, (7, 7), (1, 1))
        x = self.encoder(x, 128, (3, 3), (2, 2))
        x = self.encoder(x, 256, (3, 3), (2, 2))
        
        for _ in range(6):
            x = self.residual_block(x)
            
        x = self.decoder(x, 128, (3, 3), (2, 2))
        x = self.decoder(x, 64, (3, 3), (2, 2))
        x = self.decoder(x, 3, (7, 7), (1, 1))
        
        return tf.keras.Model(inputs = inp, outputs = x)
    
    
    def generator_b2a(self):
        
        inp = tf.keras.layers.Input(shape = self.image_shape)
        
        x = self.encoder(inp, 64, (7, 7), (1, 1))
        x = self.encoder(x, 128, (3, 3), (2, 2))
        x = self.encoder(x, 256, (3, 3), (2, 2))
        
        for _ in range(6):
            x = self.residual_block(x)
            
        x = self.decoder(x, 128, (3, 3), (2, 2))
        x = self.decoder(x, 64, (3, 3), (2, 2))
        x = self.decoder(x, 3, (7, 7), (1, 1))
        
        return tf.keras.Model(inputs = inp, outputs = x)
    
    
    
class Discriminator(CycleGAN):
    def __init__(self, image_size, channels):
        super().__init__(image_size, channels)
        pass
    
    
    def discriminator_a(self):
        inp = tf.keras.layers.Input(shape = self.image_shape)
        
        x = self.downsample(inp, 64)
        x = self.downsample(x, 128)
        x = self.downsample(x, 256)
        x = self.downsample(x, 512, (3, 3), (1, 1))
        x = self.downsample(x, 1, (3, 3), (1, 1), padding = 'valid')
        
        return tf.keras.Model(inputs = inp, outputs = x)
        
        
    def discriminator_b(self):
        inp = tf.keras.layers.Input(shape = self.image_shape)
        
        x = self.downsample(inp, 64)
        x = self.downsample(x, 128)
        x = self.downsample(x, 256)
        x = self.downsample(x, 512, (3, 3), (1, 1))
        x = self.downsample(x, 1, (3, 3), (1, 1), padding = 'valid')
        
        return tf.keras.Model(inputs = inp, outputs = x)
        
        
gen_a2b = Generator(image_size, channels).generator_a2b()
gen_b2a = Generator(image_size, channels).generator_b2a()

disc_a = Discriminator(image_size, channels).discriminator_a()
disc_b = Discriminator(image_size, channels).discriminator_b()

In [15]:
class Losses(object):
    def __init__(self):
        self.loss_object = tf.keras.losses.BinaryCrossentropy(from_logits = True)
        self.LAMBDA = 10
        
    def generator_loss(self, disc_gen_out):
        return self.loss_object(tf.ones_like(disc_gen_out), disc_gen_out)
    
    def discriminator_loss(self, disc_real_out, disc_gen_out):
        real_loss = self.loss_object(tf.ones_like(disc_real_out), disc_real_out)
        gen_loss = self.loss_object(tf.zeros_like(disc_gen_out), disc_gen_out)
        return real_loss + gen_loss
    
    def cycle_loss(self, real_image, cycled_image):
        return tf.math.reduce_mean(tf.math.abs(real_image - cycled_image)) * self.LAMBDA * 0.5
    
    def identity_loss(self, real_image, identity_image):
        return tf.math.reduce_mean(tf.math.abs(real_image - identity_image)) * self.LAMBDA
    
l = Losses()

In [16]:
gen_a2b_optimizer = tf.keras.optimizers.Adam(learning_rate = 2e-4, beta_1 = 0.5)
gen_b2a_optimizer = tf.keras.optimizers.Adam(learning_rate = 2e-4, beta_1 = 0.5)

disc_a_optimizer = tf.keras.optimizers.Adam(learning_rate = 2e-4, beta_1 = 0.5)
disc_b_optimizer = tf.keras.optimizers.Adam(learning_rate = 2e-4, beta_1 = 0.5)

In [17]:
@tf.function
def train_step(image_a, image_b):
    
    with tf.GradientTape(persistent = True) as tape:
        gen_out_b = gen_a2b(image_a, training = True)
        gen_out_a = gen_b2a(image_b, training = True)
        
        cycled_a = gen_b2a(gen_out_b, training = True)
        cycled_b = gen_a2b(gen_out_a, training = True)
        
        identity_a = gen_b2a(image_a, training = True)
        identity_b = gen_a2b(image_b, training = True)
        
        disc_real_out_a = disc_a(image_a, training = True)
        disc_gen_out_a = disc_a(gen_out_a, training = True)
        
        disc_real_out_b = disc_b(image_b, training = True)
        disc_gen_out_b = disc_b(gen_out_b, training = True)
        
        gen_loss_a2b = l.generator_loss(disc_gen_out_b)
        gen_loss_b2a = l.generator_loss(disc_gen_out_a)
        
        tot_cycle_loss = l.cycle_loss(image_b, cycled_b) + l.cycle_loss(image_a, cycled_a)
        
        tot_gen_loss_a2b = gen_loss_a2b + tot_cycle_loss + l.identity_loss(image_b, identity_b)
        tot_gen_loss_b2a = gen_loss_b2a + tot_cycle_loss + l.identity_loss(image_a, identity_a)
        
        disc_loss_a = l.discriminator_loss(disc_real_out_a, disc_gen_out_a)
        disc_loss_b = l.discriminator_loss(disc_real_out_b, disc_gen_out_b)
        
    gen_grads_a2b = tape.gradient(tot_gen_loss_a2b, gen_a2b.trainable_variables)
    gen_a2b_optimizer.apply_gradients(zip(gen_grads_a2b, gen_a2b.trainable_variables))
    
    gen_grads_b2a = tape.gradient(tot_gen_loss_b2a, gen_b2a.trainable_variables)
    gen_b2a_optimizer.apply_gradients(zip(gen_grads_b2a, gen_b2a.trainable_variables))
    
    disc_grads_a = tape.gradient(disc_loss_a, disc_a.trainable_variables)
    disc_a_optimizer.apply_gradients(zip(disc_grads_a, disc_a.trainable_variables))
    
    disc_grads_b = tape.gradient(disc_loss_b, disc_b.trainable_variables)
    disc_b_optimizer.apply_gradients(zip(disc_grads_b, disc_b.trainable_variables))
    
    return tot_gen_loss_a2b, tot_gen_loss_b2a, disc_loss_a, disc_loss_b

In [18]:
def train(train_A, train_B, epochs = 1):
    for e in range(epochs):
        print(f'Epoch {e} starts')
        
        for n, (A, B) in enumerate(zip(train_A, train_B)):
            tot_gen_loss_a2b, tot_gen_loss_b2a, disc_loss_a, disc_loss_b = train_step(A, B)
            
            if n%10 == 0:
                print('.', end='')
        
        print('\n\nGenerator Loss A2B: {} \t Discriminator Loss B: {} \nGenerator Loss B2A: {} \t Discriminator Loss: {}\n'.format(
            tot_gen_loss_a2b, disc_loss_b, tot_gen_loss_b2a, disc_loss_a))
        print(f'Epoch {e} ends \n\n\n')

In [19]:
train(train_horses, train_zebras)

Epoch 0 starts
....

KeyboardInterrupt: 