In [None]:
##import nessesary library
import nibabel as nib
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
import glob
import json
import seaborn as sns
import pandas as pd

In [None]:
## setting directories to extract and store MRI volume 

path_to_raw_training_samples_from_source_domain = 'dataset/raw_dataset/training_data/source_sample/ceT1_MRI_scans/'
path_to_raw_training_labels_from_source_domain = 'dataset/raw_dataset/training_data/source_sample/ceT1_MRI_labels/'
path_to_raw_training_samples_from_target_domain = 'dataset/raw_dataset/training_data/target_sample/'
path_to_raw_validation_samples_from_target_domain = 'dataset/raw_dataset/validation_data/' 

path_to_preprocessed_training_samples_from_source_domain = 'dataset/preprocessed_dataset/training_data/source_sample/ceT1_MRI_scans/'
path_to_preprocessed_training_labels_from_source_domain = 'dataset/preprocessed_dataset/training_data/source_sample/ceT1_MRI_labels/'
path_to_preprocessed_training_samples_from_target_domain = 'dataset/preprocessed_dataset/training_data/target_sample/'
path_to_preprocessed_validation_samples_from_target_domain = 'dataset/preprocessed_dataset/validation_data/' 

path_to_raw_prediction_mask = 'prediction_mask/raw_prediction/'
path_to_postprocessed_prediction_mask = 'prediction_mask/postprocessed_prediction/'

In [None]:
training_samples_from_source_domain = glob.glob(path_to_raw_training_samples_from_source_domain+'*')
training_labels_from_source_domain = glob.glob(path_to_raw_training_labels_from_source_domain+'*')
training_samples_from_target_domain = glob.glob(path_to_raw_training_samples_from_target_domain+'*')
validation_samples_from_target_domain = glob.glob(path_to_raw_validation_samples_from_target_domain+'*')

print('Number of training samples from source domain (ceT1) is: ' + str(len(training_samples_from_source_domain)))
print('Number of training labels from source domain (ceT1) is: ' + str(len(training_labels_from_source_domain)))
print('Number of training samples from target domain (hrT2) is: ' + str(len(training_samples_from_target_domain)))
print('Number of validation samples from target domain (hrT2) is: ' + str(len(validation_samples_from_target_domain)))

In [None]:
## display random training sample from source domain (ceT1)
img_load = nib.load(training_samples_from_source_domain[0]).get_fdata()
labels_load = nib.load(training_labels_from_source_domain[0]).get_fdata()

for i in range(16):
    plt.subplot(4, 4,i + 1)
    plt.imshow(img_load[:,:,29+i])
    plt.imshow(labels_load[:,:,29+i],alpha=0.5)
    plt.gcf().set_size_inches(200, 200)
plt.show()

In [None]:
## display random training sample from target domain (hrT2)
img_load = nib.load(training_samples_from_target_domain[0]).get_fdata()

for i in range(16):
    plt.subplot(4, 4,i + 1)
    plt.imshow(img_load[:,:,29+i])
    plt.gcf().set_size_inches(200, 200)
plt.show()

In [None]:
## display random validation sample from target domain (hrT2)
img_load = nib.load(validation_samples_from_target_domain[0]).get_fdata()

for i in range(16):
    plt.subplot(4, 4,i + 1)
    plt.imshow(img_load[:,:,29+i])
    plt.gcf().set_size_inches(200, 200)
plt.show()

## Image Preprocessing

In [None]:
def resample_volume(volume, interpolator = sitk.sitkLinear):
    new_spacing = [0.6, 0.6, 1.0]
    original_spacing = volume.GetSpacing()
    original_size = volume.GetSize()
    new_size = [int(round(osz*(ospc/nspc))) for osz,ospc,nspc in zip(original_size, original_spacing, new_spacing)]
    return sitk.Resample(volume, new_size, sitk.Transform(), interpolator, volume.GetOrigin(), new_spacing, volume.GetDirection(), 0, volume.GetPixelID())

