In [None]:
import os
import nibabel as nib
import numpy as np
import ipywidgets as widgets
from ipywidgets import interact
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

# ----------------------------------------------------------------------
# 1) Define paths to ground truth, inference, and input directories
# ----------------------------------------------------------------------
minboggle_ground_truth_path = "/home/fp427/rds/rds-cam-segm-7tts6phZ4tw/mission/data/mindboggle/segmentation_nnunet_bobs_convention"
inference_22_path = "/home/fp427/rds/rds-cam-segm-7tts6phZ4tw/mission/nnunet/nnUNet_results/Dataset022_bobs/inferences"
input_path = "/home/fp427/rds/rds-cam-segm-7tts6phZ4tw/mission/data/mindboggle/input"

# ----------------------------------------------------------------------
# 2) Define label correspondence dictionary
# ----------------------------------------------------------------------
label_correspondence = {
    "labels": {
        "unknown": -1,
        "background": 0,
        "Left-Cerebral-Exterior": 1,
        "Cerebral-White-Matter": 2,
        "Cerebral-Cortex": 3,
        "Lateral-Ventricle": 4,
        "Inf-Lat-Vent": 5,
        "Left-Cerebellum-Exterior": 6,
        "Cerebellum-White-Matter": 7,
        "Cerebellum-Cortex": 8,
        "Thalamus": 9,
        "Caudate": 10,
        "Putamen": 11,
        "Pallidum": 12,
        "3rd-Ventricle": 13,
        "4th-Ventricle": 14,
        "Brain-Stem": 15,
        "Hippocampus": 16,
        "Amygdala": 17,
        "CSF": 18,
        "Accumbens-area": 19,
        "Vessel": 20,
        "Choroid-plexus": 21,
        "VentralDC": 22,
        "WM-hypointensities": 23,
        "Optic-Chiasm": 24,
        "Vermis": 25
    }
}

# ----------------------------------------------------------------------
# 3) Dice score functions
# ----------------------------------------------------------------------
def dice_score_per_label(gt_data, pred_data, label):
    """
    Compute the Dice score for a specific label.
    Dice = 2 * (|X ∩ Y|) / (|X| + |Y|)
    """
    gt_mask = (gt_data == label)
    pred_mask = (pred_data == label)

    intersection = np.logical_and(gt_mask, pred_mask).sum()
    gt_sum = gt_mask.sum()
    pred_sum = pred_mask.sum()

    # Avoid division by zero if both masks are empty
    if gt_sum + pred_sum == 0:
        return np.nan

    return 2.0 * intersection / (gt_sum + pred_sum)


def dice_score_all_labels(gt_data, pred_data, ignore_label=-1):
    """
    Compute the Dice score for each unique label in the ground truth,
    ignoring a specified label (default: -1).
    """
    unique_labels = np.unique(gt_data)
    dice_scores = {}
    for label in unique_labels:
        if label == ignore_label:
            continue
        dice_scores[label] = dice_score_per_label(gt_data, pred_data, label)
    return dice_scores


def dice_score_avg_foreground(gt_data, pred_data, ignore_label=-1, background_label=0):
    """
    Compute the average Dice score for all labels except the ignore_label and background_label.
    """
    unique_labels = np.unique(gt_data)
    dice_values = []
    for label in unique_labels:
        if label == ignore_label or label == background_label:
            continue
        score = dice_score_per_label(gt_data, pred_data, label)
        if not np.isnan(score):
            dice_values.append(score)

    if len(dice_values) == 0:
        return np.nan

    return np.mean(dice_values)

# ----------------------------------------------------------------------
# 4) List, sort, and pair .nii.gz files from each directory
# ----------------------------------------------------------------------
gt_files = [f for f in os.listdir(minboggle_ground_truth_path) if f.endswith('.nii.gz')]
inference_files = [f for f in os.listdir(inference_22_path) if f.endswith('.nii.gz')]

gt_files_sorted = sorted(gt_files)
inference_files_sorted = sorted(inference_files)

# Optional: Check if both directories have the same number of files
if len(gt_files_sorted) != len(inference_files_sorted):
    print("Warning: The number of files in ground truth and inference directories do not match.")

# ----------------------------------------------------------------------
# 5) Load and process up to 200 pairs of NIfTI files (Dice scores)
# ----------------------------------------------------------------------

### ADDED: A dictionary to accumulate Dice scores across all scans
### Key = label_id, Value = list of Dice scores from each scan
aggregate_scores = {}

