In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
#allows to import generator and discriminator
!pip install -q git+https://github.com/tensorflow/examples.git

In [None]:
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix
from os import listdir
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import img_to_array
from numpy import vstack
from numpy import asarray
from numpy import savez_compressed
import numpy as np
from PIL import Image
from tensorflow.keras.utils import plot_model

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
from sklearn import preprocessing

#AUTOTUNE = tf.data.AUTOTUNE
AUTOTUNE = tf.data.experimental.AUTOTUNE

from PIL import Image
import glob
import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tqdm import tqdm
import seaborn as sns

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
LAMBDA = 10

# <font color='red'>**Useful methods**</font>

In [None]:
# load all images in a directory into memory
def load_images(path, size=(256,256)):
    data_list = list()
    #enumerate filenames in directory, assume all are images
    for filename in listdir(path):
        # load and resize the image
        pixels = load_img(path + filename, target_size=size)
        # convert to numpy array
        pixels = img_to_array(pixels)
        # store
        data_list.append(pixels)
    return asarray(data_list)

**Data augmentation techniques**

In [None]:
def random_crop(image):
    cropped_image = tf.image.random_crop(image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

    return cropped_image

# scaling the images to [-1, 1]
def normalize(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def random_jitter(image):
    # resizing to 286 x 286 x 3
    image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # randomly cropping to 256 x 256 x 3
    image = random_crop(image)

    # random mirroring
    image = tf.image.random_flip_left_right(image)

    return image

**Preprocess splits**

In [None]:
def preprocess_image_train(image):
    image = random_jitter(image)
    image = normalize(image)
    return image

def preprocess_image_test(image):
    image = normalize(image)
    return image

**Import and reuse the Pix2Pix models**

In [None]:
OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

**Initializing optimizers, generator and discriminators**

In [None]:
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# <font color='red'>**Loading models**</font>

In [None]:
checkpoint_path = "../models/cyclegan/preprocessed/mri_to_spect/"
ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)
                           #nbi_cls_model=nbi_cls_model)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')
ckpt.restore(ckpt_manager.latest_checkpoint)
if ckpt_manager.latest_checkpoint:
    print("Restored from {}".format(ckpt_manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

# <font color='red'>**Predicting over full test subjects**</font>
## Main
### Original CycleGan data

In [None]:
gen_path = '../../../../../../Datasets/Parkinson/radiological/PPMI/spect-mri/filtered/' 
csv_test = os.path.join(gen_path + 'control_pd_SPECT_fullRois_TRAIN.csv')
mri_test_df = pd.read_csv(csv_test, sep=',', header=None)
mri_test_df.columns = ["path", "label"]

mri_test_df.groupby('label').count()

In [None]:
#test
mri_test_df[['case_number', 'slice_number']] = mri_test_df['path'].str.extract(r'_case_(\d+)_slice_(\d+).png').astype(int)
mri_test_df_v2 = mri_test_df[(mri_test_df['slice_number'] > 41) & (mri_test_df['slice_number'] < 132)]
mri_test_df_v2.drop('slice_number', axis=1, inplace=True)
mri_test_df_v2.drop('case_number', axis=1, inplace=True)

print("len mri_tset_df_v2: ", len(mri_test_df_v2))

In [None]:
control_df = mri_test_df_v2[mri_test_df_v2['label'] == "control"]
parkinson_df = mri_test_df_v2[mri_test_df_v2['label'] == "parkinson"]

print(len(control_df))
print(len(parkinson_df))

# Cycle consistency loss

$$ X \rightarrow G(X) \rightarrow F(G(X)) \sim \hat{X} $$

In this case 
* $X:$ T1-MRI
* $G(X):$ dopaminergic estimation from T1-MRI
* $F(G(X)):$ reverse process to get T1-MRI from the dopaminergic estimation

In [None]:
def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  return LAMBDA * loss1

In [None]:
def get_subjects(df):
    cases = []
    for i in range(len(df)):
        general_inf = df.iloc[i]['path']
        img_name = general_inf.split("/")[-1]
        case = img_name.split("_")[2]
        cases.append(case)

    unique_cases = set(cases)
    unique_cases = list(unique_cases)
    return unique_cases

In [None]:
unique_control_cases = get_subjects(control_df)
unique_pd_cases = get_subjects(parkinson_df)

print("len unique_control_cases: ", len(unique_control_cases))
print("len unique_pd_cases: ", len(unique_pd_cases))

In [None]:
def get_subjects_errors(df, unique_cases):
    size = (IMG_WIDTH, IMG_WIDTH)
    subjects, errors_ind, errors_avg  = [], [], []
    
    for i in range(len(unique_cases)):
        
        #getting the filtered dataframe regarding the unique case
        filtered_df = df[df['path'].str.contains(unique_cases[i])]
        subjects.append(unique_cases[i])
        
        for j in range(len(filtered_df)):
                  
            data_list = list()
            
            path = filtered_df.iloc[j]['path']  
            # # load and resize the image
            pixels = load_img(path, target_size=size, color_mode= "rgb")
            # convert to numpy array
            pixels = img_to_array(pixels)
            
            data_list.append(pixels)
            img_array = asarray(data_list)

            split_ds = tf.data.Dataset.from_tensor_slices(img_array)
            split_ds = split_ds.map(preprocess_image_test, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)

            real_x_img = next(iter(split_ds))
            fake_y = generator_g.predict(real_x_img, verbose=0)
            #fake_y = generator_g(real_x_img, training=False)
            cycled_x = generator_f.predict(fake_y, verbose=0)
            
            total_cycle_loss = calc_cycle_loss(real_x_img, cycled_x)
            #print("total_cycle_loss: ", total_cycle_loss.numpy())            
            errors_ind.append(total_cycle_loss.numpy())
        
        avg_error = np.mean(errors_ind)
        errors_avg.append(avg_error)
        
    return subjects, errors_avg

In [None]:
control_pat, control_error_avg = get_subjects_errors(control_df, unique_control_cases)
pd_pat, pd_error_avg = get_subjects_errors(parkinson_df, unique_pd_cases)

print("len control_pat: ", len(control_pat))
print("len control_error: ", len(control_error_avg))
print("len pd_pat: ", len(pd_pat))
print("len pd_error: ", len(pd_error_avg))

In [None]:
y_true = [0] * len(control_pat) + [1] * len(pd_pat)
ecm = control_error_avg + pd_error_avg

print("len y_true: ", len(y_true))
print("len ecm: ", len(ecm))

In [None]:
ids = np.arange(0, len(control_pat))
ids = ids.tolist()

In [None]:
# Plotting both lines
plt.figure(figsize=(8, 6))
plt.plot(ids, control_error_avg, marker='o', linestyle='-', color='b', label='Avg error Control')
plt.plot(ids, pd_error_avg, marker='s', linestyle='--', color='r', label='Avg error PD')

# Adding labels and title
plt.xlabel('Subject ID')
plt.ylabel('Cycle Loss')
plt.title('Control against PD')

# Adding legend
plt.legend()

# Displaying the plot
plt.grid(True)
plt.show()

### Boxplot and violin plots

In [None]:
# Create boxplots using Seaborn
sns.boxplot(data=[control_error_avg, pd_error_avg])

# Add labels and title
plt.xlabel('Data')
plt.ylabel('Error Value')
plt.title('Boxplots of Cycle loss values')

# Customize x-axis labels
plt.xticks([0, 1], ['Control Error Values', 'PD Error Values'])

# Show the plot
plt.show()

In [None]:
# Create violin plots using Seaborn
sns.violinplot(data=[control_error_avg, pd_error_avg])

# Add labels and title
plt.xlabel('Data')
plt.ylabel('Error Value')
plt.title('Violin Plots of Cycle loss values')

# Customize x-axis labels
plt.xticks([0, 1], ['Control Error Values', 'PD Error Values'])

# Show the plot
plt.show()

### Precision and recall curves 

In [None]:
from sklearn.metrics import precision_recall_curve

precision, recall, thresholds = precision_recall_curve(y_true, ecm)

# Plot Precision-Recall curve against thresholds
plt.plot(thresholds, precision[:-1], label='Precision')
plt.plot(thresholds, recall[:-1], label='Recall')

plt.xlabel('Threshold')
plt.ylabel('Precision/Recall')
plt.title('Precision and Recall vs. Threshold Curve TEST set')

plt.legend()
plt.grid(True)

plt.show()