In [None]:
def get_new_z_range(training_image_load,labels_load):
    nda = sitk.GetArrayFromImage(labels_load)
    x = np.sum(nda,axis = (1,2))
    x_arg = np.argwhere(x>0)
    starting_layer_of_ROI = x_arg[0]
    ending_layer_of_ROI = x_arg[-1]
    expand = (120-len(x_arg))/2
    
    if expand >= starting_layer_of_ROI:
        [cropped_starting_layer_of_ROI,cropped_ending_layer_of_ROI] = [0, 120]
        
    elif (expand + 1 + ending_layer_of_ROI) >= x.shape[0]:
        [cropped_starting_layer_of_ROI,cropped_ending_layer_of_ROI] = [x.shape[0]-120, x.shape[0]]
        
    else:
        cropped_starting_layer_of_ROI = int((s-np.floor(expand))[0])
        cropped_ending_layer_of_ROI = int((e+np.ceil(expand)+1)[0])
    
    z_label_crop = labels_load[:,:,cropped_starting_layer_of_ROI:cropped_ending_layer_of_ROI]
    z_image_crop = training_image_load[:,:,cropped_starting_layer_of_ROI:cropped_ending_layer_of_ROI]
    return (z_image_crop,z_label_crop)

In [None]:
def center_crop_on_training_sample(img, center_x, center_y, label):
    left = int(center_x - np.ceil(256/2))
    bottom = int(center_y - np.ceil(256/2))
    center_cropped_img = img[left:left+256, bottom:bottom+256, :]
    if label is None:
        return center_cropped_img
    else:
        center_cropped_label = label[left:left+256, bottom:bottom+256, :]
        return (center_cropped_img, center_cropped_label)

In [None]:
def center_crop_on_validation_sample(img, center_x, center_y):
    left = int(center_x - np.ceil(368/2))
    bottom = int(center_y - np.ceil(368/2))
    center_cropped_img = img[left:left+368, bottom:bottom+368, :]
    return center_cropped_img

In [None]:
def finding_center_z_axis(image):
    nda = sitk.GetArrayFromImage(image)
    thrid_percentile = np.percentile(nda,75)
    center_y = round(np.mean(np.argwhere(nda>=thrid_percentile)[:,1]))
    center_x = round(np.mean(np.argwhere(nda>=thrid_percentile)[:,2]))
    return (center_x,center_y)

In [None]:
def label_clustering(label_volume,target_label):
    labels_load = sitk.GetArrayFromImage(label_volume)
    label = np.zeros(labels_load.shape)
    for j in range(labels_load.shape[0]):
        label[j,:,:] = (labels_load[j,:,:] > 0.5)*target_label
    return label

In [None]:
## image preprocessing on training samples from source domain (ceT1)
for i in range(len(training_samples_from_source_domain)):
    
    save_path_for_preprocessed_training_samples_from_source_domain = path_to_preprocessed_training_samples_from_source_domain + 'crossmoda_'+str(i+1)+'_ceT1.nii'
    save_path_for_preprocessed_training_labels_from_source_domain = path_to_preprocessed_training_labels_from_source_domain + 'crossmoda_'+str(i+1)+'_Label.nii'
    
    training_sample_volume = sitk.ReadImage(training_samples_from_source_domain[i])
    training_label_volume = sitk.ReadImage(training_labels_from_source_domain[i])
    
    training_label_volume_nda = sitk.GetArrayFromImage(training_label_volume)
    label_1 = np.zeros(training_label_volume_nda.shape)
    label_2 = np.zeros(training_label_volume_nda.shape)
    
    for j in range(training_label_volume_nda.shape[0]):
        label_1[j,:,:] = (training_label_volume_nda[j,:,:] == 1)*1
        label_2[j,:,:] = (training_label_volume_nda[j,:,:] == 2)*1
        
    label_1_volume = sitk.GetImageFromArray(label_1, isVector=False)
    label_1_volume.SetSpacing(training_label_volume.GetSpacing())
    
    label_2_volume = sitk.GetImageFromArray(label_2, isVector=False)
    label_2_volume.SetSpacing(training_label_volume.GetSpacing())
        
    resampled_training_sample_from_source_domain = resample_volume(training_sample_volume)
    
    resampled_training_label_1_from_source_domain = resample_volume(label_1_volume)
    resampled_training_label_2_from_source_domain = resample_volume(label_2_volume)
    
    resampled_training_clusterred_label_1_from_source_domain = label_clustering(resampled_training_label_1_from_source_domain,1)
    resampled_training_clusterred_label_2_from_source_domain = label_clustering(resampled_training_label_2_from_source_domain,2)
    
    resampled_training_label_from_source_domain = resampled_training_clusterred_label_1_from_source_domain + resampled_training_clusterred_label_2_from_source_domain
    resampled_training_label_from_source_domain = np.where(resampled_training_label_from_source_domain > 2, 1,resampled_training_label_from_source_domain)
    resampled_training_label_from_source_domain = sitk.GetImageFromArray(resampled_training_label_from_source_domain, isVector=False)
    resampled_training_label_from_source_domain.SetSpacing(resampled_training_label_from_source_domain.GetSpacing())
    
    (center_x,center_y) = finding_center_z_axis(resampled_training_sample_from_source_domain)
    (center_cropped_img, center_cropped_label) = center_crop_on_training_sample(resampled_training_sample_from_source_domain, center_x, center_y, label=resampled_training_label_from_source_domain)
    (cropped_training_sample_from_source_domain, cropped_training_label_from_source_domain) = get_new_z_range(center_cropped_img,center_cropped_label)

    normalised_training_sample_from_source_domain = sitk.RescaleIntensity(cropped_training_sample_from_source_domain,outputMinimum=0.0, outputMaximum=255.0)

    sitk.WriteImage(normalised_training_sample_from_source_domain, save_path_for_preprocessed_training_samples_from_source_domain)
    sitk.WriteImage(cropped_training_label_from_source_domain, save_path_for_preprocessed_training_labels_from_source_domain)

