**TensorFlow tuto:** https://www.tensorflow.org/tutorials/generative/cyclegan 

In [None]:
import tornado
print(tornado.version)

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import tensorflow as tf
print(tf.__version__)
tf.config.run_functions_eagerly(True)

# Main libraries

In [None]:
#allows to import generator and discriminator
!pip install -q git+https://github.com/tensorflow/examples.git

In [None]:
import gc
import tensorflow as tf
#import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
from os import listdir
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import img_to_array
from numpy import vstack
from numpy import asarray
from numpy import savez_compressed
import numpy as np
from PIL import Image
from tqdm import tqdm
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tensorflow import keras
import gc

AUTOTUNE = tf.data.AUTOTUNE

# Load and preprocess dataset

**Needed the first time**

In [None]:
str2idx = {
    'CuNi1': 0,
    'CuNi2': 1,
    'CuNi3':2
}

idx2str = {
    0: 'CuNi1',
    1: 'CuNi2',
    2: 'CuNi3'
}

In [None]:
def ohe_class(index):
    ohe_label = np.zeros(3, dtype=int)
    ohe_label[index] = 1
    return ohe_label

In [None]:
# load all images in a directory into memory
def load_images(path, size=(256,256), rgb=False):
    data_list = list()
    label_list = list()

    if rgb==False:
        color_mode = "grayscale"
    else:
        color_mode = "rgb"
    # enumerate filenames in directory, assume all are images
    for filename in tqdm(os.listdir(path)):
        clase = filename.split('_')[0]
        # load and resize the image
        pixels = load_img(path + filename, target_size=size, color_mode= color_mode)
        # convert to numpy array
        pixels = img_to_array(pixels)
        # store
        data_list.append(pixels)

        #for labels
        clase = filename.split('_')[0]
        indx = str2idx[clase]
        #get ohe from index
        ohe_label = ohe_class(indx)
        label_list.append(ohe_label)

    return asarray(data_list), label_list

**Set up**

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
#experiment = 'fineTunedModel'
#model = 'vgg16'

In [None]:
"""Frames loading.
Sets domains: A for white light and B for NBI
rgb parameter sets as False for work whit grayscale images
"""
# dataset path
path = "../../nanoespumas/data/"
# load dataset white light
# here A: white light, B: nbi light
train_dry_imgs, train_dry_labels = load_images(path + 'train_dry/', rgb= True)
test_dry_imgs, test_dry_labels = load_images(path + 'test_dry/', rgb= True)

# load dataset B
train_wet_imgs, train_wet_labels = load_images(path + 'train_wet/', rgb= True)
test_wet_imgs, test_wet_labels = load_images(path + 'test_wet/', rgb= True)

print("train images dry: ", train_dry_imgs.shape, " labels: ", len(train_dry_labels))
print("train images wet: ", train_wet_imgs.shape, " labels: ", len(train_wet_labels))
print("test dry: ", test_dry_imgs.shape)
print("test wet: ", test_wet_imgs.shape)

**Data augmentation techniques**

In [None]:
def random_crop(image):
    cropped_image = tf.image.random_crop(image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
    return cropped_image

# scaling the images to [-1, 1]
def normalize(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def random_jitter(image):
    # resizing to 286 x 286 x 3
    image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # randomly cropping to 256 x 256 x 3
    image = random_crop(image)

    # random mirroring
    image = tf.image.random_flip_left_right(image)

    return image

**Preprocess splits**

In [None]:
def preprocess_image_train(image):
    image = random_jitter(image)
    image = normalize(image)
    return image

def preprocess_image_test(image):
    image = normalize(image)
    return image

In [None]:
#conversion de las imageness a array
train_dry_array = np.asarray(train_dry_imgs)
test_dry_array = np.asarray(test_dry_imgs)
train_wet_array = np.asarray(train_wet_imgs)
test_wet_array = np.asarray(test_wet_imgs)

In [None]:
#Crea un dataSet de WL y NBI 
train_dry_ds = tf.data.Dataset.from_tensor_slices(train_dry_array)
train_dry_label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(train_dry_labels, tf.int64)).batch(BATCH_SIZE)

train_wet_ds = tf.data.Dataset.from_tensor_slices(train_wet_array)
train_wet_label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(train_wet_labels, tf.int64)).batch(BATCH_SIZE)
test_dry_ds = tf.data.Dataset.from_tensor_slices(test_dry_array)
test_wet_ds = tf.data.Dataset.from_tensor_slices(test_wet_array)

