In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import mode
from PIL import Image
import os

In [None]:
file_names = os.listdir("../amos22/Train/label")
slice_nums = [int(name.split("_")[1]) for name in file_names]
slice_nums = set(slice_nums)
scans_list = [[name for name in file_names if int(name.split("_")[1]) == i] for i in slice_nums]

In [None]:
key = lambda x : int(x.split("slice")[1].split(".")[0])
for list in scans_list:
    list.sort(key=key)

In [None]:
num_cuts = 14
wanted_cuts = []
for list_of_slices in scans_list:
    length = len(list_of_slices)
    offset = int(length / num_cuts)
    slices = []
    for i in range(num_cuts):
        slices.append(list_of_slices[i*offset])

    wanted_cuts.append(slices)

In [None]:
# Define the path where images are stored
DATASET_PATH = "../amos22/Train/label"  # Change this to your dataset path

def load_image(image_path):
    """Load an image as a NumPy array (assuming grayscale labeled segmentation masks)."""
    img = Image.open(image_path)
    return np.array(img)  # Shape: (H, W)

def compute_most_frequent_class(slices_list, dataset_path, with_background=True):
    """
    Compute the most frequent class per pixel for each relative slice position.

    Args:
        slices_list: List of lists containing image filenames.
        dataset_path: Root path where images are stored.
        with_background (bool): If False, ignores background (0) when computing the most common class.

    Returns:
        dict: {slice_index: most_frequent_class_array} mapping each relative slice index to its most frequent labels.
    """
    num_scans = len(slices_list)
    num_slices = len(slices_list[0])  # Assuming all scans have the same number of slices

    most_frequent_slices = {}  # Dictionary mapping relative slice index to its class distribution

    for slice_idx in range(num_slices):
        slice_images = []  # Stores images from all scans for the current slice index

        for scan in slices_list:
            image_path = os.path.join(dataset_path, scan[slice_idx])
            slice_images.append(load_image(image_path))

        slice_images = np.stack(slice_images, axis=0)  # Shape: (num_scans, H, W)

        if with_background:
            # Compute the most frequent class including background (label 0)
            most_frequent_class, _ = mode(slice_images, axis=0, keepdims=False)
        else:
            # Ignore background (label 0) when computing the most common class
            mask_nonzero = slice_images != 0  # Mask for non-background values
            filtered_slices = np.where(mask_nonzero, slice_images, np.nan)  # Replace 0 with NaN

            # Compute mode while ignoring NaN (background)
            most_frequent_class = np.apply_along_axis(
                lambda x: np.nan if np.all(np.isnan(x)) else np.bincount(x[~np.isnan(x)].astype(int)).argmax(),
                axis=0,
                arr=filtered_slices
            )

            # Assign background (0) to pixels where the only value was 0 across all scans
            only_background_pixels = np.all(slice_images == 0, axis=0)
            most_frequent_class[only_background_pixels] = 0

        most_frequent_slices[slice_idx * 10] = most_frequent_class.squeeze().astype(int)  # Store with relative index

    return most_frequent_slices  # Dictionary of {relative slice index: (H, W) array}

def plot_slices(most_frequent_slices, ncols=5):
    """
    Plot each computed "average" slice.

    Args:
        most_frequent_slices (dict): Dictionary of {slice_position: 2D array} representing
                                     the most frequent class per relative slice.
        ncols (int): Number of columns per row in the plot.
    """
    # Sort slices by their relative position
    sorted_keys = sorted(most_frequent_slices.keys())  # Sorted slice positions
    sorted_slices = [most_frequent_slices[key] for key in sorted_keys]  # Corresponding images

    num_slices = len(sorted_slices)
    nrows = (num_slices + ncols - 1) // ncols  # Compute number of rows

    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2))

    # Flatten axes if there are multiple rows
    axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes]

    for i in range(num_slices):
        axes[i].imshow(sorted_slices[i], cmap="jet", interpolation="nearest")
        axes[i].set_title(f"Slice {(100/num_slices * i):.2f}%")  # Show actual slice position
        axes[i].axis("off")

    # Hide unused subplot axes
    for i in range(num_slices, len(axes)):
        axes[i].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
