In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
#allows to import generator and discriminator
!pip install -q git+https://github.com/tensorflow/examples.git

In [None]:
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix
from os import listdir
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import img_to_array
from numpy import vstack
from numpy import asarray
from numpy import savez_compressed
import numpy as np
from PIL import Image
from tensorflow.keras.utils import plot_model

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
from sklearn import preprocessing

#AUTOTUNE = tf.data.AUTOTUNE
AUTOTUNE = tf.data.experimental.AUTOTUNE

from PIL import Image
import glob
import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tqdm import tqdm

In [None]:
import tensorflow as tf
tf.version.VERSION

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

# <font color='red'>**Useful methods**</font>

In [None]:
# load all images in a directory into memory
def load_images(path, size=(256,256)):
    data_list = list()
    #enumerate filenames in directory, assume all are images
    for filename in listdir(path):
        # load and resize the image
        pixels = load_img(path + filename, target_size=size)
        # convert to numpy array
        pixels = img_to_array(pixels)
        # store
        data_list.append(pixels)
    return asarray(data_list)

**Data augmentation techniques**

In [None]:
def random_crop(image):
    cropped_image = tf.image.random_crop(image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

    return cropped_image

# scaling the images to [-1, 1]
def normalize(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def random_jitter(image):
    # resizing to 286 x 286 x 3
    image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # randomly cropping to 256 x 256 x 3
    image = random_crop(image)

    # random mirroring
    image = tf.image.random_flip_left_right(image)

    return image

**Preprocess splits**

In [None]:
def preprocess_image_train(image):
    image = random_jitter(image)
    image = normalize(image)
    return image

def preprocess_image_test(image):
    image = normalize(image)
    return image

**Import and reuse the Pix2Pix models**

In [None]:
OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

**Initializing optimizers, generator and discriminators**

In [None]:
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

#nbi_cls_model_optimizier = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# <font color='red'>**Loading models**</font>

In [None]:
#base_model = tf.keras.models.load_model('../models/classifier/binary/MobileNet.h5', compile=True)
#print("model loaded!")

In [None]:
#for layer in base_model.layers:
    #print(layer.name)

In [None]:
#backbone = base_model.get_layer('mobilenet_1.00_224')
#x = base_model.get_layer('global_average_pooling2d')(backbone.output)
#x = base_model.get_layer('dense')(x)
#x = base_model.get_layer('dropout')(x)
#x = base_model.get_layer('dense_1')(x)
#
#nbi_cls_model = tf.keras.Model(inputs=backbone.input, outputs=x)
#print(nbi_cls_model.summary())

In [None]:
checkpoint_path = "../models/cyclegan/preprocessed/mri_to_spect/"
ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)
                           #nbi_cls_model=nbi_cls_model)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')
ckpt.restore(ckpt_manager.latest_checkpoint)
if ckpt_manager.latest_checkpoint:
    print("Restored from {}".format(ckpt_manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

In [None]:
generator_g.summary()

In [None]:
plot_model(generator_g, to_file='model_plot.png', show_shapes=True, show_layer_names=True)

# <font color='red'>**Predicting over full test subjects**</font>
## Main
### Original data

In [None]:
def rgb2gray(path, size, pixels):
    img2 = np.zeros((pixels.shape))
    a = load_img(path, target_size=size, color_mode= "grayscale")
    img2[:,:,0] = a
    img2[:,:,1] = a
    img2[:,:,2] = a

    return img2

In [None]:
### ========= here is for control to parkinson ========= ###
# split = "test"
# modality = "parkinson"
# gen_path = "../data/full_rois/mri/" + split + "_" + modality + "/"
# save_path = "../imgs_results/full_rois/mri/" + split + "_" + modality + "/"

# if modality == "control":
#      print("loading generator_g")
#      generator = generator_g
# else:
#      print("loading generator_f")
#      generator = generator_f     

### ========= here is for T1 to SPECT ========= ###
split = "train"
modality = "mri" 

gen_path = "../../../../../../Datasets/Parkinson/radiological/PPMI/spect-mri/filtered/" + split + "_" + modality + "/"
print(gen_path)
save_path = "../imgs_results/full_rois/mri_to_spect/" + split + "_" + modality + "/"
print(save_path)

if modality == "mri":
     print("loading generator_g")
     generator = generator_g
else:
     print("loading generator_f")
     generator = generator_f     

files = sorted(os.listdir(gen_path))
size = (256,256)
rgb = True

for filename in files:
           
     data_list = list()
     general_info = filename.split('_')
     clase = general_info[2]
     id_img = general_info[-1]

     # # load and resize the image
     # pixels = load_img(gen_path + filename, target_size=size, color_mode= "rgb")
     # # convert to numpy array
     # pixels = img_to_array(pixels)

     # if rgb==False:
     #           #convert rgb to gray
     #           pixels = rgb2gray(gen_path + filename, size, pixels)
     # else:
     #      None

     # data_list.append(pixels)
     # img_array = asarray(data_list)

     # split_ds = tf.data.Dataset.from_tensor_slices(img_array)
     # split_ds = split_ds.map(preprocess_image_test, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)

     # sample = next(iter(split_ds))
     # fake = generator(sample)
     # fake = fake[0]* 0.5 + 0.5
     # #para que PIL Image deje guardar (mult por 255 and change by uint8)
     # fake = np.array(fake) * 255
     # fake = fake.astype(np.uint8)
     # fake_img = Image.fromarray(fake)

     # #for save
     # directory = save_path + clase 
     # if not os.path.exists(directory):
     #     os.makedirs(directory)

     # salve_path = directory + '/' + filename
     # fake_img.save(salve_path) 

### CycleGan data

In [None]:
gen_path = '../../../../../../Datasets/Parkinson/radiological/PPMI/spect-mri/filtered/' 
csv_test = os.path.join(gen_path + 'control_pd_SPECT_fullRois_TEST.csv')
mri_test_df = pd.read_csv(csv_test, sep=',', header=None)
mri_test_df.columns = ["path", "label"]

In [None]:
mri_test_df.groupby('label').count()

In [None]:
#test
mri_test_df[['case_number', 'slice_number']] = mri_test_df['path'].str.extract(r'_case_(\d+)_slice_(\d+).png').astype(int)
mri_test_df_v2 = mri_test_df[(mri_test_df['slice_number'] > 41) & (mri_test_df['slice_number'] < 132)]
mri_test_df_v2.drop('slice_number', axis=1, inplace=True)
mri_test_df_v2.drop('case_number', axis=1, inplace=True)

print("len mri_tset_df_v2: ", len(mri_test_df_v2))

In [None]:
save_path = "../imgs_results/full_rois/mri_to_spect/mri_filtered_slices/"

#since we want to convert MRI to SPECT then we load the generator g
generator = generator_g

size = (256,256)
rgb = True

for i in range(len(mri_test_df_v2)):
     
     data_list = list()
     
     path = mri_test_df_v2.iloc[i]['path']
     #/control_case_3104_slice_042.png
     general_info = path.split('/')[-1]
     current_general_info = general_info.split('.')[0]
     
     clase = current_general_info.split('_')[0]
     id_img = current_general_info.split('_')[-1]
     
     # # load and resize the image
     pixels = load_img(path, target_size=size, color_mode= "rgb")
     # convert to numpy array
     pixels = img_to_array(pixels)

     if rgb==False:
          #convert rgb to gray
          pixels = rgb2gray(gen_path + filename, size, pixels)
     else:
          None

     data_list.append(pixels)
     img_array = asarray(data_list)

     split_ds = tf.data.Dataset.from_tensor_slices(img_array)
     split_ds = split_ds.map(preprocess_image_test, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)

     sample = next(iter(split_ds))
     fake = generator(sample)
     fake = fake[0]* 0.5 + 0.5
     #para que PIL Image deje guardar (mult por 255 and change by uint8)
     fake = np.array(fake) * 255
     fake = fake.astype(np.uint8)
     fake_img = Image.fromarray(fake)

     #for save
     directory = save_path + clase 
     if not os.path.exists(directory):
         os.makedirs(directory)

     salve_path = directory + '/' + general_info
     fake_img.save(salve_path) 