In [None]:
## image preprocessing on training labels from target domain (hrT2)
for i in range(len(training_samples_from_target_domain)):
    
    save_path_for_preprocessed_training_samples_from_target_domain = path_to_preprocessed_training_samples_from_target_domain + 'crossmoda_'+str(i+106)+'_hrT2.nii'
    
    training_sample_volume = sitk.ReadImage(training_samples_from_target_domain[i])
    resampled_training_sample_from_target_domain = resample_volume(training_sample_volume)
    
    (center_x,center_y) = finding_center_z_axis(resampled_training_sample_from_target_domain)
    cropped_training_sample_from_target_domain = center_crop_on_training_sample(resampled_training_sample_from_target_domain, center_x, center_y, label = None)

    normalised_training_sample_from_target_domain = sitk.RescaleIntensity(cropped_training_sample_from_target_domain,outputMinimum=0.0, outputMaximum=255.0)
    sitk.WriteImage(normalised_training_sample_from_target_domain, save_path_for_preprocessed_training_samples_from_target_domain)

In [None]:
## image preprocessing on validation samples from target domain (hrT2)
for i in range(len(validation_samples_from_target_domain)):
    
    save_path_for_preprocessed_validation_samples_from_target_domain = path_to_preprocessed_validation_samples_from_target_domain + 'crossmoda_'+str(i+211)+'_hrT2.nii'
    
    validation_sample_from_target_domain = sitk.ReadImage(validation_samples_from_target_domain[i])
    
    (center_x,center_y) = finding_center_z_axis(validation_sample_from_target_domain)
    cropped_validation_sample_from_target_domain = center_crop_on_validation_sample(validation_sample_from_target_domain, center_x, center_y)

    normalised_validation_sample_from_target_domain = sitk.RescaleIntensity(cropped_validation_sample_from_target_domain,outputMinimum=0.0, outputMaximum=255.0)
    sitk.WriteImage(normalised_validation_sample_from_target_domain, save_path_for_preprocessed_validation_samples_from_target_domain)

In [None]:
preprocessed_training_samples_from_source_domain = glob.glob(path_to_preprocessed_training_samples_from_source_domain+'*')
preprocessed_training_labels_from_source_domain = glob.glob(path_to_preprocessed_training_labels_from_source_domain+'*')
preprocessed_training_samples_from_target_domain = glob.glob(path_to_preprocessed_training_samples_from_target_domain+'*')
preprocessed_validation_samples_from_target_domain = glob.glob(path_to_preprocessed_validation_samples_from_target_domain+'*')