# Run the processing
most_frequent_slices = compute_most_frequent_class(wanted_cuts, DATASET_PATH, with_background=True)
plot_slices(most_frequent_slices, ncols=5)

In [None]:
most_frequent_slices = compute_most_frequent_class(wanted_cuts, DATASET_PATH, with_background=False)
plot_slices(most_frequent_slices, ncols=5)

In [None]:
def compute_iou(prediction, ground_truth, ignore_background=True):
    """
    Compute the Intersection over Union (IoU) for each class in a segmentation map.

    Args:
        prediction (np.ndarray): 2D array representing predicted segmentation.
        ground_truth (np.ndarray): 2D array representing ground truth segmentation.
        ignore_background (bool): If True, ignores background (class 0) when computing mean IoU.

    Returns:
        iou_per_class (dict): IoU values for each class.
        mean_iou (float): Mean IoU across all classes.
    """
    unique_classes = np.union1d(np.unique(prediction), np.unique(ground_truth))  # Only consider present classes
    iou_per_class = {}

    for cls in unique_classes[:16]:
        if ignore_background and cls == 0:
            continue  # Skip background if requested

        intersection = np.sum((prediction == cls) & (ground_truth == cls))
        union = np.sum((prediction == cls) | (ground_truth == cls))

        iou_per_class[cls] = intersection / union if union > 0 else float("nan")  # Avoid divide-by-zero

    # Compute mean IoU only over valid classes
    valid_ious = [iou for iou in iou_per_class.values() if not np.isnan(iou)]
    mean_iou = np.mean(valid_ious) if valid_ious else 0.0

    return iou_per_class, mean_iou


def predict_slice_label(slice_name, most_frequent_slices):
    """
    Predict the label map of a given slice based on the closest relative slice.

    Args:
        slice_name (str): The filename of the input slice.
        most_frequent_slices (dict): Precomputed dictionary of most common labels per relative slice.

    Returns:
        np.ndarray: Predicted label map for the given slice.
    """
    # Extract slice index from filename (assuming format "amos_xxxx_sliceYY.png")
    slice_idx = int(slice_name.split('_slice')[-1].split('.')[0])

    # Find the closest relative slice in the dictionary
    closest_slice = min(most_frequent_slices.keys(), key=lambda x: abs(x - slice_idx))

    # Return the precomputed most common label map
    return most_frequent_slices[closest_slice]

In [None]:
# Predict labels for a new slice
img_path = "amos_0001_slice52.png"
predicted_labels = predict_slice_label(img_path, most_frequent_slices)

# Load the corresponding ground truth segmentation
ground_truth_path = os.path.join(DATASET_PATH, img_path)
ground_truth_labels = load_image(ground_truth_path)
iou_per_class, mean_iou = compute_iou(predicted_labels, ground_truth_labels, ignore_background=False)

# Plot side-by-side comparison
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

# Predicted Segmentation
axes[0].imshow(predicted_labels, cmap="jet", interpolation="nearest")
axes[0].set_title("Predicted Segmentation")
axes[0].axis("off")

# Ground Truth Segmentation
axes[1].imshow(ground_truth_labels, cmap="jet", interpolation="nearest")
axes[1].set_title("Ground Truth Segmentation")
axes[1].axis("off")

plt.tight_layout()
plt.show()

print(mean_iou)
for item in iou_per_class.items():
    print(item)

print("Unique labels in prediction:", np.unique(predicted_labels))
print("Unique labels in ground truth:", np.unique(ground_truth_labels))

print("Prediction shape:", predicted_labels.shape)
print("Ground truth shape:", ground_truth_labels.shape)

overlapping_pixels = np.any(predicted_labels == ground_truth_labels)
print("Any overlapping pixels:", overlapping_pixels)

In [None]:
from tqdm import tqdm

