In [None]:
import os
import numpy as np
import SimpleITK as sitk

In [42]:
def ensure_dir(directory):
    """Ensure a directory exists."""
    if not os.path.exists(directory):
        os.makedirs(directory)

In [43]:
def connected_components(input_dir, output_dir, max_components=40):
    """Perform connected component analysis on binary labels and keep the largest components, marking them with unique labels."""
    ensure_dir(output_dir)

    for filename in os.listdir(input_dir):
        if filename.endswith(".nii.gz"):
            input_path = os.path.join(input_dir, filename)
            output_path = os.path.join(output_dir, filename)

            # Read the binary label image
            binary_image = sitk.ReadImage(input_path)

            # Connected component analysis
            connected_components = sitk.ConnectedComponent(binary_image)

            # Relabel components by size and keep only the largest max_components
            relabeled = sitk.RelabelComponent(connected_components, sortByObjectSize=True)

            # Keep only the largest components, but retain their unique labels
            largest_components = sitk.Cast(relabeled <= max_components, sitk.sitkUInt32) * relabeled

            # Save instance mask
            sitk.WriteImage(largest_components, output_path)
            print(f"Saved instance mask with top {max_components} components to {output_path}")

In [44]:
def compute_overlap(prediction_dir, label_dir, output_dir):
    """Compute overlap between predictions and ground truth labels."""
    ensure_dir(output_dir)

    for filename in os.listdir(prediction_dir):
        if filename.endswith(".nii.gz"):
            prediction_path = os.path.join(prediction_dir, filename)
            label_path = os.path.join(label_dir, filename)
            output_path = os.path.join(output_dir, filename)

            # Read prediction and label images
            prediction = sitk.ReadImage(prediction_path)
            label = sitk.ReadImage(label_path)

            # Ensure the images have the same size
            assert prediction.GetSize() == label.GetSize(), f"Size mismatch for {filename}"

            # Convert to binary
            prediction_array = sitk.GetArrayFromImage(prediction) > 0
            label_array = sitk.GetArrayFromImage(label) > 0

            # Compute overlap
            only_in_label = (label_array & ~prediction_array).astype(np.uint8)
            only_in_prediction = (~label_array & prediction_array).astype(np.uint8)
            overlap = (label_array & prediction_array).astype(np.uint8)

            # Combine into a single image
            overlap_combined = only_in_label + overlap * 2 + only_in_prediction * 3
            overlap_image = sitk.GetImageFromArray(overlap_combined)
            overlap_image.CopyInformation(prediction)

            # Save overlap image
            sitk.WriteImage(overlap_image, output_path)
            print(f"Saved overlap map to {output_path}")

In [None]:
base_dir = "images_for_comparison"

# Paths for connected components
semantic_labels_dir = os.path.join(base_dir, "labels_semantic")
instance_labels_dir = os.path.join(base_dir, "labels_instance")
connected_components(semantic_labels_dir, instance_labels_dir)

# Overlap paths
prediction_dirs = [
    ("prediction_dins", "overlap_dins"),
    ("prediction_sw_fastedit", "overlap_sw_fastedit"),
    ("prediction_sam2", "overlap_sam2")
]

for prediction_dir, overlap_dir in prediction_dirs:
    compute_overlap(
        os.path.join(base_dir, prediction_dir),
        semantic_labels_dir,
        os.path.join(base_dir, overlap_dir)
    )


In [None]:
connected_components(input_dir="./images_for_comparison/prediction_sam2",
                     output_dir="./images_for_comparison/prediction_sam2_instance", 
                     max_components=40)