# Imports and Installations

In [None]:
import numpy as np
import tensorflow as tf
import glob
import matplotlib.pyplot as plt
from skimage.transform import resize
from tensorflow.keras.utils import plot_model
import pathlib
import imageio
import glob
import PIL
import nibabel as nib
import os
from tkinter import Tcl
import cv2
import tensorflow_docs.vis.embed as embed
import tensorflow_addons.layers as tfal
from keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,Conv2D,Conv2DTranspose,LeakyReLU,Activation,Concatenate,Add
from scipy import ndimage
import shutil
import json
import tensorflow.keras.layers as L

# Further helper functions

In [3]:
#Normalization of images for the synthesis model training and testing
def preprocess_image_train(image):
    image = (image/127.5)-1
    return image

# This function is to generate the GIF images throughout the training schedule
def generate_images_GIF(img_input, model, img_true, mode, order):
    prediction = model(img_input)
    pred_vol = prediction[0, :, :, 0].numpy().copy()
    error = tf.image.ssim(img_true, prediction, max_val=2)
    img_input = np.rot90(img_input[0, :, :, 0], 3)
    img_true = np.rot90(img_true[0, :, :, 0], 3)
    prediction = np.rot90(prediction[0, :, :, 0], 3)

    plt.figure(figsize=(10, 6))
    if mode == 1:
        display_list = [img_input, prediction, img_true]
        title = [f'{seq_1} True', f'{seq_2} predicted', f'{seq_2} True']

    else:
        display_list = [img_input, prediction, img_true]
        title = [f'{seq_2} True', f'{seq_1} predicted', f'{seq_1} True']

    plots_path_T1_FLAIR = r'E:\Graduation Project\GIFs and Models\Brats {}\Predicted\{}-{}-GIF'.format(brats_num, seq_1,
                                                                                                       seq_2)
    plots_path_FLAIR_T1 = r'E:\Graduation Project\GIFs and Models\Brats {}\Predicted\{}-{}-GIF'.format(brats_num, seq_2,
                                                                                                       seq_1)
    if not os.path.exists(plots_path_T1_FLAIR):
        os.makedirs(plots_path_T1_FLAIR)
    if not os.path.exists(plots_path_FLAIR_T1):
        os.makedirs(plots_path_FLAIR_T1)

    for i in range(3):
        plt.subplot(1, 3, i + 1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5, cmap='gray')
        plt.axis('off')
        if mode == 1:
            plt.savefig(
                r'E:\Graduation Project\GIFs and Models\Brats {}\Predicted\{}-{}-GIF\{}.png'.format(brats_num, seq_1,
                                                                                                    seq_2, order))
        if mode == 2:
            plt.savefig(
                r'E:\Graduation Project\GIFs and Models\Brats {}\Predicted\{}-{}-GIF\{}.png'.format(brats_num, seq_2,
                                                                                                    seq_1, order))
    plt.show()
    return error, pred_vol

# Predicting (generating) the images, without calculating the loss
def predict_image(img_input, model):
    prediction = model(img_input)
    pred_vol = prediction[0, :, :, 0].numpy().copy()
    return pred_vol

def predict_image(img_input, model, img_true):
    prediction = model(img_input)
    pred_vol = prediction[0, :, :, 0].numpy().copy()
    # error = tf.image.ssim(img_true, prediction, max_val=2)
    return 0, pred_vol

# Predicting (generating) the images, and calculating the loss
def predict_image_and_calc_loss(img_input, model, img_true):
    prediction = model(img_input)
    pred_vol = prediction[0, :, :, 0].numpy().copy()
    # error = tf.image.ssim(img_true, prediction, max_val=2)
    return 0, pred_vol

def predict_image_and_NO_loss(img_input, model):
    prediction = model(img_input)
    pred_vol = prediction[0, :, :, 0].numpy().copy()
    return pred_vol

