# Main libraries

In [None]:
#allows to import pretrained generators and discriminators
!pip install -q git+https://github.com/tensorflow/examples.git

In [None]:
#For choose the gpu device
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import tensorflow as tf
print(tf.__version__)
tf.config.run_functions_eagerly(True)

In [None]:
import gc
from tensorflow_examples.models.pix2pix import pix2pix
from os import listdir
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import img_to_array
from numpy import vstack
from numpy import asarray
from numpy import savez_compressed
import numpy as np
from PIL import Image
from tqdm import tqdm
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tensorflow import keras
import gc

AUTOTUNE = tf.data.AUTOTUNE

# Load and preprocess dataset

**Needed the first time**

In [None]:
str2idx = {
    'adenoma': 0,
    'hiperplastic': 1,
    'serrated': 2
}

idx2str = {
    0: 'adenoma',
    1: 'hiperplastic', 
    2: 'serrated'
}

In [None]:
def ohe_class(index):
    """
    One-hot encodes a class label represented by an index.

    Parameters:
    index (int): The index representing the class label that needs to be one-hot encoded.

    Returns:
    ohe_label (numpy.ndarray): A one-hot encoded array with a length of 2, where the value at the index
                               corresponding to the input index is set to 1 and all other values are 0.

    Notes:
    - This function assumes that there are two classes (binary classification), so the length of the
      one-hot encoded array is fixed to 2.
    - The function sets the element at the given index to 1, representing the class label, and sets all
      other elements to 0, indicating the absence of those classes.
    - The input index should be in the range [0, 1], representing the two classes.
    """
    ohe_label = np.zeros(2, dtype=int)  # Create a zero-filled array of length 2
    ohe_label[index] = 1  # Set the value at the given index to 1 for one-hot encoding
    return ohe_label

In [None]:
def load_images(path, size=(256, 256), rgb=False):
    """
    Loads images from a directory into memory and performs optional resizing and color mode conversion.

    Parameters:
    path (str): The path to the directory containing the images.
    size (tuple, optional): A tuple representing the target size for resizing the images. Default is (256, 256).
    rgb (bool, optional): If True, the images will be loaded in RGB color mode; if False, they will be loaded in grayscale.
                          Default is False.

    Returns:
    data (numpy.ndarray): A NumPy array containing the loaded and preprocessed image data.
    labels (list): A list of one-hot encoded labels corresponding to the loaded images.

    Notes:
    - The function assumes that the images in the directory are all valid image files.
    - It uses the tqdm library to show a progress bar while processing the images.
    - Images with the class label 'serrated' are skipped and not loaded into memory.
    - The function uses the Keras `load_img` and `img_to_array` functions for image loading and conversion.
    - The `size` parameter can be used to resize the images to a specific size before loading them into memory.
    - The `rgb` parameter determines whether the images are loaded in RGB (True) or grayscale (False) color mode.
    """
    data_list = list()
    label_list = list()

    # Determine the color mode based on the rgb parameter
    if rgb == False:
        color_mode = "grayscale"
    else:
        color_mode = "rgb"

    # Enumerate filenames in the directory, assuming all are images
    for filename in tqdm(os.listdir(path)):
        clase = filename.split('_')[0]
        if clase != 'serrated':
            # Load and resize the image
            pixels = load_img(path + filename, target_size=size, color_mode=color_mode)
            # Convert to a numpy array
            pixels = img_to_array(pixels)
            # Store the image data
            data_list.append(pixels)

            # For labels
            clase = filename.split('_')[0]
            indx = str2idx[clase]  # Assuming there's a dictionary named str2idx mapping class labels to indices
            # Get one-hot encoding from the index
            ohe_label = ohe_class(indx)  # Assuming there's a function named ohe_class for one-hot encoding
            label_list.append(ohe_label)
        else:
            # Skip images with the class label 'serrated'
            None

    # Convert the data list to a numpy array and return the data and label lists
    return np.asarray(data_list), label_list


**Set up**

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
experiment = 'fineTunedExperimentName' 

In [None]:
"""Frames loading.
rgb parameter sets as False for work whit grayscale images
"""
# dataset path
path = '../../../data/binary/fold1/'
# load dataset white light
train_WL_imgs, train_WL_labels = load_images(path + 'train_WL/', rgb= True)
test_WL_imgs, test_WL_labels = load_images(path + 'test_WL/', rgb= True)

# load dataset NBI
train_NBI_imgs, train_NBI_labels = load_images(path + 'train_NBI/', rgb= True)
test_NBI_imgs, test_NBI_labels = load_images(path + 'test_NBI/', rgb= True)

print("train images WL: ", train_WL_imgs.shape, " labels: ", len(train_WL_labels))
print("train images NBI: ", train_NBI_imgs.shape, " labels: ", len(train_NBI_labels))