def compute_average_iou(folder_path, most_frequent_slices, ignore_background=True):
    """
    Compute the average IoU for each class across all images in a folder.

    Args:
        folder_path (str): Path to the dataset folder containing segmentation masks.
        most_frequent_slices (dict): Precomputed most frequent class segmentations.
        ignore_background (bool): If True, ignores background class (0) when computing mean IoU.

    Returns:
        avg_iou_per_class (dict): Average IoU per class across all images.
        overall_mean_iou (float): Mean IoU across all classes and images.
    """
    iou_per_class_total = {}  # To store sum of IoUs per class
    count_per_class = {}  # Count of images where the class appears
    num_images = 0  # Count total images processed

    for img_name in tqdm(os.listdir(folder_path)):
        if not img_name.endswith(".png"):  # Ensure only images are processed
            continue

        img_path = os.path.join(folder_path, img_name)
        ground_truth_labels = load_image(img_path)  # Load ground truth segmentation
        predicted_labels = predict_slice_label(img_name, most_frequent_slices)  # Predict segmentation

        # Compute IoU
        iou_per_class, _ = compute_iou(predicted_labels, ground_truth_labels, ignore_background=ignore_background)

        # Aggregate IoU values
        for cls, iou in iou_per_class.items():
            if np.isnan(iou):
                continue  # Ignore NaN values

            if cls not in iou_per_class_total:
                iou_per_class_total[cls] = 0
                count_per_class[cls] = 0

            iou_per_class_total[cls] += iou
            count_per_class[cls] += 1

        num_images += 1

    # Compute the final average IoU per class
    avg_iou_per_class = {cls: (iou_per_class_total[cls] / count_per_class[cls])
                         for cls in iou_per_class_total.keys()}

    # Compute overall mean IoU across all images and classes
    valid_ious = [iou for iou in avg_iou_per_class.values()]  # Avoid using `list()`
    overall_mean_iou = np.mean(valid_ious) if valid_ious else 0.0

    return avg_iou_per_class, overall_mean_iou

# Folder containing ground truth segmentations
dataset_folder = r"/Users/yp/PycharmProjects/AMOS/amos22/Validation/label"

# Compute IoU across the dataset
avg_iou_per_class, overall_mean_iou = compute_average_iou(dataset_folder, most_frequent_slices, ignore_background=False)

# Print results
print(f"Overall Mean IoU: {overall_mean_iou:.4f}")
print("Average IoU per class across dataset:")
for cls, iou in avg_iou_per_class.items():
    print(f"Class {cls}: {iou:.4f}")


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

def compute_label_distribution(folder_path):
    """
    Compute the relative frequency of each label across all images in a folder.

    Args:
        folder_path (str): Path to the dataset folder containing segmentation masks.

    Returns:
        label_distribution (dict): Dictionary where keys are class labels and values are their relative frequencies.
    """
    label_counts = defaultdict(int)
    total_pixels = 0

    for img_name in tqdm(os.listdir(folder_path)):
        if not img_name.endswith(".png"):  # Ensure only image files are processed
            continue

        img_path = os.path.join(folder_path, img_name)
        ground_truth_labels = load_image(img_path)  # Load segmentation mask

        unique, counts = np.unique(ground_truth_labels, return_counts=True)
        for label, count in zip(unique, counts):
            label_counts[label] += count
            total_pixels += count  # Accumulate total pixel count

    # Compute relative frequency
    label_distribution = {label: count / total_pixels for label, count in label_counts.items()}
    return label_distribution

def plot_label_distribution(label_distribution):
    """
    Plot a bar chart of label distribution.

    Args:
        label_distribution (dict): Dictionary where keys are class labels and values are their relative frequencies.
    """
    labels = [label for label in label_distribution.keys()]  # Explicitly create list
    frequencies = [freq for freq in label_distribution.values()]  # Explicitly create list


    plt.figure(figsize=(10, 5))
    plt.bar(labels, frequencies, color='skyblue')
    plt.xlabel("Class Label")
    plt.ylabel("Relative Frequency")
    plt.title("Relative Amount of Labels in the Dataset")
    plt.xticks(labels)  # Ensure all labels are shown
    plt.grid(axis="y", linestyle="--", alpha=0.7)

    plt.show()

In [None]:

# Path to dataset folder containing segmentation masks
dataset_folder = r"/Users/yp/PycharmProjects/AMOS/amos22/Validation/label"

# Compute and plot label distribution
label_distribution = compute_label_distribution(DATASET_PATH)

In [None]:
plot_label_distribution(label_distribution)