In [36]:
from pathlib import Path
import glob
import os
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
import SimpleITK as sitk
import imageio.v2 as imageio
from PIL import Image

import torch
from torch import nn
import torchvision as TV
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.nn import UpsamplingNearest2d
from torch.nn.utils import spectral_norm
from torch.distributions import Normal
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights

import utils
import u_net

path_data=r"C:\Users\Luuk\Desktop\Independent_Test_Data\Test_Data"




In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


### Use the U-Net for the INDEPENDENT test dataset

### Load the test data

In [None]:
# directory with test data
DATA_DIR = Path.cwd() / "Independent_Test_Data" / "Test_Data"
print(DATA_DIR)

# check how the patient files are named
patients = [
    path
    for path in glob.glob(path_data+r"\p*[0-9]")
]
print(len(patients))

partition = {
    "patient_1": patients[:1],
    "patient_2": patients[1:2],
    "patient_3": patients[2:3],
    "patient_4": patients[3:4],
    "patient_5": patients[4:5],
}

IMAGE_SIZE = [64, 64]

# load test data
#test_dataset = utils.ProstateMRDataset(patients, IMAGE_SIZE)
dataset_test_p1 = utils.ProstateMRDataset(partition["patient_1"], IMAGE_SIZE)
dataset_test_p2 = utils.ProstateMRDataset(partition["patient_2"], IMAGE_SIZE)
dataset_test_p3 = utils.ProstateMRDataset(partition["patient_3"], IMAGE_SIZE)
dataset_test_p4 = utils.ProstateMRDataset(partition["patient_4"], IMAGE_SIZE)
dataset_test_p5 = utils.ProstateMRDataset(partition["patient_5"], IMAGE_SIZE)

In [None]:
print(len(dataset_test_p1))
print(len(dataset_test_p2))
print(len(dataset_test_p3))
print(len(dataset_test_p4))
print(len(dataset_test_p5))

### Initializing the U-Net model with the weights after training

In [None]:
# initialise model with weights of training the U-Net model (run3 -> segmentation_model_weights)
BEST_EPOCH = 59 # epoch with the lowest validation loss
CHECKPOINTS_DIR = Path.cwd() / "run3/segmentation_model_weights" / f"u_net_{BEST_EPOCH}.pth" 
print(CHECKPOINTS_DIR)

#initialise model with weights of resnet
unet_model = u_net.UNet(num_classes=1)
unet_model.load_state_dict(torch.load(CHECKPOINTS_DIR))

### Use the U-Net for independent test dataset: 5 patients

### Patient 1

In [None]:
# set model to evaluation mode
unet_model.eval()

results_folder_1 = r'C:\Users\Luuk\Desktop\Independent_Test_Data\Results\Patient_1'

with torch.no_grad():
    for predict_index in range(86):
        (input, target) = dataset_p1[predict_index]
        output = torch.sigmoid(unet_model(input[np.newaxis, ...]))
        prediction = torch.round(output)
        
        # each image and mask should be saved in 2D
        predicted_mask = sitk.GetImageFromArray(prediction[0, 0])
        path_mask = results_folder_1 + '\{}_mask_2D.mhd'.format(predict_index)
        sitk.WriteImage(predicted_mask, path_mask)
        ground_truth = sitk.GetImageFromArray(target[0])
        path_ground_truth = results_folder_1 + '\{}_ground_truth_2D.mhd'.format(predict_index)
        sitk.WriteImage(ground_truth, path_ground_truth)
        MR_image = sitk.GetImageFromArray(input[0])
        path_image = results_folder_1 + '\{}_image_2D.mhd'.format(predict_index)
        sitk.WriteImage(MR_image, path_image)


In [None]:
# Use the 2D images that are saved to create 3D
mask_3D = []
ground_truth_3D = []
image_3D = []
for predict_index in range(86):
    read_path_mask = results_folder_1 + '\{}_mask_2D.mhd'.format(predict_index)
    read_mask = sitk.ReadImage(read_path_mask)
    array_mask = sitk.GetArrayFromImage(read_mask)
    mask_3D.append(array_mask)
    
    read_path_ground_truth = results_folder_1 + '\{}_ground_truth_2D.mhd'.format(predict_index)
    read_ground_truth = sitk.ReadImage(read_path_ground_truth)
    array_ground_truth = sitk.GetArrayFromImage(read_ground_truth)
    ground_truth_3D.append(array_ground_truth)
    
    read_path_image = results_folder_1 + '\{}_image_2D.mhd'.format(predict_index)
    read_image = sitk.ReadImage(read_path_image)
    array_image = sitk.GetArrayFromImage(read_image)
    image_3D.append(array_image)

