# <font color='red'>**Libraries**</font>

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

In [None]:
#allows to import generator and discriminator
!pip install -q git+https://github.com/tensorflow/examples.git

In [None]:
import tensorflow as tf
#import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
from os import listdir
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import img_to_array
from numpy import vstack
from numpy import asarray
from numpy import savez_compressed
import numpy as np
from PIL import Image
from tensorflow.keras.utils import plot_model

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

#AUTOTUNE = tf.data.AUTOTUNE
AUTOTUNE = tf.data.experimental.AUTOTUNE
import glob
import imageio
from skimage.transform import resize
from tqdm import tqdm

# <font color='red'>**Useful methods**</font>

In [None]:
# scaling the images to [-1, 1]
def normalize(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def preprocess_image_test(image):
    image = normalize(image)
    return image

In [None]:
def load_images(path, size=(256,256)):
    data_list = list()
    # load and resize the image
    pixels = load_img(path, target_size=size)
    # convert to numpy array
    pixels = img_to_array(pixels)
    # store
    data_list.append(pixels)
    return asarray(data_list)

# <font color='red'>**Loading generator model**</font>

**Import and reuse pix2pix models**

In [None]:
OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

**Initializing optimizers, generators and discriminators**

In [None]:
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

**Loading generator model**

In [None]:
checkpoint_path = "../models/"
ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')
ckpt.restore(ckpt_manager.latest_checkpoint)
if ckpt_manager.latest_checkpoint:
    print("Restored from {}".format(ckpt_manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

# <font color='red'>**Generating fake images from full frame videos**</font>

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [None]:
gen_path =  "../../../../../data/polyp_original/WL/"
save_path = "../imgs_results/fake_images_ck8/"

clases = os.listdir(gen_path)
for clase in clases:
    print("working on: ", clase)
    clase_path = gen_path + clase
    videos = os.listdir(clase_path)
    for video in videos:
        video_path = clase_path + '/' + video
        images = os.listdir(video_path)
        for image in tqdm(images):
            img_path = video_path + '/' + image
            img_sam = load_images(img_path)
            img = preprocess_image_test(img_sam)

            fake_sam = generator_g(img)
            fake = fake_sam[0]*0.5 + 0.5
            fake = resize(fake, (576, 768))         
            to_save = save_path + clase + '/' + video + '/'
            if not os.path.exists(to_save):
                os.makedirs(to_save)
            
            filename = to_save + image
            plt.imsave(filename, fake)

# <font color='red'>**Making NBI synthetic videos**</font>

In [None]:
import cv2
import numpy as np
from natsort import natsorted

In [None]:
general_folder = '../imgs_results/fake_images_ck8/'
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
save_path = '../fake_videos/'

clases = os.listdir(general_folder)[1:]

for clase in clases:
    print("for clase: ", clase)
    if clase == 'adenoma_WL':
        cont = 0
    elif clase == 'hiperplastic_WL':
        cont = 40
    else:
        cont = 61
    clase_path = general_folder + clase + '/'
    videos = os.listdir(clase_path)
    for video in videos:
        print(video)
        #get the video number
        num_vid = int(video.split('_')[-1])
        video_path = clase_path + video + '/'
        images = os.listdir(video_path)
        #images sorted
        images = natsorted(images)[1:]
        #getting space features from a frame
        frame = cv2.imread(os.path.join(video_path, images[0]))
        height, width, layers = frame.shape
        
        current_save = save_path + clase + '/' + video + '/'
        current_number = num_vid + cont
        video_name = current_save + "video_" + str(current_number) + '.mp4'        
        
        if not os.path.exists(current_save):
            os.makedirs(current_save)
        video = cv2.VideoWriter(filename=video_name, fourcc=fourcc, fps=25, frameSize=(width,height))
        for image in tqdm(images):
            video.write(cv2.imread(os.path.join(video_path, image)))
        
        cv2.destroyAllWindows()
        video.release()
print("finished!")       