# Setup
Downloading required pip packages and the preview dataset

In [1]:
!pip install TotalSegmentator "pyvista[jupyter]" "jupyterlab" SimpleITK medim monai --quiet

In [2]:
!gdown "https://drive.google.com/uc?export=download&id=1mgwuCpvTc3INnqGHiXhxJsbEzZPasplx" -O CT_subset_big.zip

Downloading...
From (original): https://drive.google.com/uc?export=download&id=1mgwuCpvTc3INnqGHiXhxJsbEzZPasplx
From (redirected): https://drive.google.com/uc?export=download&id=1mgwuCpvTc3INnqGHiXhxJsbEzZPasplx&confirm=t&uuid=16caf139-6ab4-4040-ac77-d6f4acfad759
To: /content/CT_subset_big.zip
100% 452M/452M [00:05<00:00, 89.2MB/s]


In [3]:
!unzip CT_subset_big.zip -d CT_Set

Archive:  CT_subset_big.zip
replace CT_Set/s0000/segmentations/iliopsoas_left.nii.gz? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace CT_Set/s0000/segmentations/iliac_artery_left.nii.gz? [y]es, [n]o, [A]ll, [N]one, [r]ename: N


# Segmentation


In [5]:
# @title ##### Select one of the included preview CT scans as well as the AI model to use:
import ipywidgets as widgets
from IPython.display import display
import os

# Get list of subject folders in CT_Set
ct_set_path = "CT_Set"
subject_folders = [f.path for f in os.scandir(ct_set_path) if f.is_dir() and f.name.startswith('s')]
subject_folders.sort()

# Create a dropdown widget to select a subject folder
subject_selector = widgets.Dropdown(
    options=subject_folders,
    description='Select Subject Folder:',
    disabled=False,
    style = {'description_width': '150px'}
)

# Define the available models
available_models = ['totalSegmentator', 'MONAI SegResNet', 'STU-Net-B']

# Create a dropdown widget for model selection
model_selector = widgets.Dropdown(
    options=available_models,
    description='Select AI Model:',
    disabled=False,
    style = {'description_width': '150px'}
)

display(subject_selector, model_selector)

