In [None]:
import sys
import SimpleITK as sitk
import json
import glob
import os
from tqdm import tqdm
import numpy as np
import torch
sys.path.append("/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/src/utils")
from revert_normalisation import get_ct_normalisation_values, revert_normalisation
from image_metrics import ImageMetrics
import pandas as pd

In [None]:


ct_plan_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset203_synthrad2025_task1_CT/nnUNetPlans.json"
ct_mean, ct_std = get_ct_normalisation_values(ct_plan_path)
pred_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset202_synthrad2025_task1_MR_mask/nnUNetTrainerMRCT__nnUNetPlans__3d_fullres/fold_0/validation"
revert_normalisation(pred_path, ct_mean, ct_std, save_path=pred_path + "_revert_norm")


In [26]:
# compute image metrics for the predition folders
class ImageMetricsCompute(ImageMetrics):
    def __init__(self):
        super().__init__()
        self.names = ["mae", "psnr", "ms_ssim"]
    
    def init_storage(self, names: list):
        self.storage = dict()
        self.storage_id = []
        self.names = names
        for name in names:
            self.storage[name] = []

    def add(self, res: dict, patient_id=None):
        for key, value in res.items():
            self.storage[key].append(value)
        if patient_id:
            self.storage_id.append(patient_id)

    def aggregate(self):
        res = dict()
        for name in self.names:
            res[name] = dict()

        for key, value in self.storage.items():
            res[key]['mean'] = np.mean(value)
            res[key]['std'] = np.std(value)
            res[key]['max'] = np.max(value)
            res[key]['min'] = np.min(value)
            res[key]['25pc'] = np.percentile(value, 25)
            res[key]['50pc'] = np.percentile(value, 50)
            res[key]['75pc'] = np.percentile(value, 75)
            res[key]['count'] = len(value)
        return res

    def reset(self):
        for key, value in self.storage.items():
            self.storage[key] = []





In [29]:
testing_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset202_synthrad2025_task1_MR_mask/nnUNetTrainerMRCT__nnUNetPlans__3d_fullres/fold_0/validation_revert_norm"
pred_paths = sorted(glob.glob(os.path.join(testing_path, '*.mha')))

raw_data_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset202_synthrad2025_task1_MR_mask"
gt_path = os.path.join(raw_data_path, "gt_segmentations")
mask_path = os.path.join(raw_data_path, "masks")

testing_metrics = ImageMetricsCompute()
testing_metrics.init_storage(["mae", "psnr", "ms_ssim"])

for pred_path in tqdm(pred_paths):
    filename = os.path.basename(pred_path)
    gt_file = os.path.join(gt_path, filename)
    mask_file = os.path.join(mask_path, filename)

    img_pred = sitk.ReadImage(pred_path)
    img_gt = sitk.ReadImage(gt_file)
    img_mask = sitk.ReadImage(mask_file, sitk.sitkUInt8)

    array_pred = sitk.GetArrayFromImage(img_pred)
    array_gt = sitk.GetArrayFromImage(img_gt)
    array_mask = sitk.GetArrayFromImage(img_mask)

    res = testing_metrics.score_patient(array_gt, array_pred, array_mask)
    testing_metrics.add(res, filename)

# aggregate results
results = testing_metrics.aggregate()
print("Results:", results)







100%|██████████| 116/116 [21:16<00:00, 11.01s/it]


Results: {'mae': {'mean': np.float64(109.44085679290535), 'std': np.float64(27.066969212032742), 'max': np.float64(230.47372171558266), 'min': np.float64(66.28774953739256), '25pc': np.float64(91.29112447986269), '50pc': np.float64(101.21119620008118), '75pc': np.float64(125.07872458603839), 'count': 116}, 'psnr': {'mean': np.float64(25.776341633129967), 'std': np.float64(1.7933116425929951), 'max': np.float64(30.46496126493591), 'min': np.float64(20.271696054072127), '25pc': np.float64(24.711910392796135), '50pc': np.float64(26.12277162982101), '75pc': np.float64(26.900988097737205), 'count': 116}, 'ms_ssim': {'mean': np.float64(0.8419290628152716), 'std': np.float64(0.0718437246649582), 'max': np.float64(0.9525578654148602), 'min': np.float64(0.5338612903359046), '25pc': np.float64(0.8066457260674451), '50pc': np.float64(0.8520516393366632), '75pc': np.float64(0.8896230195890017), 'count': 116}}


In [31]:
results

{'mae': {'mean': np.float64(109.44085679290535),
  'std': np.float64(27.066969212032742),
  'max': np.float64(230.47372171558266),
  'min': np.float64(66.28774953739256),
  '25pc': np.float64(91.29112447986269),
  '50pc': np.float64(101.21119620008118),
  '75pc': np.float64(125.07872458603839),
  'count': 116},
 'psnr': {'mean': np.float64(25.776341633129967),
  'std': np.float64(1.7933116425929951),
  'max': np.float64(30.46496126493591),
  'min': np.float64(20.271696054072127),
  '25pc': np.float64(24.711910392796135),
  '50pc': np.float64(26.12277162982101),
  '75pc': np.float64(26.900988097737205),
  'count': 116},
 'ms_ssim': {'mean': np.float64(0.8419290628152716),
  'std': np.float64(0.0718437246649582),
  'max': np.float64(0.9525578654148602),
  'min': np.float64(0.5338612903359046),
  '25pc': np.float64(0.8066457260674451),
  '50pc': np.float64(0.8520516393366632),
  '75pc': np.float64(0.8896230195890017),
  'count': 116}}

In [None]:
df = pd.DataFrame(
            {
                'patient_id': testing_metrics.storage_id,
                'mae': testing_metrics.storage['mae'],
                'ssim': testing_metrics.storage['ssim'],
                'psnr': testing_metrics.storage['psnr'],
            }
        )

In [30]:
# metrics = ImageMetrics()
# src_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset202_synthrad2025_task1_MR_mask/gt_segmentations/1ABA005.mha"
# mask_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset202_synthrad2025_task1_MR_mask/masks/1ABA005.mha"
# pred_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset202_synthrad2025_task1_MR_mask/nnUNetTrainerMRCT__nnUNetPlans__3d_fullres/fold_0/validation/1ABA005.mha"
# # read images
# img_src = sitk.ReadImage(src_path)
# img_pred = sitk.ReadImage(pred_path)
# img_mask = sitk.ReadImage(mask_path, sitk.sitkUInt8)

# # compute scores
# array_src = sitk.GetArrayFromImage(img_src)
# array_pred = sitk.GetArrayFromImage(img_pred)
# array_mask = sitk.GetArrayFromImage(img_mask)

# print(metrics.score_patient(array_src, array_pred, array_mask))