In [None]:
train_dry_ds = train_dry_ds.map(preprocess_image_train, 
                              num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)

train_wet_ds = train_wet_ds.map(preprocess_image_train,
                                num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)

#Since the datasets are in the same order you can just zip them together to get
#a dataset of (image, label) pairs:

train_dry_image_label_ds = tf.data.Dataset.zip((train_dry_ds, train_dry_label_ds))
train_wet_image_label_ds = tf.data.Dataset.zip((train_wet_ds, train_wet_label_ds))

#shuffle zip train data
train_dry_image_label_ds = train_dry_image_label_ds.shuffle(buffer_size=len(train_dry_image_label_ds),
                                                          reshuffle_each_iteration=False)
train_dry_image_label_ds = train_dry_image_label_ds.prefetch(buffer_size=AUTOTUNE)

train_wet_image_label_ds = train_wet_image_label_ds.shuffle(buffer_size=len(train_wet_image_label_ds),
                                                          reshuffle_each_iteration=False)
train_wet_image_label_ds = train_wet_image_label_ds.prefetch(buffer_size=AUTOTUNE)


#for test data
test_dry_ds = test_dry_ds.map(preprocess_image_test,
                            num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)

test_wet_ds = test_wet_ds.map(preprocess_image_test,
                              num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)

In [None]:
img_sample_dry, lab_sample_dry  = next(iter(train_dry_image_label_ds))
img_sample_wet, lab_sample_wet = next(iter(train_wet_image_label_ds))

print("dry sample info:")
print("shape: {}, label: {} ".format(img_sample_dry.shape, lab_sample_dry))
print("wet sample info:")
print("shape: {}, label: {} ".format(img_sample_wet.shape, lab_sample_wet))

In [None]:
b = train_wet_array[0]
plt.hist(b.ravel())
plt.title("Before scaling")
plt.show()

In [None]:
a = np.array(img_sample_wet[0])
plt.hist(a.ravel())
plt.title("After scaling")
plt.show()

In [None]:
plt.subplot(121)
plt.title('Dry')
#plt.imshow(sample_WL[0] * 0.5 + 0.5)
print(img_sample_dry[0].shape)
plt.imshow(np.squeeze(img_sample_dry[0]) * 0.5 + 0.5, cmap='gray')
idx = lab_sample_dry.numpy().argmax()
plt.xlabel(idx2str[idx])

plt.subplot(122)
plt.title('Dry with random jitter')
#plt.imshow(random_jitter(sample_WL[0]) * 0.5 + 0.5)
plt.imshow(np.squeeze(random_jitter(img_sample_dry[0])) * 0.5 + 0.5, cmap='gray')
idx = lab_sample_dry.numpy().argmax()
plt.xlabel(idx2str[idx])

In [None]:
plt.subplot(121)
plt.title('Wet')
#plt.imshow(sample_NBI[0] * 0.5 + 0.5)
plt.imshow(np.squeeze(img_sample_wet[0]) * 0.5 + 0.5, cmap='gray')
idx = lab_sample_wet.numpy().argmax()
plt.xlabel(idx2str[idx])

plt.subplot(122)
plt.title('Wet light with random jitter')
#plt.imshow(random_jitter(sample_NBI[0]) * 0.5 + 0.5)
plt.imshow(np.squeeze(random_jitter(img_sample_wet[0])) * 0.5 + 0.5, cmap='gray')
idx = lab_sample_wet.numpy().argmax()
plt.xlabel(idx2str[idx])