Dropdown(description='Select Subject Folder:', options=('CT_Set/s0000', 'CT_Set/s0001', 'CT_Set/s0002', 'CT_Se…

Dropdown(description='Select AI Model:', options=('totalSegmentator', 'MONAI SegResNet', 'STU-Net-B'), style=D…

In [23]:
# @title Visualize the selected CT scan (optional)
import nibabel as nib
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interactive, Output
import os
from IPython.display import display

selected_folder = subject_selector.value

# Load the NIfTI file
file_path = os.path.join(selected_folder, "ct.nii.gz")
img = nib.load(file_path)
data = img.get_fdata()

# Get dimensions
n_slices_axial = data.shape[2]
n_slices_coronal = data.shape[1]
n_slices_sagittal = data.shape[0]

# Outputs for manual layout
out_axial = Output()
out_coronal = Output()
out_sagittal = Output()

# Define slice viewers for each axis
def view_axial_slice(slice_idx):
    with out_axial:
        out_axial.clear_output(wait=True)
        plt.figure(figsize=(4, 4))
        plt.imshow(data[:, :, slice_idx].T, cmap="gray", origin="lower")
        plt.title(f"Axial {slice_idx}")
        plt.axis("off")
        plt.show()

def view_coronal_slice(slice_idx):
    with out_coronal:
        out_coronal.clear_output(wait=True)
        plt.figure(figsize=(4, 4))
        plt.imshow(data[:, slice_idx, :].T, cmap="gray", origin="lower")
        plt.title(f"Coronal {slice_idx}")
        plt.axis("off")
        plt.show()

def view_sagittal_slice(slice_idx):
    with out_sagittal:
        out_sagittal.clear_output(wait=True)
        plt.figure(figsize=(4, 4))
        plt.imshow(data[slice_idx, :, :].T, cmap="gray", origin="lower")
        plt.title(f"Sagittal {slice_idx}")
        plt.axis("off")
        plt.show()

# Sliders
axial_slider = widgets.IntSlider(min=0, max=n_slices_axial - 1, value=n_slices_axial // 2, description="Axial")
coronal_slider = widgets.IntSlider(min=0, max=n_slices_coronal - 1, value=n_slices_coronal // 2, description="Coronal")
sagittal_slider = widgets.IntSlider(min=0, max=n_slices_sagittal - 1, value=n_slices_sagittal // 2, description="Sagittal")

# Link sliders to update functions
widgets.interactive_output(view_axial_slice, {'slice_idx': axial_slider})
widgets.interactive_output(view_coronal_slice, {'slice_idx': coronal_slider})
widgets.interactive_output(view_sagittal_slice, {'slice_idx': sagittal_slider})

# Show initial slices
view_axial_slice(axial_slider.value)
view_coronal_slice(coronal_slider.value)
view_sagittal_slice(sagittal_slider.value)

# Layout: three images in a row, sliders below
display(
    widgets.VBox([
        widgets.HBox([out_axial, out_coronal, out_sagittal]),
        widgets.HBox([axial_slider, coronal_slider, sagittal_slider])
    ])
)

VBox(children=(HBox(children=(Output(), Output(), Output())), HBox(children=(IntSlider(value=220, description=…

In [15]:
# @title Run inference and process output
from totalsegmentator.python_api import totalsegmentator
import requests
from google.colab import userdata
import os
import shutil
import zipfile
import io
import os
import nibabel as nib
import torch
import numpy as np
from monai.bundle import download
from monai.inferers import sliding_window_inference
from monai.transforms import (
    LoadImaged, Spacingd, Orientationd, ScaleIntensityd,
    EnsureTyped, EnsureChannelFirstd
)
from monai.data import Dataset, DataLoader
from monai.networks.nets import SegResNet
import medim
import torch.nn.functional as F


selected_folder = subject_selector.value

input_file = os.path.join(selected_folder, "ct.nii.gz")

output_dir = "output_segments"


# Delete the output directory if it exists
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)

# Create the output directory
os.makedirs(output_dir, exist_ok=True)


def save_segmentations_by_organ(prediction_array, input_nii_path, output_dir, channel_def):
    """
    Saves individual organ segmentations from a prediction array.

    Args:
        prediction_array (np.ndarray): The predicted segmentation array.
        input_nii_path (str): Path to the original input NIfTI file (for affine and header).
        output_dir (str): Directory to save the individual organ segmentation files.
        channel_def (dict): Dictionary mapping organ IDs to organ names.
    """
    img_nib = nib.load(input_nii_path)
    affine = img_nib.affine
    header = img_nib.header

    os.makedirs(output_dir, exist_ok=True)

    # Get unique organ IDs present in the prediction (excluding background)
    organ_ids = np.unique(prediction_array)
    organ_ids = organ_ids[organ_ids != 0] # Exclude background (assuming 0 is background)


    for organ_id in organ_ids:
        mask = (prediction_array == organ_id).astype(np.uint8)
        if mask.sum() == 0:
            continue  # skip empty masks
        # Get organ name from channel_def
        organ_name = channel_def.get(organ_id, f"organ_{organ_id}") # Use organ_id as fallback
        out_img = nib.Nifti1Image(mask, affine, header)
        nib.save(out_img, os.path.join(output_dir, f"{organ_name}.nii.gz"))
    print(f"Saved {len(organ_ids)} individual organ segmentation files to {output_dir}")


# Define channel definition for MONAI SegResNet
monai_channel_def = {
    0: "background",
    1: "spleen",
    2: "kidney_right",
    3: "kidney_left",
    4: "gallbladder",
    5: "liver",
    6: "stomach",
    7: "aorta",
    8: "inferior_vena_cava",
    9: "portal_vein_and_splenic_vein",
    10: "pancreas",
    11: "adrenal_gland_right",
    12: "adrenal_gland_left",
    13: "lung_upper_lobe_left",
    14: "lung_lower_lobe_left",
    15: "lung_upper_lobe_right",
    16: "lung_middle_lobe_right",
    17: "lung_lower_lobe_right",
    18: "vertebrae_L5",
    19: "vertebrae_L4",
    20: "vertebrae_L3",
    21: "vertebrae_L2",
    22: "vertebrae_L1",
    23: "vertebrae_T12",
    24: "vertebrae_T11",
    25: "vertebrae_T10",
    26: "vertebrae_T9",
    27: "vertebrae_T8",
    28: "vertebrae_T7",
    29: "vertebrae_T6",
    30: "vertebrae_T5",
    31: "vertebrae_T4",
    32: "vertebrae_T3",
    33: "vertebrae_T2",
    34: "vertebrae_T1",
    35: "vertebrae_C7",
    36: "vertebrae_C6",
    37: "vertebrae_C5",
    38: "vertebrae_C4",
    39: "vertebrae_C3",
    40: "vertebrae_C2",
    41: "vertebrae_C1",
    42: "esophagus",
    43: "trachea",
    44: "heart_myocardium",
    45: "heart_atrium_left",
    46: "heart_ventricle_left",
    47: "heart_atrium_right",
    48: "heart_ventricle_right",
    49: "pulmonary_artery",
    50: "brain",
    51: "iliac_artery_left",
    52: "iliac_artery_right",
    53: "iliac_vena_left",
    54: "iliac_vena_right",
    55: "small_bowel",
    56: "duodenum",
    57: "colon",
    58: "rib_left_1",
    59: "rib_left_2",
    60: "rib_left_3",
    61: "rib_left_4",
    62: "rib_left_5",
    63: "rib_left_6",
    64: "rib_left_7",
    65: "rib_left_8",
    66: "rib_left_9",
    67: "rib_left_10",
    68: "rib_left_11",
    69: "rib_left_12",
    70: "rib_right_1",
    71: "rib_right_2",
    72: "rib_right_3",
    73: "rib_right_4",
    74: "rib_right_5",
    75: "rib_right_6",
    76: "rib_right_7",
    77: "rib_right_8",
    78: "rib_right_9",
    79: "rib_right_10",
    80: "rib_right_11",
    81: "rib_right_12",
    82: "humerus_left",
    83: "humerus_right",
    84: "scapula_left",
    85: "scapula_right",
    86: "clavicula_left",
    87: "clavicula_right",
    88: "femur_left",
    89: "femur_right",
    90: "hip_left",
    91: "hip_right",
    92: "sacrum",
    93: "face",
    94: "gluteus_maximus_left",
    95: "gluteus_maximus_right",
    96: "gluteus_medius_left",
    97: "gluteus_medius_right",
    98: "gluteus_minimus_left",
    99: "gluteus_minimus_right",
    100: "autochthon_left",
    101: "autochthon_right",
    102: "iliopsoas_left",
    103: "iliopsoas_right",
    104: "urinary_bladder"
}

# Define channel definition for STU-Net
stunet_channel_def = {
    0: "background",
    1: "adrenal_gland_left",
    2: "adrenal_gland_right",
    3: "aorta",
    4: "autochthon_left",
    5: "autochthon_right",
    6: "brain",
    7: "clavicula_left",
    8: "clavicula_right",
    9: "colon",
    10: "duodenum",
    11: "esophagus",
    12: "face",
    13: "femur_left",
    14: "femur_right",
    15: "gallbladder",
    16: "gluteus_maximus_left",
    17: "gluteus_maximus_right",
    18: "gluteus_medius_left",
    19: "gluteus_medius_right",
    20: "gluteus_minimus_left",
    21: "gluteus_minimus_right",
    22: "heart_atrium_left",
    23: "heart_atrium_right",
    24: "heart_myocardium",
    25: "heart_ventricle_left",
    26: "heart_ventricle_right",
    27: "hip_left",
    28: "hip_right",
    29: "humerus_left",
    30: "humerus_right",
    31: "iliac_artery_left",
    32: "iliac_artery_right",
    33: "iliac_vena_left",
    34: "iliac_vena_right",
    35: "iliopsoas_left",
    36: "iliopsoas_right",
    37: "inferior_vena_cava",
    38: "kidney_left",
    39: "kidney_right",
    40: "liver",
    41: "lung_lower_lobe_left",
    42: "lung_lower_lobe_right",
    43: "lung_middle_lobe_right",
    44: "lung_upper_lobe_left",
    45: "lung_upper_lobe_right",
    46: "pancreas",
    47: "portal_vein_and_splenic_vein",
    48: "pulmonary_artery",
    49: "rib_left_1",
    50: "rib_left_10",
    51: "rib_left_11",
    52: "rib_left_12",
    53: "rib_left_2",
    54: "rib_left_3",
    55: "rib_left_4",
    56: "rib_left_5",
    57: "rib_left_6",
    58: "rib_left_7",
    59: "rib_left_8",
    60: "rib_left_9",
    61: "rib_right_1",
    62: "rib_right_10",
    63: "rib_right_11",
    64: "rib_right_12",
    65: "rib_right_2",
    66: "rib_right_3",
    67: "rib_right_4",
    68: "rib_right_5",
    69: "rib_right_6",
    70: "rib_right_7",
    71: "rib_right_8",
    72: "rib_right_9",
    73: "sacrum",
    74: "scapula_left",
    75: "scapula_right",
    76: "small_bowel",
    77: "spleen",
    78: "stomach",
    79: "trachea",
    80: "urinary_bladder",
    81: "vertebrae_C1",
    82: "vertebrae_C2",
    83: "vertebrae_C3",
    84: "vertebrae_C4",
    85: "vertebrae_C5",
    86: "vertebrae_C6",
    87: "vertebrae_C7",
    88: "vertebrae_L1",
    89: "vertebrae_L2",
    90: "vertebrae_L3",
    91: "vertebrae_L4",
    92: "vertebrae_L5",
    93: "vertebrae_T1",
    94: "vertebrae_T10",
    95: "vertebrae_T11",
    96: "vertebrae_T12",
    97: "vertebrae_T2",
    98: "vertebrae_T3",
    99: "vertebrae_T4",
    100: "vertebrae_T5",
    101: "vertebrae_T6",
    102: "vertebrae_T7",
    103: "vertebrae_T8",
    104: "vertebrae_T9"
}

match model_selector.value:
    case 'totalSegmentator':
        totalsegmentator(input_file, output_dir)
    case 'MONAI SegResNet':

        # =========================================
        # MONAI SegResNet Download
        # =========================================
        bundle_dir = "./wholebody_ct"
        os.makedirs(bundle_dir, exist_ok=True)
        !wget -O {bundle_dir}/model.pt https://huggingface.co/MONAI/wholeBody_ct_segmentation/resolve/0.2.7/models/model.pt

        # ===================================================
        # 2. Load Model
        # ===================================================

        # SegResNet model definition (from MONAI bundle)
        model = SegResNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=105,   # 104 organs + background
            init_filters=32,
            blocks_down=[1, 2, 2, 4],
            blocks_up=[1, 1, 1]
        )


        ckpt = torch.load(f"{bundle_dir}/model.pt", map_location="cpu")

        model.load_state_dict(ckpt, strict=True)
        model.eval().cuda()

        # ===================================================
        # 3. Preprocessing
        # ===================================================

        from monai.transforms import Compose

        pre_transforms = Compose([
            LoadImaged(keys=["image"]),
            EnsureChannelFirstd(keys=["image"]),
            Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.5), mode="bilinear"),
            Orientationd(keys=["image"], axcodes="RAS"),
            ScaleIntensityd(keys=["image"]),
            EnsureTyped(keys=["image"])
        ])

        # ===================================================
        # 4. Inference
        # ===================================================

        def run_inference(input_path):
            data = [{"image": input_path}]
            dataset = Dataset(data=data, transform=pre_transforms)
            loader = DataLoader(dataset, batch_size=1)

            with torch.no_grad():
                for batch in loader:
                    image = batch["image"].cuda()
                    pred = sliding_window_inference(
                        inputs=image,
                        roi_size=(96, 96, 96),
                        sw_batch_size=1,
                        predictor=model,
                        overlap=0.5
                    )
                    # Get the prediction array
                    prediction_array = torch.argmax(pred, dim=1).cpu().numpy()[0]
                    return prediction_array

        # Run inference and save segmentations
        monai_prediction = run_inference(input_file)
        save_segmentations_by_organ(monai_prediction, input_file, output_dir, monai_channel_def)
        del model
        torch.cuda.empty_cache()

    case 'STU-Net-B':
        # ===================================================
        # 1. Load STU-Net-B with dataset parameter
        # ===================================================
        model = medim.create_model("STU-Net-B", dataset="TotalSegmentator")
        model.eval().cuda()

        # ===================================================
        # 2. Load and preprocess NIfTI input
        # ===================================================
        img = nib.load(input_file)
        img_data = img.get_fdata()

        # Normalize input (crucial step!)
        img_data = (img_data - img_data.mean()) / (img_data.std() + 1e-8)

        input_tensor = torch.tensor(
            img_data, dtype=torch.float32
        ).unsqueeze(0).unsqueeze(0).cuda()

        # ===================================================
        # 3. Sliding Window Inference
        # ===================================================
        with torch.no_grad():
            pred_logits = sliding_window_inference(
                inputs=input_tensor,
                roi_size=(96, 96, 96),
                sw_batch_size=1,
                predictor=model,
                overlap=0.5
            )
            prediction = torch.argmax(pred_logits, dim=1).squeeze().cpu().numpy()

        # ===================================================
        # 4. Save segmentations
        # ===================================================
        save_segmentations_by_organ(prediction, input_file, output_dir, stunet_channel_def)

        # ===================================================
        # 5. Cleanup
        # ===================================================
        model.cpu()
        del model, input_tensor, pred_logits, prediction
        torch.cuda.empty_cache()

creating model STU-Net-L
try to load pretrained weights for TotalSegmentator


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


large_ep4k.model:   0%|          | 0.00/3.52G [00:00<?, ?B/s]

  win_data = inputs[unravel_slice[0]].to(sw_device)
  out[idx_zm] += p


Saved 66 individual organ segmentation files to output_segments


In [16]:
# @title Filtering step
import ipywidgets as widgets
from IPython.display import display, clear_output
import os
import nibabel as nib
import numpy as np

# Assume output_dir is defined in a previous cell or define it here if needed
# For example, if output_dir is not defined:
# output_dir = "output_segments"

valid_segmentation_names = []

all_segmentation_files = [f for f in os.listdir(output_dir)]

for file_name in all_segmentation_files:
    file_path = os.path.join(output_dir, file_name)
    base_name = file_name.replace(".nii.gz", "")

    try:
        img = nib.load(file_path)
        data = img.get_fdata()

        if np.count_nonzero(data) >= 5: # Only add if it contains at least 5 non-zero voxels
            valid_segmentation_names.append(base_name)
        else:
            print(f"Skipping {file_name}: Data has less than 5 non-zero voxels.")
    except Exception as e:
        print(f"Could not load, process, or validate {file_name}: {e}")
        continue

# Sort the valid segmentation names alphabetically
valid_segmentation_names.sort()

Skipping rib_right_11.nii.gz: Data has less than 5 non-zero voxels.
Skipping rib_left_1.nii.gz: Data has less than 5 non-zero voxels.


# Visualizer

In [17]:
# @title ##### Select the organs to be shown and then run the next cell to visualize them in 3d
# Create a dropdown widget to select organs
organ_selector = widgets.SelectMultiple(
    options=valid_segmentation_names,
    description='Select Organs:',
    disabled=False,
    style = {'description_width': 'initial'}
)

display(organ_selector)

SelectMultiple(description='Select Organs:', options=('aorta', 'autochthon_left', 'autochthon_right', 'clavicu…

In [None]:
# @title 3d Visualization
import nibabel as nib
import pyvista as pv
from skimage import measure
import numpy as np
from scipy.ndimage import gaussian_filter
import colorsys


# This cell should be run after selecting organs in the previous cell

# Assume output_dir is defined in a previous cell or define it here if needed
# Assume organ_selector and its value are available from the previous cell

loaded_meshes = {}
plotter = pv.Plotter(notebook=True)
organ_controls = {}

# Load selected organ masks and create meshes
selected_organs = organ_selector.value

if not selected_organs:
    print("No organs selected. Please select organs in the previous cell.")
else:
    num_organs = len(selected_organs)
    for i, organ_name in enumerate(selected_organs):
        mask_file = os.path.join(output_dir, f"{organ_name}.nii.gz")
        if os.path.exists(mask_file):
            mask = nib.load(mask_file).get_fdata()
            mask_smoothed = gaussian_filter(mask.astype(float), sigma=1)
            verts, faces, _, _ = measure.marching_cubes(mask_smoothed, level=0.5, step_size=1)
            faces = np.hstack((np.ones((faces.shape[0], 1)) * 3, faces)).flatten().astype(np.int64)
            mesh = pv.PolyData(verts, faces)
            loaded_meshes[organ_name] = mesh

            # Calculate initial color using HSV
            hue = i / num_organs  # Distribute hues evenly
            rgb_color = colorsys.hsv_to_rgb(hue, 0.8, 0.8)  # Use a fixed saturation and value
            hex_color = '#%02x%02x%02x' % (int(rgb_color[0]*255), int(rgb_color[1]*255), int(rgb_color[2]*255))

            # Add mesh to plotter and store the actor
            actor = plotter.add_mesh(mesh, color=hex_color, opacity=0.6, name=organ_name)

            # Create controls for each organ
            color_picker = widgets.ColorPicker(concise=False, description='Color:', value=hex_color, disabled=False)
            opacity_slider = widgets.FloatSlider(value=0.6, min=0.0, max=1.0, step=0.05, description='Opacity:', disabled=False, continuous_update=True, orientation='horizontal', readout=True, readout_format='.2f')
            visibility_checkbox = widgets.Checkbox(value=True, description='Visible:', disabled=False)
            render_button = widgets.Button(description="Render Plot")

            # Store controls and actor
            organ_controls[organ_name] = {
                'color_picker': color_picker,
                'opacity_slider': opacity_slider,
                'visibility_checkbox': visibility_checkbox,
                'actor': actor
            }

            def update_color(change, actor=actor):
                actor.prop.color = change['new']

            def update_opacity(change, actor=actor):
                actor.prop.opacity = change['new']

            def update_visibility(change, actor=actor):
                actor.SetVisibility(change['new'])


            color_picker.observe(update_color, names='value')
            opacity_slider.observe(update_opacity, names='value')
            visibility_checkbox.observe(update_visibility, names='value')

    # Display controls and plot

    def render(b=None): # Added b=None to accept button click event
      clear_output()
      for organ_name, controls in organ_controls.items():
          print(f"Controls for {organ_name}:")
          display(widgets.HBox([controls['color_picker'], controls['opacity_slider'], controls['visibility_checkbox']]))
      display(render_button)
      plotter.show(jupyter_backend='html')

      render_button.on_click(render)



    # Initial display of controls and plot
    render()

Controls for aorta:


HBox(children=(ColorPicker(value='#cc2828', description='Color:'), FloatSlider(value=0.6, description='Opacity…

Controls for duodenum:


HBox(children=(ColorPicker(value='#ccb428', description='Color:'), FloatSlider(value=0.6, description='Opacity…

Controls for gallbladder:


HBox(children=(ColorPicker(value='#57cc28', description='Color:'), FloatSlider(value=0.6, description='Opacity…

Controls for inferior_vena_cava:


HBox(children=(ColorPicker(value='#28cc86', description='Color:'), FloatSlider(value=0.6, description='Opacity…

Controls for liver:


HBox(children=(ColorPicker(value='#2886cc', description='Color:'), FloatSlider(value=0.6, description='Opacity…

Controls for pancreas:


HBox(children=(ColorPicker(value='#5728cc', description='Color:'), FloatSlider(value=0.6, description='Opacity…

Controls for portal_vein_and_splenic_vein:


HBox(children=(ColorPicker(value='#cc28b4', description='Color:'), FloatSlider(value=0.6, description='Opacity…

Button(description='Render Plot', style=ButtonStyle())

EmbeddableWidget(value='<iframe srcdoc="<!DOCTYPE html>\n<html>\n  <head>\n    <meta http-equiv=&quot;Content-…

# Evaluation
Compare the model's predictions to professional predictions by loading both sets of segmentation masks for a selected subject.

Chosen metrics for comparing segmentation masks:
 1. Dice Similarity Coefficient (Dice): Measures the overlap between two segmentation masks.
    > Formula: **Dice = (2 * |A intersect B|) / (|A| + |B|)** Where A and B are the two segmentation masks.\
    Range: 0 (no overlap) to 1 (perfect overlap). Higher is better.

 2. Jaccard Index (IoU - Intersection over Union): Another measure of overlap, similar to Dice.
    > Formula: **Jaccard = |A intersect B| / |A union B|**\
      Range: 0 (no overlap) to 1 (perfect overlap). Higher is better.

 3. Hausdorff Distance: Measures the maximum distance between points in the boundaries of the two segmentation masks.
    > Formula: **Hausdorff(A, B) = max(h(A, B), h(B, A)) where h(A, B) = max_{a in A} min_{b in B} distance(a, b)**\
      Lower is better, as it indicates the boundaries are closer.

In [18]:
# @title Calculating Metrics
from scipy.spatial.distance import directed_hausdorff
import numpy as np
import os
import nibabel as nib
import pandas as pd
from skimage.transform import resize
selected_organs = organ_selector.value

def calculate_metrics(mask1, mask2):
    """
    Calculates Dice coefficient, Jaccard index, and Hausdorff distance between two binary masks.

    Args:
        mask1 (np.ndarray): The first binary mask.
        mask2 (np.ndarray): The second binary mask.

    Returns:
        tuple: A tuple containing the Dice coefficient, Jaccard index, and Hausdorff distance.
               Returns (0.0, 0.0, np.inf) if either mask is empty for Dice and Jaccard,
               and np.inf for Hausdorff distance if either mask is empty.
    """
    # Ensure masks have the same shape
    if mask1.shape != mask2.shape:
        print(f"  Warning: Masks have different shapes: {mask1.shape} vs {mask2.shape}. Resizing the smaller mask.")
        # Determine which mask is smaller and resize it
        if np.prod(mask1.shape) < np.prod(mask2.shape):
            mask1 = resize(mask1, mask2.shape, order=0, preserve_range=True, anti_aliasing=False)
            mask1 = (mask1 > 0.5).astype(np.uint8) # Convert back to binary after resizing
        else:
            mask2 = resize(mask2, mask1.shape, order=0, preserve_range=True, anti_aliasing=False)
            mask2 = (mask2 > 0.5).astype(np.uint8) # Convert back to binary after resizing


    mask1_flat = mask1.flatten()
    mask2_flat = mask2.flatten()


    # Dice Coefficient
    intersection = np.sum(mask1_flat * mask2_flat)
    sum_masks = np.sum(mask1_flat) + np.sum(mask2_flat)
    dice = (2.0 * intersection) / sum_masks if sum_masks else 0.0

    # Jaccard Index
    union = np.sum(mask1_flat + mask2_flat - mask1_flat * mask2_flat)
    jaccard = intersection / union if union else 0.0

    # Hausdorff Distance
    coords1 = np.argwhere(mask1)
    coords2 = np.argwhere(mask2)
    hausdorff_distance = np.inf
    if coords1.shape[0] > 0 and coords2.shape[0] > 0:
        # Calculate directed Hausdorff distance in both directions and take the maximum
        h1 = directed_hausdorff(coords1, coords2)[0]
        h2 = directed_hausdorff(coords2, coords1)[0]
        hausdorff_distance = max(h1, h2)
    elif coords1.shape[0] > 0 or coords2.shape[0] > 0:
        # If one mask is empty and the other is not, Hausdorff distance is infinite
         hausdorff_distance = np.inf


    return dice, jaccard, hausdorff_distance

# Store calculated metrics
comparison_results = {}

# Assume selected_organs, selected_folder, and output_dir are defined in previous cells

# Construct the file path to the professional segmentation directory
professional_segmentation_dir = os.path.join(selected_folder, "segmentations")

# Iterate through the selected organs
for organ_name in selected_organs:
    print(f"Processing organ: {organ_name}")

    # Load professional mask
    professional_mask = None
    professional_file_path = os.path.join(professional_segmentation_dir, f"{organ_name}.nii.gz")
    if os.path.exists(professional_file_path):
        try:
            img = nib.load(professional_file_path)
            professional_mask = img.get_fdata()
            print(f"  Loaded professional mask for {organ_name}")
        except Exception as e:
            print(f"  Could not load professional mask for {organ_name}: {e}")
            professional_mask = None # Ensure mask is None if loading fails
    else:
        print(f"  Professional mask for {organ_name} not found.")


    # Load model mask
    model_mask = None
    model_file_path = os.path.join(output_dir, f"{organ_name}.nii.gz")
    if os.path.exists(model_file_path):
        try:
            img = nib.load(model_file_path)
            model_mask = img.get_fdata()
            print(f"  Loaded model mask for {organ_name}")
        except Exception as e:
            print(f"  Could not load model mask for {organ_name}: {e}")
            model_mask = None # Ensure mask is None if loading fails
    else:
         print(f"  Model mask for {organ_name} not found.")


    # Calculate metrics if both masks are loaded
    if professional_mask is not None and model_mask is not None:
        dice, jaccard, hausdorff = calculate_metrics(model_mask, professional_mask)
        comparison_results[organ_name] = {
            'dice': dice,
            'jaccard': jaccard,
            'hausdorff': hausdorff
        }
        print(f"  Metrics for {organ_name}: Dice={dice:.4f}, Jaccard={jaccard:.4f}, Hausdorff={hausdorff:.4f}")
    else:
        print(f"  Skipping metrics for {organ_name} due to missing mask(s).")

    # Explicitly delete masks to free up memory
    del professional_mask
    del model_mask


print("\nFinished calculating metrics for all selected organs.")

# Create a pandas DataFrame from the comparison_results dictionary
comparison_df = pd.DataFrame.from_dict(comparison_results, orient='index')
comparison_df.index.name = 'Organ'

# Store the current comparison results in the persistent dictionary, keyed by the selected model
# Assuming 'all_comparison_results' and 'model_selector' are available from previous cells
if 'all_comparison_results' not in globals():
    all_comparison_results = {}

selected_model = model_selector.value
all_comparison_results[selected_model] = comparison_df

Processing organ: aorta
  Loaded professional mask for aorta
  Loaded model mask for aorta
  Metrics for aorta: Dice=0.3106, Jaccard=0.1838, Hausdorff=134.5623
Processing organ: autochthon_left
  Loaded professional mask for autochthon_left
  Loaded model mask for autochthon_left
  Metrics for autochthon_left: Dice=0.3684, Jaccard=0.2258, Hausdorff=48.1248
Processing organ: autochthon_right
  Loaded professional mask for autochthon_right
  Loaded model mask for autochthon_right
  Metrics for autochthon_right: Dice=0.0646, Jaccard=0.0334, Hausdorff=61.4329
Processing organ: clavicula_left
  Loaded professional mask for clavicula_left
  Loaded model mask for clavicula_left
  Metrics for clavicula_left: Dice=0.5264, Jaccard=0.3572, Hausdorff=38.6523
Processing organ: colon
  Loaded professional mask for colon
  Loaded model mask for colon
  Metrics for colon: Dice=0.8648, Jaccard=0.7619, Hausdorff=38.1838
Processing organ: duodenum
  Loaded professional mask for duodenum
  Loaded model ma

In [None]:
# @title Results Table
import pandas as pd
import numpy as np
from IPython.display import display

# Assume 'all_comparison_results' and 'model_selector' are available from previous cells

print(f"\nComparison results for '{selected_model}':")

# Set pandas options to display the full DataFrame
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)


# Display the DataFrame
display(comparison_df)

# Analyze the results for the current model
print(f"\nAnalysis of Segmentation Metrics for '{selected_model}':")
print("-" * 30)

if not comparison_df.empty:
    # Find organs with highest/lowest Dice and Jaccard scores
    highest_dice_organ = comparison_df['dice'].idxmax()
    lowest_dice_organ = comparison_df['dice'].idxmin()
    highest_jaccard_organ = comparison_df['jaccard'].idxmax()
    lowest_jaccard_organ = comparison_df['jaccard'].idxmin()

    # Find organs with highest/lowest Hausdorff distances
    # Exclude infinite Hausdorff distances for min calculation if any
    finite_hausdorff = comparison_df[comparison_df['hausdorff'] != np.inf]['hausdorff']
    highest_hausdorff_organ = comparison_df['hausdorff'].idxmax()

    if not finite_hausdorff.empty:
        lowest_hausdorff_organ = finite_hausdorff.idxmin()
    else:
        lowest_hausdorff_organ = "N/A (All Hausdorff distances are infinite)"


    print(f"Overall Performance (based on selected organs):")
    print(f"- Dice Coefficient: Higher values indicate better overlap.")
    print(f"  Highest Dice: '{highest_dice_organ}' ({comparison_df.loc[highest_dice_organ, 'dice']:.4f})")
    print(f"  Lowest Dice: '{lowest_dice_organ}' ({comparison_df.loc[lowest_dice_organ, 'dice']:.4f})")
    print(f"- Jaccard Index (IoU): Higher values indicate better overlap.")
    print(f"  Highest Jaccard: '{highest_jaccard_organ}' ({comparison_df.loc[highest_jaccard_organ, 'jaccard']:.4f})")
    print(f"  Lowest Jaccard: '{lowest_jaccard_organ}' ({comparison_df.loc[lowest_jaccard_organ, 'jaccard']:.4f})")
    print(f"- Hausdorff Distance: Lower values indicate better boundary agreement.")
    print(f"  Highest Hausdorff: '{highest_hausdorff_organ}' ({comparison_df.loc[highest_hausdorff_organ, 'hausdorff']:.4f})")
    if lowest_hausdorff_organ != "N/A (All Hausdorff distances are infinite)":
         print(f"  Lowest Hausdorff: '{lowest_hausdorff_organ}' ({comparison_df.loc[lowest_hausdorff_organ, 'hausdorff']:.4f})")
    else:
        print(f"  Lowest Hausdorff: {lowest_hausdorff_organ}")

else:
    print("No comparison results to analyze.")

# Optionally, display all stored results
print("\nAll Stored Comparison Results:")
print("=" * 30)
for model_name, results_df in all_comparison_results.items():
    print(f"\nResults for Model: {model_name}")
    display(results_df)

# Reset pandas options to default (optional, but good practice)
pd.reset_option('display.max_rows')
pd.reset_option('display.max_columns')
pd.reset_option('display.width')
pd.reset_option('display.max_colwidth')


Comparison results for 'STU-Net-B':


Unnamed: 0_level_0,dice,jaccard,hausdorff
Organ,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
aorta,0.259097,0.148829,129.011627
autochthon_left,0.46443,0.302448,34.014703
autochthon_right,0.066404,0.034342,64.892218
colon,0.823198,0.699522,35.298725
duodenum,0.06883,0.035642,60.514461
esophagus,0.175583,0.096241,50.299105
femur_left,0.433941,0.277091,142.555252
femur_right,0.06305,0.032551,140.132081
gluteus_maximus_left,0.535121,0.365301,126.779336
gluteus_maximus_right,0.005485,0.00275,128.117134



Analysis of Segmentation Metrics for 'STU-Net-B':
------------------------------
Overall Performance (based on selected organs):
- Dice Coefficient: Higher values indicate better overlap.
  Highest Dice: 'colon' (0.8232)
  Lowest Dice: 'rib_left_6' (0.0000)
- Jaccard Index (IoU): Higher values indicate better overlap.
  Highest Jaccard: 'colon' (0.6995)
  Lowest Jaccard: 'rib_left_6' (0.0000)
- Hausdorff Distance: Lower values indicate better boundary agreement.
  Highest Hausdorff: 'humerus_left' (182.9453)
  Lowest Hausdorff: 'rib_left_12' (9.4340)

All Stored Comparison Results:

Results for Model: totalSegmentator


Unnamed: 0_level_0,dice,jaccard,hausdorff
Organ,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
adrenal_gland_left,0.696203,0.533981,2.44949
adrenal_gland_right,0.930474,0.869988,2.0
aorta,0.987893,0.976075,2.236068
atrial_appendage_left,0.956734,0.917057,2.0
autochthon_left,0.983372,0.967287,2.0
autochthon_right,0.982863,0.966303,2.0
brachiocephalic_trunk,0.971692,0.944942,1.414214
brachiocephalic_vein_left,0.980409,0.961571,1.0
brachiocephalic_vein_right,0.976883,0.95481,1.0
clavicula_left,0.977373,0.955748,2.236068



Results for Model: MONAI SegResNet


Unnamed: 0_level_0,dice,jaccard,hausdorff
Organ,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
adrenal_gland_left,0.0,0.0,18.894444
adrenal_gland_right,0.806536,0.675794,3.741657
aorta,0.956558,0.916732,2.828427
autochthon_left,0.949806,0.904411,3.0
autochthon_right,0.956149,0.915982,3.0
clavicula_left,0.935203,0.878292,3.0
clavicula_right,0.953585,0.911287,2.236068
colon,0.891935,0.804949,32.264532
duodenum,0.862767,0.758655,20.223748
esophagus,0.871711,0.772595,4.358899



Results for Model: STU-Net-B


Unnamed: 0_level_0,dice,jaccard,hausdorff
Organ,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
aorta,0.259097,0.148829,129.011627
autochthon_left,0.46443,0.302448,34.014703
autochthon_right,0.066404,0.034342,64.892218
colon,0.823198,0.699522,35.298725
duodenum,0.06883,0.035642,60.514461
esophagus,0.175583,0.096241,50.299105
femur_left,0.433941,0.277091,142.555252
femur_right,0.06305,0.032551,140.132081
gluteus_maximus_left,0.535121,0.365301,126.779336
gluteus_maximus_right,0.005485,0.00275,128.117134


In [None]:
# @title ##### Save results file
# Export the DataFrame to a CSV file
csv_filename = f"{selected_model}_segmentation_comparison_results.csv"
comparison_df.to_csv(csv_filename, index=True)

print(f"Comparison results exported to {csv_filename}")

# Provide a download link
from google.colab import files
files.download(csv_filename)

Comparison results exported to STU-Net-B_segmentation_comparison_results.csv


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>