**Data augmentation techniques**

In [None]:
def random_crop(image):
    cropped_image = tf.image.random_crop(image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
    return cropped_image

# scaling the images to [-1, 1]
def normalize(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def random_jitter(image):
    # resizing to 286 x 286 x 3
    image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # randomly cropping to 256 x 256 x 3
    image = random_crop(image)

    # random mirroring
    image = tf.image.random_flip_left_right(image)

    return image

**Preprocess splits**

In [None]:
def preprocess_image_train(image):
    image = random_jitter(image)
    image = normalize(image)
    return image

def preprocess_image_test(image):
    image = normalize(image)
    return image

In [None]:
#conversion de las imageness a array
train_WL_array = np.asarray(train_WL_imgs)
test_WL_array = np.asarray(test_WL_imgs)
train_NBI_array = np.asarray(train_NBI_imgs)
test_NBI_array = np.asarray(test_NBI_imgs)

In [None]:
#Crea un dataSet de WL y NBI 
train_WL_ds = tf.data.Dataset.from_tensor_slices(train_WL_array)
train_WL_label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(train_WL_labels, tf.int64)).batch(BATCH_SIZE)

train_NBI_ds = tf.data.Dataset.from_tensor_slices(train_NBI_array)
train_NBI_label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(train_NBI_labels, tf.int64)).batch(BATCH_SIZE)
test_WL_ds = tf.data.Dataset.from_tensor_slices(test_WL_array)
test_NBI_ds = tf.data.Dataset.from_tensor_slices(test_NBI_array)

In [None]:
train_WL_ds = train_WL_ds.map(preprocess_image_train, 
                              num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)

train_NBI_ds = train_NBI_ds.map(preprocess_image_train,
                                num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)

#Since the datasets are in the same order you can just zip them together to get
#a dataset of (image, label) pairs:

train_WL_image_label_ds = tf.data.Dataset.zip((train_WL_ds, train_WL_label_ds))
train_NBI_image_label_ds = tf.data.Dataset.zip((train_NBI_ds, train_NBI_label_ds))

#shuffle zip train data
train_WL_image_label_ds = train_WL_image_label_ds.shuffle(buffer_size=len(train_WL_image_label_ds),
                                                          reshuffle_each_iteration=False)
train_WL_image_label_ds = train_WL_image_label_ds.prefetch(buffer_size=AUTOTUNE)

train_NBI_image_label_ds = train_NBI_image_label_ds.shuffle(buffer_size=len(train_NBI_image_label_ds),
                                                          reshuffle_each_iteration=False)
train_NBI_image_label_ds = train_NBI_image_label_ds.prefetch(buffer_size=AUTOTUNE)


#for test data
test_WL_ds = test_WL_ds.map(preprocess_image_test,
                            num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)

test_NBI_ds = test_NBI_ds.map(preprocess_image_test,
                              num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)

In [None]:
img_sample_WL, lab_sample_WL  = next(iter(train_WL_image_label_ds))
img_sample_NBI, lab_sample_NBI = next(iter(train_NBI_image_label_ds))

print("WL sample info:")
print("shape: {}, label: {} ".format(img_sample_WL.shape, lab_sample_WL))
print("NBI sample info:")
print("shape: {}, label: {} ".format(img_sample_NBI.shape, lab_sample_NBI))

In [None]:
b = train_NBI_array[0]
plt.hist(b.ravel())
plt.title("Before scaling")
plt.show()

In [None]:
a = np.array(img_sample_NBI[0])
plt.hist(a.ravel())
plt.title("After scaling")
plt.show()

In [None]:
plt.subplot(121)
plt.title('White light')
print(img_sample_WL[0].shape)
plt.imshow(np.squeeze(img_sample_WL[0]) * 0.5 + 0.5, cmap='gray')
idx = lab_sample_WL.numpy().argmax()
plt.xlabel(idx2str[idx])

plt.subplot(122)
plt.title('White light with random jitter')
plt.imshow(np.squeeze(random_jitter(img_sample_WL[0])) * 0.5 + 0.5, cmap='gray')
idx = lab_sample_WL.numpy().argmax()
plt.xlabel(idx2str[idx])

In [None]:
plt.subplot(121)
plt.title('NBI light')
plt.imshow(np.squeeze(img_sample_NBI[0]) * 0.5 + 0.5, cmap='gray')
idx = lab_sample_NBI.numpy().argmax()
plt.xlabel(idx2str[idx])

plt.subplot(122)
plt.title('NBI light with random jitter')
plt.imshow(np.squeeze(random_jitter(img_sample_NBI[0])) * 0.5 + 0.5, cmap='gray')
idx = lab_sample_NBI.numpy().argmax()
plt.xlabel(idx2str[idx])