**Loading many wet samples**

In [None]:
images, labels = [], []
for i in tqdm(range(25)):
    imgs_samples, labels_samples = next(iter(train_wet_image_label_ds.shuffle(buffer_size=len(train_wet_imgs))))
    images.append(imgs_samples)
    labels.append(labels_samples)

images = np.asarray(images)
print("images: {}, amount of labels: {}".format(images.shape, len(labels)))

In [None]:
plt.figure(figsize=(12,12))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(np.squeeze(images[i])* 0.5 + 0.5)#convert (batch, high, width, #channels) into (high, width, #channels) 
    idx = labels[i].numpy().argmax()
    plt.xlabel("label: {}".format(idx2str[idx]))
plt.show()

# Import and reuse the Pix2Pix models

In [None]:
OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

In [None]:
img_sample_dry, label_sample_dry = next(iter((train_dry_image_label_ds)))
img_sample_wet, label_sample_wet = next(iter((train_wet_image_label_ds)))

print("info de real data")
print("img shape: {}, label: {}".format(img_sample_dry.shape, lab_sample_dry))
print("min: {}, max: {}".format(tf.reduce_min(img_sample_dry).numpy(), tf.reduce_max(img_sample_dry).numpy()))
print("min: {}, max: {}".format(tf.reduce_min(img_sample_wet).numpy(), tf.reduce_max(img_sample_wet).numpy()))

to_wet = generator_g([img_sample_dry])
to_dry = generator_f([img_sample_wet])

print("info de fake data")
print("min: {}, max: {}".format(tf.reduce_min(to_wet).numpy(), tf.reduce_max(to_wet).numpy()))
print("min: {}, max: {}".format(tf.reduce_min(to_dry).numpy(), tf.reduce_max(to_dry).numpy()))

plt.figure(figsize=(8, 8))
contrast = 8

imgs = [img_sample_dry, to_wet, img_sample_wet, to_dry]
title = ['dry', 'To wet', 'wet', 'To dry']

for i in range(len(imgs)):
    plt.subplot(2, 2, i+1)
    plt.title(title[i])
    if i % 2 == 0:
        plt.imshow(imgs[i][0] * 0.5 + 0.5)
    else:
        plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()

In [None]:
print(img_sample_wet.shape)
print(label_sample_wet.shape)

In [None]:
plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real wet?')
plt.imshow(discriminator_y([img_sample_wet])[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real dry?')
plt.imshow(discriminator_x([img_sample_dry])[0, ..., -1], cmap='RdBu_r')

plt.show()

## **Loss functions**

In [None]:
LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
class_loss_obj = tf.keras.losses.CategoricalCrossentropy(from_logits=False)

In [None]:
def discriminator_loss(real, generated):
    
    real_loss = loss_obj(tf.ones_like(real), real)
    generated_loss = loss_obj(tf.zeros_like(generated), generated)
    total_disc_loss = real_loss + generated_loss
    
    return total_disc_loss * 0.5

def generator_loss(generated):
    return loss_obj(tf.ones_like(generated), generated)

def calc_cycle_loss(real_image, cycled_image):
    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

    return LAMBDA * loss1

def identity_loss(real_image, same_image):
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return LAMBDA * 0.5 * loss

## **Initializing optimizers, generator and discriminators**

In [None]:
lr = 2e-4
generator_g_optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.5)

## **Check points**

In [None]:

checkpoint_path = "../models/"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=2)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