# generate black sequences in the same shape as the other input sequences with the heirarchy and name format, to be used in segmentation comparison cases
def black_seq_generator(test_path, brats_num, T1_FLAG=True, T2_FLAG=True, FLAIR_FLAG=True):
    test_data_list = sorted(glob.glob(test_path + '/*'))
    original_vol_path = sorted(glob.glob(test_path + '/*'))[0]
    original_vol = nib.load(original_vol_path)
    original_shape = original_vol.shape

    v = np.zeros(original_shape)
    v = nib.Nifti1Image(v, original_vol.affine)  # to save this 3D (ndarry) numpy

    if FLAIR_FLAG:
        nib.save(v, test_path + '/' + f'BraTS2021_0{brats_num:04d}_flair.nii.gz')
    if T1_FLAG:
        nib.save(v, test_path + '/' + f'BraTS2021_0{brats_num:04d}_t1.nii.gz')
    if T2_FLAG:
        nib.save(v, test_path + '/' + f'BraTS2021_0{brats_num:04d}_t2.nii.gz')

    if T1_FLAG and T2_FLAG:
        nib.save(v, test_path + '/' + f'BraTS2021_0{brats_num:04d}_t1.nii.gz')
        nib.save(v, test_path + '/' + f'BraTS2021_0{brats_num:04d}_t2.nii.gz')
    if T1_FLAG and FLAIR_FLAG:
        nib.save(v, test_path + '/' + f'BraTS2021_0{brats_num:04d}_t1.nii.gz')
        nib.save(v, test_path + '/' + f'BraTS2021_0{brats_num:04d}_flair.nii.gz')
    if T2_FLAG and FLAIR_FLAG:
        nib.save(v, test_path + '/' + f'BraTS2021_0{brats_num:04d}_flair.nii.gz')
        nib.save(v, test_path + '/' + f'BraTS2021_0{brats_num:04d}_t2.nii.gz')
    print("Done")

def copy_subfolders_into_another_folder(paths_txt, source_folder, destination_folder):
    # Read the contents of the file
    with open(paths_txt, 'r') as file:
        contents = file.read()

    # Replace single quotes with double quotes
    contents = contents.replace("'", "\"")

    # Load the JSON array, the list of subfolder names to copy
    subfolder_names = json.loads(contents)

    # Loop through each item in the source folder
    for item in os.listdir(source_folder):
        # If the item is a subfolder and its name is in the list
        if os.path.isdir(os.path.join(source_folder, item)) and item in subfolder_names:
            # Copy the subfolder to the destination folder
            shutil.copytree(os.path.join(source_folder, item), os.path.join(destination_folder, item))

    copy_subfolders_into_another_folder

# The following function is responsible for returning the indices of the brain of the volume that contains foreground voxels.
def find_brain_width_wise(dep, hei, i, img):        #cropping width wise
    slice2D = img.get_fdata()[:, i, :]
    for j in range(hei):
        for k in range(dep):
            if slice2D[j, k] != 0:
                return i
    return 0

def find_brain_height_wise(dep, wid, i, img):      #cropping height wise
    slice2D = img.get_fdata()[i, :, :]
    for j in range(wid):
        for k in range(dep):
            if slice2D[j, k] != 0:
                return i
    return 0

def find_brain_depth_wise(wid, hei, i, img):        #cropping depth wise
    slice2D = img.get_fdata()[:, :, i]
    for j in range(wid):
        for k in range(hei):
            if slice2D[j, k] != 0:
                return i
    return 0


###   Datasets

In [4]:
root_path = r'Path Of BraTS Validation Brain Cropped Volumes'
data_list = sorted(glob.glob(root_path + '/*'))                       #list of paths of the inside subjects

results_path = r'Path Of BraTS Validation Brain Cropped 2D Images '
results_data_list = sorted(glob.glob(results_path + '/*'))            #list of all (IMAGES) for each sequence for each subject


data_list, results_data_list
# To double check the size of our Data
len(sorted(glob.glob(results_path + '/*')))
seq_1 = 'T2'
seq_2 = 'FLAIR'

In [None]:
data_list, results_data_list


# The current in-use model archeticture, Squeeze and Excitation Attention GANs

# GENERATOR 

In [None]:
import tensorflow.keras.layers as L
from tensorflow.keras.models import Model

def conv_block(x, num_filters):
    x = L.Conv2D(num_filters, 3, padding="same")(x)
    x = tfal.InstanceNormalization(axis=-1)(x)
    x = L.Activation("relu")(x)

    x = L.Conv2D(num_filters, 3, padding="same")(x)
    x = tfal.InstanceNormalization(axis=-1)(x)
    x = L.Activation("relu")(x)

    return x


