# TODO:
- [x] remove strategy scopes
- [ ] format loss functions : skipped : directly use [official example](https://www.tensorflow.org/tutorials/generative/cyclegan)
- [x] **change network names as per checkpoint need** (simpler name like gen_x,gen_y and so on)<- same as official
- [x] take model definitations from example (pix2pix). No need to increase notebook lenght <- same as official
- [x] add visuals (epochwise improvement )
- [x] Modify Official Example for summary writter <- not in official

In [None]:
!pip install -q git+https://github.com/tensorflow/examples.git

In [None]:
from tensorflow_examples.models.pix2pix import pix2pix

# Imports

In [None]:
import os
import time
import matplotlib.pyplot as plt
import tensorflow as tf

import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from IPython.display import clear_output
from kaggle_datasets import KaggleDatasets

# what the hell happened to experimental??
AUTOTUNE = tf.data.AUTOTUNE
tf.__version__

# Params and DataPath

In [None]:
IMG_DIM     = 256   # @param
NB_CHANNEL  = 3     # @param

EPOCHS      = 50    # @param
BUFFER_SIZE = 2048  # @param
BATCH_SIZE  = 1     # @param

GCS_PATH=KaggleDatasets().get_gcs_path()
GCS_PATH=f'{GCS_PATH}/single/'
GCS_PATH

# Dataset (No prefetch)--> could use cache 

In [None]:


def data_input_fn(): 
    '''
      This Function generates data from gcs
    '''
    
    def _parser(example):
        feature ={  'image'  : tf.io.FixedLenFeature((),tf.string),
                    'target' : tf.io.FixedLenFeature((),tf.string),
                    'label'  : tf.io.FixedLenFeature((),tf.string)
        }    
        parsed_example=tf.io.parse_single_example(example,feature)
        
        image_raw=parsed_example['image']
        image=tf.image.decode_png(image_raw,channels=NB_CHANNEL)
        # gan normal
        image=(tf.cast(image,tf.float32)/127.5)-1
        image=tf.reshape(image,(IMG_DIM,IMG_DIM,NB_CHANNEL))
        


        target_raw=parsed_example['target']
        target=tf.image.decode_png(target_raw,channels=NB_CHANNEL)
        # gan normal
        target=(tf.cast(target,tf.float32)/127.5)-1
        target=tf.reshape(target,(IMG_DIM,IMG_DIM,NB_CHANNEL))
        return image,target

    gcs_pattern=os.path.join(GCS_PATH,'*.tfrecord')
    file_paths = tf.io.gfile.glob(gcs_pattern)
    dataset = tf.data.TFRecordDataset(file_paths)
    dataset = dataset.map(_parser)
    dataset = dataset.shuffle(BUFFER_SIZE,reshuffle_each_iteration=True)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset



ds  =   data_input_fn()

for x,y in ds.take(1):
    data=np.squeeze(x[0])
    plt.imshow(data)
    plt.show()
    data=np.squeeze(y[0])
    plt.imshow(data)
    plt.show()
    
    print('Image Batch Shape:',x.shape)
    print('Target Batch Shape:',y.shape)

# Sample to observe during training

In [None]:
sample_img,_ = next(iter(ds))
plt.imshow(sample_img[0] * 0.5 + 0.5)
plt.show()

# Networks (pix2pix)

In [None]:
OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

# Loss Functions 

In [None]:
LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, generated):
    real_loss = loss_obj(tf.ones_like(real), real)

    generated_loss = loss_obj(tf.zeros_like(generated), generated)

    total_disc_loss = real_loss + generated_loss

    return total_disc_loss * 0.5

def generator_loss(generated):
    return loss_obj(tf.ones_like(generated), generated)

def calc_cycle_loss(real_image, cycled_image):
    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

    return LAMBDA * loss1

def identity_loss(real_image, same_image):
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return LAMBDA * 0.5 * loss

# Optimizers

In [None]:
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# logs and checkpoints

In [None]:
import datetime
log_dir="logs/"
# for tensorboard
summary_writer = tf.summary.create_file_writer(log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)

# Visualize During Training 

In [None]:
def generate_images(model, test_input):
    prediction = model(test_input)

    plt.figure(figsize=(12, 12))

    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']

    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

# Custom loop 
* **still using property decorator** 
* **where is mirrored strategy?** 

In [None]:
@tf.function
def train_step(real_x, real_y,epoch):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
    with tf.GradientTape(persistent=True) as tape:
        # Generator G translates X -> Y
        # Generator F translates Y -> X.

        fake_y = generator_g(real_x, training=True)
        cycled_x = generator_f(fake_y, training=True)

        fake_x = generator_f(real_y, training=True)
        cycled_y = generator_g(fake_x, training=True)

        # same_x and same_y are used for identity loss.
        same_x = generator_f(real_x, training=True)
        same_y = generator_g(real_y, training=True)

        disc_real_x = discriminator_x(real_x, training=True)
        disc_real_y = discriminator_y(real_y, training=True)

        disc_fake_x = discriminator_x(fake_x, training=True)
        disc_fake_y = discriminator_y(fake_y, training=True)

        # calculate the loss
        gen_g_loss = generator_loss(disc_fake_y)
        gen_f_loss = generator_loss(disc_fake_x)

        total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

        # Total generator loss = adversarial loss + cycle loss
        total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
        total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

        disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
        disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

    # Calculate the gradients for generator and discriminator
    generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
    generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)

    discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
    discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)

    # Apply the gradients to the optimizer
    generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

    generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))

    discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))

    discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))
    
    with summary_writer.as_default():
        tf.summary.scalar('gen_g_loss', gen_g_loss, step=epoch)
        tf.summary.scalar('gen_f_loss', gen_f_loss, step=epoch)
        tf.summary.scalar('total_cycle_loss', total_cycle_loss, step=epoch)
        tf.summary.scalar('total_gen_g_loss', total_gen_g_loss, step=epoch)
        tf.summary.scalar('total_gen_f_loss', total_gen_f_loss, step=epoch)
        tf.summary.scalar('disc_x_loss', disc_x_loss, step=epoch)
        tf.summary.scalar('disc_y_loss', disc_y_loss, step=epoch)
        


# Tensorboard 

In [None]:
%load_ext tensorboard
%tensorboard --logdir {log_dir}

# Training 
- [ ] save version before full execution

In [None]:

for epoch in range(EPOCHS):
    start = time.time()
    n = 0
    for image_x, image_y in ds:
        train_step(image_x, image_y,epoch)
        if n % 100 == 0:
            print ('>', end='')
        n+=1
    clear_output(wait=True)
    # Using a consistent image  so that the progress of the model is clearly visible.
    generate_images(generator_g, sample_img)

    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,ckpt_save_path))

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,time.time()-start))