ckpt.restore(ckpt_manager.latest_checkpoint)
if ckpt_manager.latest_checkpoint:
    print("Restored from {}".format(ckpt_manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

# **Training**

In [None]:
EPOCHS = 20

In [None]:
def generate_images(model, test_input):
    prediction = model(test_input)
    plt.figure(figsize=(12, 12))

    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']

    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

In [None]:
@tf.function
def train_step(real_x, real_y):
    # persistent is set to True because the tape is used more than
    # once to calculate the gradients.
    real_x_img = real_x[0]
    real_x_label = real_x[1]
    real_y_img = real_y[0]
    real_y_label = real_y[1]
    
    with tf.GradientTape(persistent=True) as tape:
        # Generator G translates X -> Y------> WL -> NBI
        # Generator F translates Y -> X.-----> NBI -> WL

        fake_y = generator_g(real_x_img, training=True)
        cycled_x = generator_f(fake_y, training=True)
        #same for revert domain traslation
        fake_x = generator_f(real_y_img, training=True)
        cycled_y = generator_g(fake_x, training=True)

        # same_x and same_y are used for identity loss.
        same_x = generator_f(real_x_img, training=True)
        same_y = generator_g(real_y_img, training=True)

        disc_real_x = discriminator_x(real_x_img, training=True)
        disc_real_y = discriminator_y(real_y_img, training=True)

        disc_fake_x = discriminator_x(fake_x, training=True)
        disc_fake_y = discriminator_y(fake_y, training=True)

        # calculate the loss (generator)
        gen_g_loss = generator_loss(disc_fake_y)
        gen_f_loss = generator_loss(disc_fake_x)
            
        total_cycle_loss = calc_cycle_loss(real_x_img, cycled_x) + calc_cycle_loss(real_y_img, cycled_y)

        # Total generator loss = adversarial loss + cycle loss
        #what happened if the identity loss is not taken?
        total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y_img, same_y)
        total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x_img, same_x)

        disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
        disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
    
        

    ### Calculate the gradients for generator and discriminator
    generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                          generator_g.trainable_variables)
    generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                          generator_f.trainable_variables)

    discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                              discriminator_x.trainable_variables)
    discriminator_y_gradients = tape.gradient(disc_y_loss,
                                              discriminator_y.trainable_variables)
        

    # Apply the gradients to the optimizer
    generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

    generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))

    discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))

    discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))

In [None]:
def train_and_checkpoint(ckpt_manager=None):
    
    ckpt.restore(ckpt_manager.latest_checkpoint)
    if ckpt_manager.latest_checkpoint:
        print("Restored from {}".format(ckpt_manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    for epoch in range(EPOCHS):
        start = time.time()
        n = 0    

        for image_x, image_y in tf.data.Dataset.zip((train_dry_image_label_ds, train_wet_image_label_ds)):
            train_step(image_x, image_y)
            if n % 10 == 0:
                print ('.', end='')
            n += 1

        clear_output(wait=True)
        # Using a consistent image (sample_horse) so that the progress of the model
        # is clearly visible.
        generate_images(generator_g, img_sample_dry)

        if (epoch + 1) % 5 == 0:
                       
            ckpt_save_path = ckpt_manager.save()
            print ('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))
            #print('Saving generator g h5 models...')
            #generator_g_name = save_model_h5_path + '/gen_g.h5'
            #
            #generator_g.save(generator_g_name)
            #print('Saving classification net')
            #cls_model_name = save_model_h5_path + '/nbi_clss.h5'
            #nbi_cls_model.save(cls_model_name)

        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))

In [None]:
train_and_checkpoint(ckpt_manager)

# Testing over single video

In [None]:
def generate_images(model, test_input):
    prediction = model(test_input)
    plt.figure(figsize=(12, 12))
    
    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']

    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()# **Generate using test dataset**

In [None]:
# frames from video path
path =  "../../../../../data/polyp_original/WL/adenoma_WL/video_1/"
# load dataset white light
# here A: white light, B: nbi light
adenoma_WL = load_images(path, rgb=True)
print("Adenoma WL video_1: ", adenoma_WL.shape)

In [None]:
adenoma_WL_array = np.asarray(adenoma_WL)
adenoma_WL_ds = tf.data.Dataset.from_tensor_slices(adenoma_WL_array)
adenoma_WL_ds = adenoma_WL_ds.map(preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
                BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
for inp in adenoma_WL_ds.take(adenoma_WL.shape[0]):
    generate_images(generator_g, inp)