def se_block(x, num_filters, ratio=8):
    se_shape = (1, 1, num_filters)
    se = L.GlobalAveragePooling2D()(x)
    se = L.Reshape(se_shape)(se)
    se = L.Dense(num_filters // ratio, activation="relu", use_bias=False)(se)
    se = L.Dense(num_filters, activation="sigmoid", use_bias=False)(se)
    se = L.Reshape(se_shape)(se)
    x = L.Multiply()([x, se])
    return x

def encoder_block(x, num_filters):
    x = conv_block(x, num_filters)
    x = se_block(x, num_filters)
    p = L.MaxPool2D((2, 2))(x)
    return x, p

def decoder_block(x, s, num_filters):
    x = L.UpSampling2D(interpolation="bilinear")(x)
    x = L.Concatenate()([x, s])
    x = conv_block(x, num_filters)
    x = se_block(x, num_filters)
    return x

def squeeze_attention_unet(input_shape=(256, 256, 3)):
    """ Inputs """
    inputs = L.Input(input_shape)

    """ Encoder """
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)


    b1 = conv_block(p4, 1024)
    b1 = se_block(b1, 1024)
    

    """ Decoder """
    d =  decoder_block(b1, s4, 512)
    d1 = decoder_block(d, s3, 256)
    d2 = decoder_block(d1, s2, 128)
    d3 = decoder_block(d2, s1, 64)

    """ Outputs """
    outputs = L.Conv2D(3, (1, 1), activation='tanh')(d3)

    """ Model """
    
    model = Model(inputs, outputs, name="Squeeze-Attention-UNET")
    return model


# DISCRIMINATOR 

In [None]:
def downsample(filters, size, apply_norm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                                               kernel_initializer=initializer, use_bias=False))
    
    if apply_norm:
        result.add(tfal.InstanceNormalization(axis=-1))
    result.add(tf.keras.layers.LeakyReLU())
  #  result.add(tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2, padding='same' ))
    return result


def discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
    x = inp
    down1 = downsample(64, 4, False)(x) # (bs, 16, 16, 64)
    down2 = downsample(128, 4)(down1)
    down3 = downsample(256, 4)(down2)
    

    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer,
                                  use_bias=False)(zero_pad1) # (bs, 31, 31, 512)
    norm1 = tfal.InstanceNormalization()(conv)
    leaky_relu = tf.keras.layers.LeakyReLU()(norm1)
    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = tf.keras.layers.Conv2D(3, 4, strides=1, kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)
    return tf.keras.Model(inputs=inp, outputs=last)

# CHECK POINT

In [None]:

generator_g = squeeze_attention_unet()
discriminator_x = discriminator()
generator_g_optimizer = tf.keras.optimizers.legacy.Adam(2e-10, beta_1=0.5 )
discriminator_x_optimizer = tf.keras.optimizers.legacy.Adam(2e-10, beta_1=0.5 )


# Loading the "GANs" model
checkpoint_path = r"Path of the Last Check Model"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           discriminator_x=discriminator_x,
                           generator_g_optimizer=generator_g_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=300)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print(f'Last Check Point: {ckpt_manager.latest_checkpoint}')
    print('Latest checkpoint restored!!')

print("Generator's parameters = {:,}".format(generator_g.count_params()))
print("Discriminator's parameters = {:,}".format(discriminator_x.count_params()))

# Copy from root to result

In [None]:
root_path, results_path

In [None]:
#our results path, in this case, for africa brats (synthesizing flair from t2)
FLAIR_syn_result_vol_path = r'Path of Saved The prediction volumes'

# results_data_list = sorted(glob.glob(FLAIR_syn_result_vol_path + '/*'))

# Get a list of all the subfolders in the source folder
subfolders = [f.path for f in os.scandir(root_path) if f.is_dir()]

# Copy each subfolder to the destination folders
for folder in subfolders:
    shutil.copytree(folder, os.path.join(FLAIR_syn_result_vol_path, os.path.basename(folder)))

print(len(glob.glob(FLAIR_syn_result_vol_path + '/*')))
FLAIR_syn_result_vols = sorted(glob.glob(FLAIR_syn_result_vol_path + '/*'))

#count the number of file (double checking that everything is going well)
total_files = 0
for root, dirs, files in os.walk(FLAIR_syn_result_vol_path):
    total_files += len(files)
print('Total number of files in folder and subfolders:', total_files)

#double checking that for each subfolder, I have all 4 sequences + the segmentation
for i in (sorted(glob.glob(FLAIR_syn_result_vol_path + '/*'))):
  x = len( sorted(glob.glob(i + '/*')))
  if x != 5:
    print(i)
print(len(FLAIR_syn_result_vols))

# Delete T1