**Loading many NBI samples (only for data understanding)**

In [None]:
images, labels = [], []
for i in tqdm(range(25)):
    imgs_samples, labels_samples = next(iter(train_NBI_image_label_ds.shuffle(buffer_size=len(train_NBI_imgs))))
    images.append(imgs_samples)
    labels.append(labels_samples)

images = np.asarray(images)
print("images: {}, amount of labels: {}".format(images.shape, len(labels)))

In [None]:
plt.figure(figsize=(12,12))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(np.squeeze(images[i])* 0.5 + 0.5)#convert (batch, high, width, #channels) into (high, width, #channels) 
    idx = labels[i].numpy().argmax()
    plt.xlabel("label: {}".format(idx2str[idx]))
plt.show()

# Import and reuse the Pix2Pix models

In [None]:
OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

In [None]:
img_sample_WL, label_sample_WL = next(iter((train_WL_image_label_ds)))
img_sample_NBI, label_sample_NBI = next(iter((train_NBI_image_label_ds)))

print("info de real data")
print("img shape: {}, label: {}".format(img_sample_WL.shape, lab_sample_WL))
print("min: {}, max: {}".format(tf.reduce_min(img_sample_WL).numpy(), tf.reduce_max(img_sample_WL).numpy()))
print("min: {}, max: {}".format(tf.reduce_min(img_sample_NBI).numpy(), tf.reduce_max(img_sample_NBI).numpy()))

to_NBI = generator_g([img_sample_WL])
to_WL = generator_f([img_sample_NBI])

print("info de fake data")
print("min: {}, max: {}".format(tf.reduce_min(to_NBI).numpy(), tf.reduce_max(to_NBI).numpy()))
print("min: {}, max: {}".format(tf.reduce_min(to_WL).numpy(), tf.reduce_max(to_WL).numpy()))

plt.figure(figsize=(8, 8))
contrast = 8

imgs = [img_sample_WL, to_NBI, img_sample_NBI, to_WL]
title = ['WL', 'To NBI', 'NBI', 'To WL']

for i in range(len(imgs)):
    plt.subplot(2, 2, i+1)
    plt.title(title[i])
    if i % 2 == 0:
        plt.imshow(imgs[i][0] * 0.5 + 0.5)
    else:
        plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()

In [None]:
print(img_sample_NBI.shape)
print(label_sample_NBI.shape)

In [None]:
plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real NBI?')
plt.imshow(discriminator_y([img_sample_NBI])[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real WL?')
plt.imshow(discriminator_x([img_sample_WL])[0, ..., -1], cmap='RdBu_r')

plt.show()

## Classificator network

In [None]:
model_path = "path/to/the/pretrained/NBI_model.h5"
cls_model = keras.models.load_model(model_path, compile=True)
    
for layer in cls_model.layers:
    layer.trainable = False
print("all freezing")

## **Loss functions**

In [None]:
LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
class_loss_obj = tf.keras.losses.CategoricalCrossentropy(from_logits=False)

In [None]:
def discriminator_loss(real, generated):
        
    real_loss = loss_obj(tf.ones_like(real), real)
    generated_loss = loss_obj(tf.zeros_like(generated), generated)
    total_disc_loss = real_loss + generated_loss
    
    return total_disc_loss * 0.5

def generator_loss(generated):
    return loss_obj(tf.ones_like(generated), generated)

def calc_cycle_loss(real_image, cycled_image):
    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

    return LAMBDA * loss1

def identity_loss(real_image, same_image):
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return LAMBDA * 0.5 * loss

def classifier_loss(y_real, y_pred):
    return class_loss_obj(y_real, y_pred)

## **Initializing optimizers, generator and discriminators**

In [None]:
lr = 2e-4
generator_g_optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.5)

cls_model_optimizier = tf.keras.optimizers.Adam(lr, beta_1=0.5)

## **Check points**

In [None]:

checkpoint_path = "../models/folders/" + experiment

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer,
                           cls_model=cls_model, 
                           cls_model_optimizier=cls_model_optimizier)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=2)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