print('Number of preprocessed training samples from source domain (ceT1) is: ' + str(len(preprocessed_training_samples_from_source_domain)))
print('Number of preprocessed training labels from source domain (ceT1) is: ' + str(len(preprocessed_training_labels_from_source_domain)))
print('Number of preprocessed training samples from target domain (hrT2) is: ' + str(len(preprocessed_training_samples_from_target_domain)))
print('Number of preprocessed validation samples from target domain (hrT2) is: ' + str(len(preprocessed_validation_samples_from_target_domain)))

In [None]:
## display random preprocessed training sample from source domain (ceT1)
img_load = nib.load(preprocessed_training_samples_from_source_domain[0]).get_fdata()
labels_load = nib.load(preprocessed_training_labels_from_source_domain[0]).get_fdata()

for i in range(16):
    plt.subplot(4, 4,i + 1)
    plt.imshow(img_load[:,:,44+i])
    plt.imshow(labels_load[:,:,44+i],alpha=0.5)
    plt.gcf().set_size_inches(200, 200)
plt.show()

In [None]:
## display random preprocessed training sample from target domain (hrT2)
img_load = nib.load(preprocessed_training_samples_from_target_domain[0]).get_fdata()

for i in range(16):
    plt.subplot(4, 4,i + 1)
    plt.imshow(img_load[:,:,29+i])
    plt.gcf().set_size_inches(200, 200)
plt.show()

In [None]:
## display random preprocessed validation sample from target domain (hrT2)
img_load = nib.load(preprocessed_validation_samples_from_target_domain[0]).get_fdata()

for i in range(16):
    plt.subplot(4, 4,i + 1)
    plt.imshow(img_load[:,:,29+i])
    plt.gcf().set_size_inches(200, 200)
plt.show()

## Example of fake hrT2 volume generated by CUT model using real ceT1 volume

In [None]:
## display random preprocessed validation sample from target domain (hrT2)
source_dir = 'dataset/fake_hrT2_MRI_scan/fake_hrT2_volume/'
fake_hrT2_volume = glob.glob(source_dir+'*')

img_load = nib.load(fake_hrT2_volume[0]).get_fdata()

for i in range(16):
    plt.subplot(4, 4,i + 1)
    plt.imshow(img_load[:,:,44+i])
    plt.gcf().set_size_inches(200, 200)
plt.show()

The code regarding CUT model and nnU-Net model training and inferencing can be found on https://drive.google.com/drive/folders/1oZbSLDman28BaQufGXKWHdgpc72lYlmP?usp=sharing

## Reducing tumour signal on generated hrT2 volume

In [None]:
destination_base_dir = 'dataset/fake_hrT2_MRI_scan/fake_hrT2_volume_with_reduced_tumour_signal/'

postprocessed_training_labels_from_source_domain = glob.glob(path_to_preprocessed_training_labels_from_source_domain+'*')

for i in range(len(postprocessed_training_labels_from_source_domain)):
    labels_load = nib.load(postprocessed_training_labels_from_source_domain[i]).get_fdata()
    vs_label = np.where(labels_load == 1, 0.5,1)

    original_volume_load = nib.load(fake_hrT2_volume[i])

    reduce_volume_np = np.multiply(original_volume_load.get_fdata(), vs_label)

    destination = destination_base_dir+'crossmoda_'+str(i)+'_fake_hrT2_with_reduced_signal.nii'

    new_volume = nib.Nifti1Image(reduce_volume_np, original_volume_load.affine)

    nib.save(new_volume, destination)

In [None]:
## display random preprocessed validation sample from target domain (hrT2)
fake_hrT2_with_reduced_signal = glob.glob(destination_base_dir+'*')
img_load = nib.load(fake_hrT2_with_reduced_signal[0]).get_fdata()

for i in range(16):
    plt.subplot(4, 4,i + 1)
    plt.imshow(img_load[:,:,44+i])
    plt.gcf().set_size_inches(200, 200)
plt.show()

## Image Postprocessing on prediction mask generated ny nnU-Net segmentation model

