In [1]:
!pip install nibabel scipy pandas numpy medpy

Defaulting to user installation because normal site-packages is not writeable


In [3]:
# Metrics Computation
import os
from glob import glob
import time
import re
import argparse
import nibabel as nib
import pandas as pd
from medpy.metric.binary import hd, dc
import numpy as np

HEADER = ["Name", "Dice LV", "Volume LV", "Err LV(ml)",
          "Dice RV", "Volume RV", "Err RV(ml)",
          "Dice MYO", "Volume MYO", "Err MYO(ml)"]

#
# Utils functions used to sort strings into a natural order
#
def conv_int(i):
    return int(i) if i.isdigit() else i

def natural_order(sord):
    """
    Sort a (list,tuple) of strings into natural order.
    """
    if isinstance(sord, tuple):
        sord = sord[0]
    return [conv_int(c) for c in re.split(r'(\d+)', sord)]

#
# Utils function to load and save nifti files with the nibabel package
#
def load_nii(img_path):
    """
    Load a 3D NIfTI file without resizing.
    Args:
        img_path: Path to the NIfTI file
    Returns:
        data: Original numpy array
        affine: Original affine matrix
        header: Original header
    """
    nimg = nib.load(img_path)
    data = nimg.get_fdata()
    
    return data, nimg.affine, nimg.header

def save_nii(img_path, data, affine, header):
    """Save a NIfTI file."""
    nimg = nib.Nifti1Image(data, affine=affine, header=header)
    nimg.to_filename(img_path)

#
# Metrics computation functions
#
def metrics(img_gt, img_pred, voxel_size):
    """
    Compute metrics between two segmentation maps.
    Both inputs should already be resized to matching dimensions.
    """
    if img_gt.shape != img_pred.shape:
        raise ValueError(f"Shape mismatch: GT {img_gt.shape} vs Pred {img_pred.shape}")

    res = []
    for c in [3, 1, 2]:  # Process each class (LV, RV, MYO)
        gt_c = np.copy(img_gt)
        gt_c[gt_c != c] = 0
        pred_c = np.copy(img_pred)
        pred_c[pred_c != c] = 0

        # Clip and compute metrics
        gt_c = np.clip(gt_c, 0, 1)
        pred_c = np.clip(pred_c, 0, 1)

        dice = dc(gt_c, pred_c)
        volpred = pred_c.sum() * np.prod(voxel_size) / 1000.
        volgt = gt_c.sum() * np.prod(voxel_size) / 1000.

        res += [dice, volpred, volpred-volgt]

    return res

def compute_metrics_on_files(path_gt, path_pred):
    """Compute metrics for a single pair of files."""
    gt, _, header = load_nii(path_gt)
    pred, _, _ = load_nii(path_pred)
    zooms = header.get_zooms()

    name = os.path.basename(path_gt).split('.')[0]
    res = metrics(gt, pred, zooms)
    res = ["{:.3f}".format(r) for r in res]

    formatting = "{:>14}, {:>7}, {:>9}, {:>10}, {:>7}, {:>9}, {:>10}, {:>8}, {:>10}, {:>11}"
    print(formatting.format(*HEADER))
    print(formatting.format(name, *res))

def compute_metrics_on_directories(dir_gt, dir_pred):
    """Batch process all files in directories."""
    lst_gt = sorted(glob(os.path.join(dir_gt, '*')), key=natural_order)
    lst_pred = sorted(glob(os.path.join(dir_pred, '*')), key=natural_order)

    res = []
    for p_gt, p_pred in zip(lst_gt, lst_pred):
        if os.path.basename(p_gt) != os.path.basename(p_pred):
            raise ValueError(f"Name mismatch: {os.path.basename(p_gt)} vs {os.path.basename(p_pred)}")

        gt, _, header = load_nii(p_gt)
        pred, _, _ = load_nii(p_pred)
        zooms = header.get_zooms()
        res.append(metrics(gt, pred, zooms))

    lst_name_gt = [os.path.basename(gt).split(".")[0] for gt in lst_gt]
    res = [[n,] + r for r, n in zip(res, lst_name_gt)]
    df = pd.DataFrame(res, columns=HEADER)
    df.to_csv(f"results_{time.strftime('%Y%m%d_%H%M%S')}.csv", index=False)

def main(path_gt, path_pred):
    """Entry point for file or directory processing."""
    if os.path.isfile(path_gt) and os.path.isfile(path_pred):
        compute_metrics_on_files(path_gt, path_pred)
    elif os.path.isdir(path_gt) and os.path.isdir(path_pred):
        compute_metrics_on_directories(path_gt, path_pred)
    else:
        raise ValueError("Paths must be both files or both directories")

# Define paths to the ground truth and predictions
gt_path = "database/testing/segmentation"   # Change this to your ground truth folder
pred_path = "Segmentations_ensemble"  # Change this to your predictions folder

# Run the evaluation
main(gt_path, pred_path)