**Resource**:[Monet Cycle GAN Tutorial](https://www.kaggle.com/amyjang/monet-cyclegan-tutorial)

# Imports and TPU init

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
    
print(tf.__version__)

# PARAMS

In [None]:
IMG_DIM     = 256   # @param
NB_CHANNEL  = 3     # @param

EPOCHS      = 250   # @param
BUFFER_SIZE = 2048  # @param
BATCH_SIZE  = 1     # @param

GCS_PATH=KaggleDatasets().get_gcs_path()
GCS_PATH=f'{GCS_PATH}/single/'
GCS_PATH

# Dataset 
### **THIS FUNCTION WILL CHANGE**

In [None]:
import os
import matplotlib.pyplot as plt
%matplotlib inline


def data_input_fn(): 
    '''
      This Function generates data from gcs
    '''
    
    def _parser(example):
        feature ={  'image'  : tf.io.FixedLenFeature((),tf.string),
                    'target' : tf.io.FixedLenFeature((),tf.string),
                    'label'  : tf.io.FixedLenFeature((),tf.string)
        }    
        parsed_example=tf.io.parse_single_example(example,feature)
        
        image_raw=parsed_example['image']
        image=tf.image.decode_png(image_raw,channels=NB_CHANNEL)
        image=(tf.cast(image,tf.float32)/127.5)-1
        image=tf.reshape(image,(IMG_DIM,IMG_DIM,NB_CHANNEL))
        


        target_raw=parsed_example['target']
        target=tf.image.decode_png(target_raw,channels=NB_CHANNEL)
        target=(tf.cast(target,tf.float32)/127.5)-1
        target=tf.reshape(target,(IMG_DIM,IMG_DIM,NB_CHANNEL))
        return image,target

    gcs_pattern=os.path.join(GCS_PATH,'*.tfrecord')
    file_paths = tf.io.gfile.glob(gcs_pattern)
    dataset = tf.data.TFRecordDataset(file_paths)
    dataset = dataset.map(_parser)
    dataset = dataset.shuffle(BUFFER_SIZE,reshuffle_each_iteration=True)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset



ds  =   data_input_fn()

for x,y in ds.take(1):
    data=np.squeeze(x[0])
    plt.imshow(data)
    plt.show()
    data=np.squeeze(y[0])
    plt.imshow(data)
    plt.show()
    
    print('Image Batch Shape:',x.shape)
    print('Target Batch Shape:',y.shape)
    

# Networks

In [None]:
def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    result.add(layers.LeakyReLU())

    return result

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())

    return result

def Generator():
    inputs = layers.Input(shape=[256,256,3])

    # bs = batch size
    down_stack = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(3, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image')

    x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

In [None]:
with strategy.scope():
    image_generator = Generator() # transforms targets to image-esque paintings
    target_generator = Generator() # transforms image paintings to be more like targets

    image_discriminator = Discriminator() # differentiates real image paintings and generated image paintings
    target_discriminator = Discriminator() # differentiates real targets and generated targets


# Cycle GAN

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        image_generator,
        target_generator,
        image_discriminator,
        target_discriminator,
        lambda_cycle=10,
    ):
        super(CycleGan, self).__init__()
        self.img_gen = image_generator
        self.tgt_gen = target_generator
        self.img_disc = image_discriminator
        self.tgt_disc = target_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        img_gen_optimizer,
        tgt_gen_optimizer,
        img_disc_optimizer,
        tgt_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.img_gen_optimizer = img_gen_optimizer
        self.tgt_gen_optimizer = tgt_gen_optimizer
        self.img_disc_optimizer = img_disc_optimizer
        self.tgt_disc_optimizer = tgt_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_image, real_target = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # target to image back to target
            fake_image = self.img_gen(real_target, training=True)
            cycled_target = self.tgt_gen(fake_image, training=True)

            # image to target back to image
            fake_target = self.tgt_gen(real_image, training=True)
            cycled_image = self.img_gen(fake_target, training=True)

            # generating itself
            same_image = self.img_gen(real_image, training=True)
            same_target = self.tgt_gen(real_target, training=True)

            # discriminator used to check, inputing real images
            disc_real_image = self.img_disc(real_image, training=True)
            disc_real_target = self.tgt_disc(real_target, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_image = self.img_disc(fake_image, training=True)
            disc_fake_target = self.tgt_disc(fake_target, training=True)

            # evaluates generator loss
            image_gen_loss = self.gen_loss_fn(disc_fake_image)
            target_gen_loss = self.gen_loss_fn(disc_fake_target)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_image, cycled_image, self.lambda_cycle) + self.cycle_loss_fn(real_target, cycled_target, self.lambda_cycle)

            # evaluates total generator loss
            total_image_gen_loss = image_gen_loss + total_cycle_loss + self.identity_loss_fn(real_image, same_image, self.lambda_cycle)
            total_target_gen_loss = target_gen_loss + total_cycle_loss + self.identity_loss_fn(real_target, same_target, self.lambda_cycle)

            # evaluates discriminator loss
            image_disc_loss = self.disc_loss_fn(disc_real_image, disc_fake_image)
            target_disc_loss = self.disc_loss_fn(disc_real_target, disc_fake_target)

        # Calculate the gradients for generator and discriminator
        image_generator_gradients = tape.gradient(total_image_gen_loss,
                                                  self.img_gen.trainable_variables)
        target_generator_gradients = tape.gradient(total_target_gen_loss,
                                                  self.tgt_gen.trainable_variables)

        image_discriminator_gradients = tape.gradient(image_disc_loss,
                                                      self.img_disc.trainable_variables)
        target_discriminator_gradients = tape.gradient(target_disc_loss,
                                                      self.tgt_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.img_gen_optimizer.apply_gradients(zip(image_generator_gradients,
                                                 self.img_gen.trainable_variables))

        self.tgt_gen_optimizer.apply_gradients(zip(target_generator_gradients,
                                                 self.tgt_gen.trainable_variables))

        self.img_disc_optimizer.apply_gradients(zip(image_discriminator_gradients,
                                                  self.img_disc.trainable_variables))

        self.tgt_disc_optimizer.apply_gradients(zip(target_discriminator_gradients,
                                                  self.tgt_disc.trainable_variables))
        return {
            "image_gen_loss": total_image_gen_loss,
            "target_gen_loss": total_target_gen_loss,
            "image_disc_loss": image_disc_loss,
            "target_disc_loss": target_disc_loss
        }


# Loss Functions

In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, 
                                                       reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, 
                                                            reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5
    
    def generator_loss(generated):
        return tf.keras.losses.BinaryCrossentropy(from_logits=True, 
                                                  reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)
    
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

        return LAMBDA * loss1
    
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss

# Optimizers

In [None]:
with strategy.scope():
    image_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    target_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    image_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    target_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# Training

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        image_generator, target_generator, image_discriminator, target_discriminator
    )

    cycle_gan_model.compile(
        img_gen_optimizer = image_generator_optimizer,
        tgt_gen_optimizer = target_generator_optimizer,
        img_disc_optimizer = image_discriminator_optimizer,
        tgt_disc_optimizer = target_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )


In [None]:
# cycle_gan_model.fit(ds,epochs=EPOCHS)

# Save Weights

In [None]:

# image_generator.save_weights("img_gen.h5")
# target_generator.save_weights("tgt_gen.h5") # might come in handy later

    