In [1]:
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 29 11:04:41 2023
@author: 20192757
"""
import random
from pathlib import Path

import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from scipy.spatial.distance import directed_hausdorff
from tqdm.auto import tqdm

import u_net
import utils

def dice_score(x, y, eps=1e-5):
    return (2*(x*y).sum()+eps) / ((x+y).sum()+eps)


device = "cuda" if torch.cuda.is_available() else "cpu"

# to ensure reproducible training/validation split
random.seed(42)

# directorys with data and to stored training checkpoints
DATA_DIR = Path.cwd() / "TrainingData" / "TrainingData"

# this is my best epoch - what is yours?
for number_fake in [0,4,8,12,16,20,28,32]:

    CHECKPOINTS_DIR = Path.cwd() / "final_results" / "final_results_old_unet" /"60_epochs_{x}_number_of_fake".format(x=number_fake) / "model.pth"

    # hyperparameters
    NO_VALIDATION_PATIENTS = 3
    IMAGE_SIZE = [64, 64]

    # find patient folders in training directory
    # excluding hidden folders (start with .)
    patients = [
        path
        for path in DATA_DIR.glob("*")
        if not any(part.startswith(".") for part in path.parts)
    ]

    train_split = []
    validation_split = []

    for i in patients:
        if 'p107' in str(i):
            validation_split.append(i)
        elif 'p117' in str(i): 
            validation_split.append(i)
        elif 'p120' in str(i):
            validation_split.append(i)
        else:
            train_split.append(i)

    # print(train_split)
    # print(validation_split)

    # split in training/validation after shuffling
    partition = {
        "train": train_split,
        "validation": validation_split,
    }

    # load validation data
    valid_dataset = utils.ProstateMRDataset(partition["validation"], IMAGE_SIZE)
    valid_dataloader = DataLoader(valid_dataset, batch_size=1)

    unet_model = u_net.UNet(num_classes=1).to(device)
    unet_model.load_state_dict(torch.load(CHECKPOINTS_DIR, map_location=device))
    unet_model.eval()

    # TODO
    # apply for all images and compute Dice score with ground-truth.
    # output .mhd images with the predicted segmentations
    with torch.no_grad():
        DiceScores = []
        HausdorffDist = []
        for image, target in tqdm(valid_dataloader):
            image = image.to(device)
            target = target[:,0:1].to(device)

            output = torch.sigmoid(unet_model(image))

            prediction = torch.round(output)

            dice = dice_score(prediction, target)
            #if all(v == 0 for v in target):
            DiceScores.append(dice.cpu().numpy())
            HausdorffDist.append(directed_hausdorff(prediction[0,0].cpu().numpy(), target[0,0].cpu().numpy())[0])

    with open("DICE_{x}.txt".format(x=number_fake), "w") as f:
        for s in DiceScores:
            f.write(str(s) +"\n")

    with open("HD_{x}.txt".format(x=number_fake), "w") as f:
        for s in HausdorffDist:
            f.write(str(s) +"\n")

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 258/258 [00:05<00:00, 48.13it/s]
100%|██████████| 258/258 [00:05<00:00, 49.21it/s]
100%|██████████| 258/258 [00:05<00:00, 45.16it/s]
100%|██████████| 258/258 [00:05<00:00, 47.03it/s]
100%|██████████| 258/258 [00:05<00:00, 48.55it/s]
100%|██████████| 258/258 [00:05<00:00, 49.18it/s]
100%|██████████| 258/258 [00:05<00:00, 44.33it/s]
100%|██████████| 258/258 [00:05<00:00, 44.05it/s]