num_pairs_to_load = 200
for i in range(min(num_pairs_to_load, len(gt_files_sorted), len(inference_files_sorted))):
    gt_filename = gt_files_sorted[i]
    inf_filename = inference_files_sorted[i]

    gt_filepath = os.path.join(minboggle_ground_truth_path, gt_filename)
    inf_filepath = os.path.join(inference_22_path, inf_filename)

    # Load the NIfTI images
    gt_img = nib.load(gt_filepath)
    inf_img = nib.load(inf_filepath)

    # Extract data as NumPy arrays
    gt_data = gt_img.get_fdata()
    inf_data = inf_img.get_fdata()

    # Compute Dice scores for each label (ignoring -1)
    dice_scores = dice_score_all_labels(gt_data, inf_data, ignore_label=-1)

    # Compute the average Dice over all foreground labels (removing 0 and -1)
    avg_foreground_dice = dice_score_avg_foreground(
        gt_data, inf_data, ignore_label=-1, background_label=0
    )

    print(f"\nPair {i+1}:")
    print(f"Ground Truth File: {gt_filename}")
    print(f"Inference File: {inf_filename}\n")

    print("Dice scores per label (label: score):")
    for label, score in dice_scores.items():
        if np.isnan(score):
            print(f"  Label {label}: no voxels in GT or Prediction")
        else:
            print(f"  Label {label}: {score:.4f}")

        ### ADDED: Accumulate label-wise scores across scans, ignoring NaN
        if not np.isnan(score):
            if label not in aggregate_scores:
                aggregate_scores[label] = []
            aggregate_scores[label].append(score)

    if np.isnan(avg_foreground_dice):
        print("Average Foreground Dice (excluding -1 and 0): no valid foreground labels")
    else:
        print(f"Average Foreground Dice (excluding -1 and 0): {avg_foreground_dice:.4f}")

# ----------------------------------------------------------------------
# 6) After processing all scans, compute single list of average Dice scores
# ----------------------------------------------------------------------
print("\n=========================================")
print("Final AVERAGE Dice score per label (across ALL scans):")

final_scores = {}
for label, values in aggregate_scores.items():
    final_scores[label] = np.mean(values)  # average over all scans that had that label

# Sort in descending order by Dice
sorted_final = sorted(final_scores.items(), key=lambda x: x[1], reverse=True)

print("Label | Class Name                       | Dice")
print("------|----------------------------------|----------")

for label, score in sorted_final:
    # Convert label to int for printing if needed
    label_name = next(
        (k for k, v in label_correspondence["labels"].items() if v == label),
        "???"
    )
    print(f"{int(label):5d} | {label_name:30s} | {score:.4f}")



Pair 1:
Ground Truth File: Afterthought-1.nii.gz
Inference File: bobs_0000.nii.gz

Dice scores per label (label: score):
  Label 0.0: 0.9650
  Label 3.0: 0.6066
  Label 4.0: 0.8657
  Label 5.0: 0.4786
  Label 7.0: 0.2309
  Label 9.0: 0.1791
  Label 10.0: 0.2388
  Label 11.0: 0.0485
  Label 12.0: 0.0794
  Label 13.0: 0.0000
  Label 14.0: 0.0000
  Label 15.0: 0.0017
  Label 16.0: 0.0077
  Label 17.0: 0.3323
  Label 18.0: 0.7041
  Label 19.0: 0.5073
  Label 20.0: 0.0000
  Label 22.0: 0.7485
  Label 24.0: 0.0000
Average Foreground Dice (excluding -1 and 0): 0.2794

Pair 2:
Ground Truth File: Colin27-1.nii.gz
Inference File: bobs_0001.nii.gz

Dice scores per label (label: score):
  Label 0.0: 0.9789
  Label 3.0: 0.7493
  Label 4.0: 0.8263
  Label 5.0: 0.6548
  Label 7.0: 0.6544
  Label 9.0: 0.3529
  Label 10.0: 0.2737
  Label 11.0: 0.3607
  Label 12.0: 0.1239
  Label 13.0: 0.0000
  Label 14.0: 0.0034
  Label 15.0: 0.0001
  Label 16.0: 0.2750
  Label 17.0: 0.5916
  Label 18.0: 0.4142
  Labe

: 

In [17]:
import os
import nibabel as nib
import numpy as np
import ipywidgets as widgets
from ipywidgets import interact
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

# Suppose you already have this label correspondence dictionary:
label_correspondence = {
    "labels": {
        "unknown": -1,
        "background": 0,
        "Left-Cerebral-Exterior": 1,
        "Cerebral-White-Matter": 2,
        "Cerebral-Cortex": 3,
        "Lateral-Ventricle": 4,
        "Inf-Lat-Vent": 5,
        "Left-Cerebellum-Exterior": 6,
        "Cerebellum-White-Matter": 7,
        "Cerebellum-Cortex": 8,
        "Thalamus": 9,
        "Caudate": 10,
        "Putamen": 11,
        "Pallidum": 12,
        "3rd-Ventricle": 13,
        "4th-Ventricle": 14,
        "Brain-Stem": 15,
        "Hippocampus": 16,
        "Amygdala": 17,
        "CSF": 18,
        "Accumbens-area": 19,
        "Vessel": 20,
        "Choroid-plexus": 21,
        "VentralDC": 22,
        "WM-hypointensities": 23,
        "Optic-Chiasm": 24,
        "Vermis": 25
    }
}