ckpt.restore(ckpt_manager.latest_checkpoint)
if ckpt_manager.latest_checkpoint:
    print("Restored from {}".format(ckpt_manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

# **Training**

In [None]:
EPOCHS = 50

In [None]:
def generate_images(model, test_input):
    """
    Generates and displays side-by-side images: the input image and the corresponding predicted image
    using the specified model.

    Parameters:
    model: A pre-trained or trained model capable of generating image predictions.
    test_input: The input image or image batch for which the model will generate predictions.

    Notes:
    - The function assumes that the model has been trained and is capable of producing image predictions.
    - The test_input should be in a format compatible with the model's input requirements.
    - The function uses matplotlib to display the input image and the predicted image side by side.
    - The input image and predicted image are plotted on the same scale, which is adjusted to lie between
      [0, 1] to ensure proper visualization.
    - The function does not return anything. Instead, it directly displays the images using matplotlib.
    """
    # Get the prediction from the model for the test_input
    prediction = model(test_input)

    # Set up the plot
    plt.figure(figsize=(12, 12))

    # Create a list of images to display
    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']

    # Plot the input image and predicted image side by side
    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        # Scale the pixel values between [0, 1] for proper visualization
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')

    # Show the plot with both images
    plt.show()


In [None]:
@tf.function
def train_step(real_x, real_y, alpha1, alpha2, alpha3):
    # persistent is set to True because the tape is used more than
    # once to calculate the gradients.
    real_x_img = real_x[0]
    real_y_img = real_y[0]
    real_y_label = real_y[1]
    
    with tf.GradientTape(persistent=True) as tape:
        # Generator G translates X -> Y------> WL -> NBI
        # Generator F translates Y -> X.-----> NBI -> WL

        fake_y = generator_g(real_x_img, training=True)
        cycled_x = generator_f(fake_y, training=True)
        #same for revert domain traslation
        fake_x = generator_f(real_y_img, training=True)
        cycled_y = generator_g(fake_x, training=True)

        # same_x and same_y are used for identity loss.
        same_x = generator_f(real_x_img, training=True)
        same_y = generator_g(real_y_img, training=True)

        disc_real_x = discriminator_x(real_x_img, training=True)
        disc_real_y = discriminator_y(real_y_img, training=True)

        disc_fake_x = discriminator_x(fake_x, training=True)
        disc_fake_y = discriminator_y(fake_y, training=True)
        #image classification (adeVshyp)
        pred_y = cls_model(fake_y)

        # calculate the loss (generator)
        gen_g_loss = generator_loss(disc_fake_y)
        gen_f_loss = generator_loss(disc_fake_x)
        
        # calculate the loss (classificator)
        y_cls_loss = classifier_loss(real_y_label, pred_y)

        total_cycle_loss = calc_cycle_loss(real_x_img, cycled_x) + calc_cycle_loss(real_y_img, cycled_y)

        total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y_img, same_y)
        total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x_img, same_x)

        disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
        disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
        

    ### Calculate the gradients for generator and discriminator
    generator_g_gradients = tape.gradient(total_gen_g_loss, generator_g.trainable_variables)
    generator_f_gradients = tape.gradient(total_gen_f_loss, generator_f.trainable_variables)

    discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                              discriminator_x.trainable_variables)
    discriminator_y_gradients = tape.gradient(disc_y_loss,
                                              discriminator_y.trainable_variables)
    #calculate the gradients for classifier network
    cls_model_gradients = tape.gradient(y_cls_loss, cls_model.trainable_variables)
    

    # Apply the gradients to the optimizer
    generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

    generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))

    discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))

    discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))
    
    cls_model_optimizier.apply_gradients(zip(cls_model_gradients,
                                               cls_model.trainable_variables))

In [None]:
def train_and_checkpoint(ckpt_manager=None):
    
    ckpt.restore(ckpt_manager.latest_checkpoint)
    if ckpt_manager.latest_checkpoint:
        print("Restored from {}".format(ckpt_manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    for epoch in range(EPOCHS):
        start = time.time()
        n = 0    

        for image_x, image_y in tf.data.Dataset.zip((train_WL_image_label_ds, train_NBI_image_label_ds)):
            train_step(image_x, image_y)
            if n % 10 == 0:
                print ('.', end='')
            n += 1

        clear_output(wait=True)
        # Using a consistent image (sample_horse) so that the progress of the model
        # is clearly visible.
        generate_images(generator_g, img_sample_WL)

        if (epoch + 1) % 5 == 0:
                       
            ckpt_save_path = ckpt_manager.save()
            print ('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))

        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))

In [None]:
train_and_checkpoint(ckpt_manager)

# Testing over single video

In [None]:
def generate_images(model, test_input):
    prediction = model(test_input)
    plt.figure(figsize=(12, 12))
    
    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']

    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()# **Generate using test dataset**

In [None]:
# frames from video path
path =  "../path/WL_/polyp_class/videoN/" #inside the videoN folder must be the corresponding video frames
# load dataset white light
adenoma_WL = load_images(path, rgb=True)
print("Adenoma WL video_1: ", adenoma_WL.shape)

In [None]:
adenoma_WL_array = np.asarray(adenoma_WL)
adenoma_WL_ds = tf.data.Dataset.from_tensor_slices(adenoma_WL_array)
adenoma_WL_ds = adenoma_WL_ds.map(preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
                BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
for inp in adenoma_WL_ds.take(adenoma_WL.shape[0]):
    generate_images(generator_g, inp)