In [None]:
for dirpath, dirnames, filenames in os.walk(T1_syn_result_vol_path):
    for filename in filenames:
        if filename.endswith('t1.nii.gz'):
            file_path = os.path.join(dirpath, filename)
            os.remove(file_path)
#             print(f"{file_path} has been deleted.")

# Delete T2

In [None]:
for dirpath, dirnames, filenames in os.walk(T2_syn_result_vol_path):
    for filename in filenames:
        if filename.endswith('t2.nii.gz'):
            file_path = os.path.join(dirpath, filename)
            os.remove(file_path)
#             print(f"{file_path} has been deleted.")

# Delete FLAIR

In [10]:
for dirpath, dirnames, filenames in os.walk(FLAIR_syn_result_vol_path):
    for filename in filenames:
        if filename.endswith('flair.nii.gz'):
            file_path = os.path.join(dirpath, filename)
            os.remove(file_path)

# Delete T1c

In [None]:
for dirpath, dirnames, filenames in os.walk(FLAIR_syn_result_vol_path):
    for filename in filenames:
        if filename.endswith('t1ce.nii.gz'):
            file_path = os.path.join(dirpath, filename)
            os.remove(file_path)

### double checking that for each subfolder, I have ONLY 3 sequences (as one sequence should be deleted to be predicted later) + the segmentation


In [None]:
for i in (sorted(glob.glob(FLAIR_syn_result_vol_path + '/*'))):
  x = len(sorted(glob.glob(i + '/*')))
  if x != 4:
    print(i)

total_files = 0
for root, dirs, files in os.walk(FLAIR_syn_result_vol_path):
    total_files += len(files)

print('Total number of files in folder and subfolders:', total_files)

print(len(sorted(glob.glob(FLAIR_syn_result_vol_path + '/*'))))

In [12]:
SSIM1 = []
SSIM2 = []

vol_1 = []
vol_2 = []

ssim_1_list = []
ssim_2_list = []

order_1 = 0
order_2 = 0

ssim_score_1 = 0
ssim_score_2 = 0

seq_1 = 'T2'
seq_2 = 'FLAIR'

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

for i, path in enumerate(results_data_list):
    T1_path = os.path.join(path, "images", "GIF T2")
    dep = len(glob.glob(T1_path + '/**/*'))

    GIF_T1 = tf.keras.preprocessing.image_dataset_from_directory(
                                  T1_path,
                                  seed=123,
                                  image_size=(256, 256),
                                  batch_size=1,
                                  shuffle = False)

    GIF_T1 = GIF_T1.cache().prefetch(buffer_size=AUTOTUNE)
    GIF_T1 = GIF_T1.map(lambda x, _: (preprocess_image_train(x)))

    for image_x in GIF_T1:
        img_1 = predict_image_and_NO_loss(image_x, generator_g)
        vol_2.append(img_1)

    original_vol_path = sorted(glob.glob(data_list[i] + '/*'))[0]
    original_vol = nib.load(original_vol_path)
    original_shape = original_vol.shape

    vol_2 = np.array(vol_2).transpose(1, 2, 0)

    vol_2 = ndimage.zoom(vol_2, (original_shape[0]/vol_2.shape[0],
                                 original_shape[1]/vol_2.shape[1],
                                 original_shape[2]/vol_2.shape[2]), order=0)

    v2 = nib.Nifti1Image(np.array(vol_2), original_vol.affine)
    FLAIR_name = os.path.basename(sorted(glob.glob(data_list[i] + '/*'))[0])  #for flair
    # FLAIR_name = os.path.basename(sorted(glob.glob(data_list[i] + '/*'))[3])   #for T1ce

    FLAIR_res_path = os.path.join(FLAIR_syn_result_vols[i], FLAIR_name)
    nib.save(v2, FLAIR_res_path)

    vol_1 = []
    vol_2 = []
    print(f"Volume #{i} is done")

#### The following cell is to ZIP the results dataset (now containing, for each subject, 3 original sequences, 1 synthesized sequence, corresponding segmentation map), to be used in segmentation task

In [14]:
import shutil
import os

source_folder_path = r'Path of Saved The prediction volumes'
destination_folder_path = r'Path of Saved The prediction volumes as Zip'
zip_filename = 'syn'
zip_filepath = os.path.join(destination_folder_path, zip_filename)
shutil.make_archive(zip_filepath, 'zip', source_folder_path)
print("Zipping complete.")

Zipping complete.
