In [33]:
from __future__ import print_function, absolute_import

import os
import warnings
import SimpleITK as sitk
import numpy as np
from tqdm import tqdm  # Import tqdm for progress bars
from atlas_registration_functions import register_transform

warnings.filterwarnings("ignore", category=DeprecationWarning)

# Paths
ELASTIX_PATH = r'D:\Elastix\elastix.exe'
TRANSFORMIX_PATH = r'D:\Elastix\transformix.exe'
DATA_PATH = r'D:\capita_selecta\DevelopmentData\DevelopmentData'
OUTPUT_DIR = r'D:\capita_selecta\results_experiments'

if not os.path.exists(ELASTIX_PATH):
    raise IOError('Elastix cannot be found, please set the correct ELASTIX_PATH.')
if not os.path.exists(TRANSFORMIX_PATH):
    raise IOError('Transformix cannot be found, please set the correct TRANSFORMIX_PATH.')

def calc_dice(true_del, est_del):
    # Ensure the arrays are binary (0s and 1s)
    true_del = (true_del > 0).astype(np.uint8)
    est_del = (est_del > 0).astype(np.uint8)

    intersection = np.sum(true_del * est_del)
    size1 = np.sum(true_del)
    size2 = np.sum(est_del)

    if size1 + size2 == 0:
        return 1.0  # If both are empty, define DICE as 1.0 (perfect match)

    return 2.0 * intersection / (size1 + size2)

In [34]:
# Get patient names and select atlas patients
patient_list = [patient for patient in os.listdir(DATA_PATH) if os.path.isdir(os.path.join(DATA_PATH, patient))]
atlas_patients = patient_list[:5]
register_patients = [patient[8:-4] for patient in os.listdir(OUTPUT_DIR) if patient.startswith("reg_maj_") and patient.endswith(".mhd")]

true_delineations = []
est_delineations = []

for patient in register_patients:
    true_delineation = os.path.join(DATA_PATH, patient, 'prostaat.mhd')
    est_delineation = os.path.join(OUTPUT_DIR, f'reg_maj_{patient}.mhd')

    true_delineation_image = sitk.ReadImage(true_delineation)
    est_delineation_image = sitk.ReadImage(est_delineation)

    true_delineations.append(sitk.GetArrayFromImage(true_delineation_image))
    est_delineations.append(sitk.GetArrayFromImage(est_delineation_image))

dice_scores = []

for i in range(len(true_delineations)):
    dice_score = calc_dice(true_delineations[i], est_delineations[i])
    print(f"Patient {register_patients[i]} registration reached a dice score of: {dice_score:.3f}")
    dice_scores.append(dice_score)


Patient p116 registration reached a dice score of: 0.034
Patient p117 registration reached a dice score of: 0.429