In [None]:
def get_new_z_range_post_processed(sample):
    nda = sitk.GetArrayFromImage(sample)
    x = np.sum(nda,axis = (1,2))
    x_arg = np.argwhere(x>0)
    s = x_arg[0]
    e = x_arg[-1]
    expand = (120-len(x_arg))/2
    if expand >= s:
        [ns,ne] = [0,120]
    elif (expand + 1 + e) >= x.shape[0]:
        [ns,ne] = [x.shape[0]-120,x.shape[0]]
    else:
        ns = int((s-np.floor(expand))[0])
        ne = int((e+np.ceil(expand)+1)[0])
    return (ns,ne)

In [None]:
def uncrop(sample, center_x, center_y, prediction):
    left = int(center_x - np.ceil(368/2))
    bottom = int(center_y - np.ceil(368/2))
    uncrop_prediction = np.zeros(sitk.GetArrayFromImage(sample).shape)
    center_cropped_sample = sample[left:left+368, bottom:bottom+368, :]
    (ns, ne) = get_new_z_range_post_processed(center_cropped_sample)
    prediction_nda = sitk.GetArrayFromImage(prediction)
    uncrop_prediction[ns:ne, bottom:bottom+368, left:left+368] = prediction_nda
    return uncrop_prediction

In [None]:
prediction_masks = glob.glob(path_to_raw_prediction_mask+'*')
postprocessed_prediction_masks = glob.glob(path_to_postprocessed_prediction_mask+'*')

print('Number of prediction masks is: ' + str(len(prediction_masks)))
print('Number of prediction masks is: ' + str(len(postprocessed_prediction_masks)))

In [None]:
## display random preprocessed prediction mask generated by nnU-Net
img_load = nib.load(preprocessed_validation_samples_from_target_domain[0]).get_fdata()
labels_load = nib.load(prediction_masks[0]).get_fdata()

for i in range(16):
    plt.subplot(4, 4,i + 1)
    plt.imshow(img_load[:,:,24+i])
    plt.imshow(labels_load[:,:,24+i],alpha=0.5)
    plt.gcf().set_size_inches(200, 200)
plt.show()

In [None]:
for i in range(len(prediction_masks)):
    
    save_path_for_postprocessed_prediction_mask = path_to_postprocessed_prediction_mask + 'crossmoda_'+str(i+211)+'_Label.nii'

    validation_sample_before_cropping = sitk.ReadImage(validation_samples_from_target_domain[i])
    prediction = sitk.ReadImage(prediction_masks[i])

    (center_x,center_y) = finding_center_z_axis(validation_sample_before_cropping)

    uncrop_prediction = uncrop(validation_sample_before_cropping, center_x, center_y, prediction)

    new_prediction_volume = sitk.GetImageFromArray(uncrop_prediction, isVector=False)


    new_spacing = validation_sample_before_cropping.GetSpacing()
    
    new_prediction_volume.SetSpacing(new_spacing)

    sitk.WriteImage(new_prediction_volume, save_path_for_postprocessed_prediction_mask)

In [None]:
prediction_masks = glob.glob(path_to_raw_prediction_mask+'*')
postprocessed_prediction_masks = glob.glob(path_to_postprocessed_prediction_mask+'*')

print('Number of prediction masks is: ' + str(len(prediction_masks)))
print('Number of prediction masks is: ' + str(len(postprocessed_prediction_masks)))

In [None]:
## display random postprocessed validation sample from target domain (hrT2)
img_load = nib.load(validation_samples_from_target_domain[0]).get_fdata()
labels_load = nib.load(postprocessed_prediction_masks[0]).get_fdata()

for i in range(16):
    plt.subplot(4, 4,i + 1)
    plt.imshow(img_load[:,:,24+i])
    plt.imshow(labels_load[:,:,24+i],alpha=0.5)
    plt.gcf().set_size_inches(200, 200)
plt.show()

## Framework Performance Evaluation

In [None]:
baseline_experiment_result_path = 'results/baseline_experiment_result.json'
replicate_work_result_path = 'results/replicate_work_result.json'
proposed_work_result_path = 'results/proposed_work_result.json'

with open(baseline_experiment_result_path, 'r') as x:
    baseline_experiment_result = json.load(x)

with open(replicate_work_result_path, 'r') as y:
    replicated_work_result = json.load(y)
    
