In [1]:
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 29 11:04:41 2023
@author: 20192757
"""
import random
from pathlib import Path

import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from scipy.spatial.distance import directed_hausdorff
from tqdm.auto import tqdm
import SimpleITK as sitk

import u_net
import utils

def dice_score(x, y, eps=1e-5):
    return (2*(x*y).sum()+eps) / ((x+y).sum()+eps)

  from .autonotebook import tqdm as notebook_tqdm


Make changes in DATA_DIR for path to test data, numbers for patient numbers in test data, and CHECKPOINTS_DIR for path to model. Also number_fake has to be the model with the best result previously.

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# to ensure reproducible training/validation split
random.seed(42)

# directorys with data and to stored training checkpoints
DATA_DIR = Path.cwd() / "TrainingData" / "TrainingData" 

# number of fake images with best performance
number_fake = 12

# path to model with best performance 
CHECKPOINTS_DIR = Path.cwd()/"final_results"/"final_results_old_unet"/"60_epochs_{x}_number_of_fake".format(x=number_fake)/"model.pth"

# patient numbers of test data
numbers = ['p107','p117','p120','p133','p116'] 


In [3]:
# hyperparameters
NO_VALIDATION_PATIENTS = 5
IMAGE_SIZE = [64, 64]

# find patient folders in training directory
# excluding hidden folders (start with .)
patients = [
    path
    for path in DATA_DIR.glob("*")
    if not any(part.startswith(".") for part in path.parts)
]

pat = []

for i in numbers: 
    #i is the patient number as string, so 'p107'
    locals()[ i ] = [] 
    #makes lists of patient numbers individually, so p107 = []

    for j in patients: 
        #loop over the paths 
        if i in str(j): 
            #when the patient number as a string is in the path, it is added to the list of that patient number
            locals()[i].append(j)
            pat.append(locals()[i])


# load validation data
for i in range(len(pat)): 
    valid_dataset = utils.ProstateMRDataset(pat[i], IMAGE_SIZE)
    valid_dataloader = DataLoader(valid_dataset, batch_size=1)

    unet_model = u_net.UNet(num_classes=1).to(device)
    unet_model.load_state_dict(torch.load(CHECKPOINTS_DIR, map_location=device))
    unet_model.eval()

    slices = [] 
    with torch.no_grad():
        for image, target in tqdm(valid_dataloader):
            image = image.to(device)
            target = target[:,0:1].to(device)

            output = torch.sigmoid(unet_model(image))

            prediction = torch.round(output)
            
            slices.append(prediction.numpy())
    threedim = sitk.GetImageFromArray(slices)
    writer = sitk.ImageFileWriter()
    writer.SetFileName("Patient{x}.mhd".format(x=numbers[i]))
    writer.Execute(threedim)



  9%|▉         | 8/86 [00:01<00:15,  4.91it/s]


KeyboardInterrupt: 

In [9]:
# Load image and segmentation of atlas patient / patients
#atlas_im_path = os.path.join(atlas_path, 'mr_bffe.mhd')
itk_image102 = sitk.ReadImage("Patientp107.mhd")
image_array102 = sitk.GetArrayViewFromImage(itk_image102)
# Print the image dimensions
print(image_array102.shape)

(86, 1, 1, 64, 64)


In [8]:
matplotlib qt 

In [9]:
from IndexTracker import IndexTracker
fig, ax = plt.subplots(1, 2)
tracker1 = IndexTracker(ax[0], image_array102[:,0,0,:,:]);
fig.canvas.mpl_connect('scroll_event', tracker1.onscroll);

plt.show()

up 3.0
up 2.0
up 1.0
up 1.0
up 2.0
up 2.0
up 2.0
up 1.0
up 3.0
up 2.0
up 1.0
up 2.0
up 3.0
up 1.0
down -1.0
down -2.0
down -2.0
down -1.0
down -1.0
down -1.0
down -2.0
down -2.0
down -1.0
down -1.0
down -1.0
down -2.0
down -2.0
down -1.0


In [10]:
patients

[WindowsPath('C:/Users/20192757/Documents/Year 4 (Master)/Q3/Capita selecta MIA/TrainingData/TrainingData/p102'),
 WindowsPath('C:/Users/20192757/Documents/Year 4 (Master)/Q3/Capita selecta MIA/TrainingData/TrainingData/p107'),
 WindowsPath('C:/Users/20192757/Documents/Year 4 (Master)/Q3/Capita selecta MIA/TrainingData/TrainingData/p108'),
 WindowsPath('C:/Users/20192757/Documents/Year 4 (Master)/Q3/Capita selecta MIA/TrainingData/TrainingData/p109'),
 WindowsPath('C:/Users/20192757/Documents/Year 4 (Master)/Q3/Capita selecta MIA/TrainingData/TrainingData/p115'),
 WindowsPath('C:/Users/20192757/Documents/Year 4 (Master)/Q3/Capita selecta MIA/TrainingData/TrainingData/p116'),
 WindowsPath('C:/Users/20192757/Documents/Year 4 (Master)/Q3/Capita selecta MIA/TrainingData/TrainingData/p117'),
 WindowsPath('C:/Users/20192757/Documents/Year 4 (Master)/Q3/Capita selecta MIA/TrainingData/TrainingData/p119'),
 WindowsPath('C:/Users/20192757/Documents/Year 4 (Master)/Q3/Capita selecta MIA/Training

Just for a test

In [None]:
# for i in patientsNumbers: 
#     locals()[ i ] = []

# for i in patients:
#     if 'p107' in str(i):
#         p107.append(i)
#         pat.append(p107)
#     elif 'p117' in str(i): 
#         p117.append(i)
#         pat.append(p117)
#     elif 'p120' in str(i):
#         p120.append(i)
#         pat.append(p120)
#     elif 'p133' in str(i):
#         p133.append(i)
#         pat.append(p133)
#     elif 'p116' in str(i): 
#         p116.append(i)
#         pat.append(p116)