save_path_mask = results_folder_1 + '\mask_3D.mhd'
save_mask_3D = sitk.GetImageFromArray(mask_3D)
sitk.WriteImage(save_mask_3D,save_path_mask)

save_path_ground_truth = results_folder_1 + '\ground_truth_3D.mhd'
save_ground_truth_3D = sitk.GetImageFromArray(ground_truth_3D)
sitk.WriteImage(save_ground_truth_3D,save_path_ground_truth)

save_path_image = results_folder_1 + '\image_3D.mhd'
save_image_3D = sitk.GetImageFromArray(image_3D)
sitk.WriteImage(save_image_3D,save_path_image)

mask_result = sitk.ReadImage(save_path_mask)
mask_array_result = sitk.GetArrayFromImage(mask_result)

ground_truth_result = sitk.ReadImage(save_path_ground_truth)
ground_truth_array_result = sitk.GetArrayFromImage(ground_truth_result)

image_result = sitk.ReadImage(save_path_image)
image_array_result = sitk.GetArrayFromImage(image_result)

# show each slice of the 3D image, 
for i in range(86):
    # Overlay ground truth on MR image & predicted mask on MR image
    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(image_array_result[i,:,:], cmap="gray")
    ax[0].imshow(ground_truth_array_result[i,:,:], cmap="gray", alpha=0.5)
    ax[0].set_title("Ground Truth, slice {}".format(i))
    ax[1].imshow(image_array_result[i,:,:], cmap="gray")
    ax[1].imshow(mask_array_result[i,:,:], cmap="gray", alpha=0.5)
    ax[1].set_title("Mask, slice {}".format(i))
    
    plt.show()
    

### Patient 2

In [None]:
# set model to evaluation mode
unet_model.eval()

results_folder_2 = r'C:\Users\Luuk\Desktop\Independent_Test_Data\Results\Patient_2'

with torch.no_grad():
    for predict_index in range(86):
        (input, target) = dataset_p2[predict_index]
        output = torch.sigmoid(unet_model(input[np.newaxis, ...]))
        prediction = torch.round(output)
        
        # each image and mask should be saved in 2D
        predicted_mask = sitk.GetImageFromArray(prediction[0, 0])
        path_mask = results_folder_2 + '\{}_mask_2D.mhd'.format(predict_index)
        sitk.WriteImage(predicted_mask, path_mask)
        ground_truth = sitk.GetImageFromArray(target[0])
        path_ground_truth = results_folder_2 + '\{}_ground_truth_2D.mhd'.format(predict_index)
        sitk.WriteImage(ground_truth, path_ground_truth)
        MR_image = sitk.GetImageFromArray(input[0])
        path_image = results_folder_2 + '\{}_image_2D.mhd'.format(predict_index)
        sitk.WriteImage(MR_image, path_image)

In [None]:
# Use the 2D images that are saved to create 3D
mask_3D = []
ground_truth_3D = []
image_3D = []
for predict_index in range(86):
    read_path_mask = results_folder_2 + '\{}_mask_2D.mhd'.format(predict_index)
    read_mask = sitk.ReadImage(read_path_mask)
    array_mask = sitk.GetArrayFromImage(read_mask)
    mask_3D.append(array_mask)
    
    read_path_ground_truth = results_folder_2 + '\{}_ground_truth_2D.mhd'.format(predict_index)
    read_ground_truth = sitk.ReadImage(read_path_ground_truth)
    array_ground_truth = sitk.GetArrayFromImage(read_ground_truth)
    ground_truth_3D.append(array_ground_truth)
    
    read_path_image = results_folder_2 + '\{}_image_2D.mhd'.format(predict_index)
    read_image = sitk.ReadImage(read_path_image)
    array_image = sitk.GetArrayFromImage(read_image)
    image_3D.append(array_image)

save_path_mask = results_folder_2 + '\mask_3D.mhd'
save_mask_3D = sitk.GetImageFromArray(mask_3D)
sitk.WriteImage(save_mask_3D,save_path_mask)

save_path_ground_truth = results_folder_2 + '\ground_truth_3D.mhd'
save_ground_truth_3D = sitk.GetImageFromArray(ground_truth_3D)
sitk.WriteImage(save_ground_truth_3D,save_path_ground_truth)

save_path_image = results_folder_3 + '\image_3D.mhd'
save_image_3D = sitk.GetImageFromArray(image_3D)
sitk.WriteImage(save_image_3D,save_path_image)

mask_result = sitk.ReadImage(save_path_mask)
mask_array_result = sitk.GetArrayFromImage(mask_result)