with open(proposed_work_result_path, 'r') as z:
    proposed_work_result = json.load(z)

In [None]:
baseline_experiment_VS_ASSD = list(baseline_experiment_result['case']['VS_ASSD'].values())
baseline_experiment_VS_Dice = list(baseline_experiment_result['case']['VS_Dice'].values())
baseline_experiment_Cochlea_ASSD = list(baseline_experiment_result['case']['Cochlea_ASSD'].values())
baseline_experiment_Cochlea_Dice = list(baseline_experiment_result['case']['Cochlea_Dice'].values())
baseline_experiment_Mean_Dice = list(baseline_experiment_result['case']['Mean_Dice'].values())

replicated_work_VS_ASSD = list(replicated_work_result['case']['VS_ASSD'].values())
replicated_work_VS_Dice = list(replicated_work_result['case']['VS_Dice'].values())
replicated_work_Cochlea_ASSD = list(replicated_work_result['case']['Cochlea_ASSD'].values())
replicated_work_Cochlea_Dice = list(replicated_work_result['case']['Cochlea_Dice'].values())
replicated_work_Mean_Dice = list(replicated_work_result['case']['Mean_Dice'].values())

proposed_work_VS_ASSD = list(proposed_work_result['case']['VS_ASSD'].values())
proposed_work_VS_Dice = list(proposed_work_result['case']['VS_Dice'].values())
proposed_work_Cochlea_ASSD = list(proposed_work_result['case']['Cochlea_ASSD'].values())
proposed_work_Cochlea_Dice = list(proposed_work_result['case']['Cochlea_Dice'].values())
proposed_work_Mean_Dice = list(proposed_work_result['case']['Mean_Dice'].values())

In [None]:
## display the aggregates result from baselin experiment
baseline_result_table = pd.DataFrame.from_dict(baseline_experiment_result['aggregates'])
del baseline_result_table['gt_fname']
del baseline_result_table['pred_fname']
baseline_result_table[0:-2]

In [None]:
## display the aggregates result from baselin experiment
replicated_work_result_table = pd.DataFrame.from_dict(replicated_work_result['aggregates'])
del replicated_work_result_table['gt_fname']
del replicated_work_result_table['pred_fname']
replicated_work_result_table[0:-2]

In [None]:
## display the aggregates result from baselin experiment
proposed_work_result_table = pd.DataFrame.from_dict(proposed_work_result['aggregates'])
del proposed_work_result_table['gt_fname']
del proposed_work_result_table['pred_fname']
proposed_work_result_table[0:-2]

In [None]:
composite_VS_ASSD = pd.DataFrame({'DA without self training':replicated_work_VS_ASSD,'DA with self training':proposed_work_VS_ASSD})
sns.set_theme(style="whitegrid")
plt.figure(figsize=(10,8))
ax = sns.violinplot(data=composite_VS_ASSD)
sns.swarmplot(data=composite_VS_ASSD, color="w", alpha=0.8)
ax.set_title("Average symmetric surface distance comparison on vestibular schwannoma segmentation")
ax.set_ylabel("ASSD")
plt.show()

In [None]:
composite_VS_Dice = pd.DataFrame({'DA without self training':replicated_work_VS_Dice,'DA with self training':proposed_work_VS_Dice})
sns.set_theme(style="whitegrid")
plt.figure(figsize=(10,8))
ax = sns.violinplot(data=composite_VS_Dice)
sns.swarmplot(data=composite_VS_Dice, color="w", alpha=0.8)
ax.set_title("Dice similarity coefficient comparison on vestibular schwannoma segmentation")
ax.set_ylabel("DSC")
plt.show()

In [None]:
composite_Cochlea_Dice = pd.DataFrame({'DA without self training':replicated_work_Cochlea_Dice,'DA with self training':proposed_work_Cochlea_Dice})
sns.set_theme(style="whitegrid")
plt.figure(figsize=(10,8))
ax = sns.violinplot(data=composite_Cochlea_Dice)
sns.swarmplot(data=composite_Cochlea_Dice, color="w", alpha=0.8)
ax.set_title("Dice similarity coefficient comparison on bilateral cochlea segmentation")
ax.set_ylabel("DSC")
plt.show()