def visualize_triplet(input_filepath, gt_filepath, pred_filepath):
    """
    Visualize input, ground truth, and prediction side by side
    with an interactive slider to move through slices.
    Allows the user to select which labels to show in the overlays.
    All other labels are masked out.
    """

    # Load data
    input_img = nib.load(input_filepath)
    input_data = input_img.get_fdata()

    gt_img = nib.load(gt_filepath)
    gt_data = gt_img.get_fdata()

    pred_img = nib.load(pred_filepath)
    pred_data = pred_img.get_fdata()

    # Build a selectable list of labels (including a special "None" and "All" options):
    label_dict = label_correspondence["labels"]
    # We'll store a list of (display_name, actual_label_id)
    # We add special entries for "All" (show all labels) and "None" (show none)
    label_options = [("All labels", "all"), ("None", "none")]
    for lbl_name, lbl_id in label_dict.items():
        label_options.append((f"{lbl_name} ({lbl_id})", lbl_id))

    # Identify the maximum label used (for color normalization, e.g. 25)
    max_label = max(label_dict.values())

    @interact(
        slice_idx=widgets.IntSlider(min=0,
                                    max=input_data.shape[2]-1,
                                    step=1,
                                    value=(input_data.shape[2]-1)//2),
        selected_labels=widgets.SelectMultiple(
            options=label_options,
            description='Labels to show',
            value=["all"],  # default selection is 'All labels'
            rows=10  # adjust how many lines the SelectMultiple widget shows
        )
    )
    def update(slice_idx=0, selected_labels=("all",)):
        # ---------------------------------------------------------
        # 1) Figure setup
        # ---------------------------------------------------------
        fig, axs = plt.subplots(1, 3, figsize=(20, 7))  # bigger figure

        # ---------------------------------------------------------
        # 2) Generate masked overlays for GT and Prediction
        # ---------------------------------------------------------
        # Convert selected_labels (tuple of label_ids or special strings) to a numeric set
        # If 'all' is present, we keep all labels as is
        # If 'none' is present, we show no labels
        # Otherwise, show only the chosen IDs
        all_label_ids = np.array(list(label_dict.values()), dtype=np.int32)
        if "all" in selected_labels:
            masked_gt = gt_data
            masked_pred = pred_data
        elif "none" in selected_labels:
            # everything is masked out
            masked_gt = np.zeros_like(gt_data)
            masked_pred = np.zeros_like(pred_data)
        else:
            # selected_labels is a list of label IDs we want to see
            numeric_labels = []
            for val in selected_labels:
                if isinstance(val, str):
                    continue  # skip 'all' or 'none' if somehow included
                numeric_labels.append(val)

            numeric_labels = np.array(numeric_labels, dtype=gt_data.dtype)

            # We keep only those labels; everything else -> 0
            masked_gt = np.where(np.isin(gt_data, numeric_labels), gt_data, 0)
            masked_pred = np.where(np.isin(pred_data, numeric_labels), pred_data, 0)

        # ---------------------------------------------------------
        # 3) Show the three subplots
        # ---------------------------------------------------------
        # (a) Input
        axs[0].imshow(input_data[:, :, slice_idx], cmap='gray')
        axs[0].set_title('Input')

        # (b) Ground Truth overlay
        axs[1].imshow(input_data[:, :, slice_idx], cmap='gray')
        axs[1].imshow(masked_gt[:, :, slice_idx], cmap='rainbow', alpha=0.5)
        axs[1].set_title('Ground Truth')

        # (c) Prediction overlay
        axs[2].imshow(input_data[:, :, slice_idx], cmap='gray')
        axs[2].imshow(masked_pred[:, :, slice_idx], cmap='rainbow', alpha=0.5)
        axs[2].set_title('Prediction')

        # ---------------------------------------------------------
        # 4) Build the legend
        # ---------------------------------------------------------
        patches = []
        for lbl_name, lbl_id in label_dict.items():
            # We handle the color logic:
            if lbl_id == 0:
                # background
                color = 'black'
            elif lbl_id == -1:
                # unknown
                color = 'lightgray'
            else:
                # color based on normalized label
                color_val = lbl_id / max_label if max_label != 0 else 0
                color = plt.cm.rainbow(color_val)
            patches.append(Patch(facecolor=color, label=f"{lbl_name} ({lbl_id})"))

        axs[2].legend(handles=patches, bbox_to_anchor=(1.05, 1), loc='upper left')

        plt.tight_layout()
        plt.show()

# ----------------------------------------------------------------------
# 7) Example usage: visualize the first triplet
# ----------------------------------------------------------------------
# Suppose you have these variables defined:
#   input_path, minboggle_ground_truth_path, inference_22_path
#   gt_files_sorted, inference_files_sorted
#
# For demonstration, we'll just show the first pair:
input_filename = os.path.join(input_path, gt_files_sorted[0])
gt_filename = os.path.join(minboggle_ground_truth_path, gt_files_sorted[0])
pred_filename = os.path.join(inference_22_path, inference_files_sorted[0])

visualize_triplet(input_filename, gt_filename, pred_filename)


interactive(children=(IntSlider(value=127, description='slice_idx', max=255), SelectMultiple(description='Labe…