ground_truth_result = sitk.ReadImage(save_path_ground_truth)
ground_truth_array_result = sitk.GetArrayFromImage(ground_truth_result)

image_result = sitk.ReadImage(save_path_image)
image_array_result = sitk.GetArrayFromImage(image_result)

# show each slice of the 3D image, 
for i in range(86):
    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(image_array_result[i,:,:], cmap="gray")
    ax[0].imshow(ground_truth_array_result[i,:,:], cmap="gray", alpha=0.5)
    ax[0].set_title("Ground Truth, slice {}".format(i))
    ax[1].imshow(image_array_result[i,:,:], cmap="gray")
    ax[1].imshow(mask_array_result[i,:,:], cmap="gray", alpha=0.5)
    ax[1].set_title("Mask, slice {}".format(i))

    plt.show()

### Patient 3

In [None]:
# set model to evaluation mode
unet_model.eval()

results_folder_3 = r'C:\Users\Luuk\Desktop\Independent_Test_Data\Results\Patient_3'

with torch.no_grad():
    for predict_index in range(86):
        (input, target) = dataset_p3[predict_index]
        output = torch.sigmoid(unet_model(input[np.newaxis, ...]))
        prediction = torch.round(output)
        
        # each image, ground truth and mask should be saved in 2D
        predicted_mask = sitk.GetImageFromArray(prediction[0, 0])
        path_mask = results_folder_3 + '\{}_mask_2D.mhd'.format(predict_index)
        sitk.WriteImage(predicted_mask, path_mask)
        ground_truth = sitk.GetImageFromArray(target[0])
        path_ground_truth = results_folder_3 + '\{}_ground_truth_2D.mhd'.format(predict_index)
        sitk.WriteImage(ground_truth, path_ground_truth)
        MR_image = sitk.GetImageFromArray(input[0])
        path_image = results_folder_3 + '\{}_image_2D.mhd'.format(predict_index)
        sitk.WriteImage(MR_image, path_image)

In [None]:
# Use the 2D images that are saved to create 3D
mask_3D = []
ground_truth_3D = []
image_3D = []
for predict_index in range(86):
    read_path_mask = results_folder_3 + '\{}_mask_2D.mhd'.format(predict_index)
    read_mask = sitk.ReadImage(read_path_mask)
    array_mask = sitk.GetArrayFromImage(read_mask)
    mask_3D.append(array_mask)
    
    read_path_ground_truth = results_folder_3 + '\{}_ground_truth_2D.mhd'.format(predict_index)
    read_ground_truth = sitk.ReadImage(read_path_ground_truth)
    array_ground_truth = sitk.GetArrayFromImage(read_ground_truth)
    ground_truth_3D.append(array_ground_truth)
    
    read_path_image = results_folder_3 + '\{}_image_2D.mhd'.format(predict_index)
    read_image = sitk.ReadImage(read_path_image)
    array_image = sitk.GetArrayFromImage(read_image)
    image_3D.append(array_image)

save_path_mask = results_folder_3 + '\mask_3D.mhd'
save_mask_3D = sitk.GetImageFromArray(mask_3D)
sitk.WriteImage(save_mask_3D,save_path_mask)

save_path_ground_truth = results_folder_3 + '\ground_truth_3D.mhd'
save_ground_truth_3D = sitk.GetImageFromArray(ground_truth_3D)
sitk.WriteImage(save_ground_truth_3D,save_path_ground_truth)

save_path_image = results_folder_3 + '\image_3D.mhd'
save_image_3D = sitk.GetImageFromArray(image_3D)
sitk.WriteImage(save_image_3D,save_path_image)

mask_result = sitk.ReadImage(save_path_mask)
mask_array_result = sitk.GetArrayFromImage(mask_result)

ground_truth_result = sitk.ReadImage(save_path_ground_truth)
ground_truth_array_result = sitk.GetArrayFromImage(ground_truth_result)

image_result = sitk.ReadImage(save_path_image)
image_array_result = sitk.GetArrayFromImage(image_result)

# show each slice of the 3D image, 
for i in range(86):
    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(image_array_result[i,:,:], cmap="gray")
    ax[0].imshow(ground_truth_array_result[i,:,:], cmap="gray", alpha=0.5)
    ax[0].set_title("Ground Truth, slice {}".format(i))
    ax[1].imshow(image_array_result[i,:,:], cmap="gray")
    ax[1].imshow(mask_array_result[i,:,:], cmap="gray", alpha=0.5)
    ax[1].set_title("Mask, slice {}".format(i))

    plt.show()