In [None]:
import tqdm
import os
import json
import numpy as np
import pycocotools.mask as mask_util
from datasets import load_dataset
from utils.visualization import visualize_masks
import matplotlib.pyplot as plt
import random
import torch
import torch.nn.functional as F


def load_samples(output_dir, split, resolution, model):
    samples_dict = {}
    results_dir = f'{output_dir}/{split}/{resolution}/{model}'
    for file in os.listdir(results_dir):
        samples_dict[int(file.split('.')[0])] = os.path.join(results_dir, file)

    samples = []
    for index in range(len(samples_dict)):
        samples.append(samples_dict[index])

    return samples

def decode_masks(mask_rles):
    masks = []
    for mask_rle in mask_rles:
        mask = mask_util.decode(mask_rle)
        masks.append(mask)
    masks = np.array(masks)
    mask_sums = masks.sum(axis=(1, 2))
    masks = masks[mask_sums > 0]
    return masks


def label_map_to_random_rgb(label_map, seed=None):
    """
    Convert a (H, W) integer label map into a (H, W, 3) random RGB image.

    Args:
        label_map (torch.Tensor or np.ndarray): Single-channel label map of shape (H, W).
        seed (int, optional): If provided, sets a random seed for reproducible colors.

    Returns:
        np.ndarray: Color image of shape (H, W, 3) with dtype=np.uint8.
    """
    # If user wants reproducible colors, set the seed
    if seed is not None:
        np.random.seed(seed)

    # Convert label_map to a CPU tensor (if torch) or ndarray
    if isinstance(label_map, torch.Tensor):
        label_map = label_map.cpu().numpy()
    label_map = label_map.astype(np.int32)

    # Determine the number of classes (maximum label)
    max_label = label_map.max()
    if max_label < 1:
        # If there are no labels > 0, just return a blank image
        h, w = label_map.shape
        return np.zeros((h, w, 3), dtype=np.uint8)

    # Generate random colors for each label: shape (max_label+1, 3)
    #   e.g., color 0 = black, color 1 = random, color 2 = random, etc.
    #   If you want label 0 to also have a random color, just remove the special case for 0.
    colors = np.zeros((max_label + 1, 3), dtype=np.uint8)
    for lbl in range(1, max_label + 1):
        colors[lbl] = np.random.randint(0, 256, size=3)

    # Map each label to its corresponding color
    rgb_image = colors[label_map]

    return rgb_image
    

In [2]:

def masks_to_label_map(masks, device='cuda', output_size=(1024, 1024)):
    """
    Converts multiple binary masks into a single labeled map (with optional resizing).

    Args:
        masks (torch.Tensor or array-like): A tensor (or array) of shape (N, H, W), 
            where N is the number of masks.
        device (str): The device to place the tensor (e.g. 'cpu', 'cuda').
        output_size (tuple): The desired output height and width (H_out, W_out).

    Returns:
        torch.LongTensor: A labeled map of shape (H_out, W_out).
    """

    # Ensure masks are in torch.Tensor format and move to device
    if not isinstance(masks, torch.Tensor):
        masks = torch.tensor(masks, device=device)
    else:
        masks = masks.to(device)

    # Create label map
    #   Each non-zero region in mask i is labeled with (i+1).
    label_map = torch.zeros_like(masks[0], dtype=torch.int64)
    for i, mask in enumerate(masks):
        if torch.sum(mask) == 0:
            continue
        label_map += (i + 1) * mask

    # Optionally resize label_map to (H_out, W_out)
    if output_size is not None:
        label_map = label_map.unsqueeze(0).unsqueeze(0).float()  # Shape: (1,1,H,W)
        label_map = F.interpolate(label_map, size=output_size, mode='nearest')
        label_map = label_map.squeeze(0).squeeze(0).long()       # Shape: (H_out, W_out)

    return label_map


def label_map_to_contour(label_map, tolerance):
    """
    Converts a labeled map into a contour (boundary) map using dilation/erosion.

    Args:
        label_map (torch.LongTensor): A labeled map of shape (H, W).
        tolerance (int): The size of the neighborhood to consider for detecting edges.

    Returns:
        torch.BoolTensor: A boolean tensor of shape (H, W) indicating the contour.
    """

    # Prepare label_map for max_pool2d (batch and channel dimension)
    label_map_4d = label_map.unsqueeze(0).unsqueeze(0).float()  # (1,1,H,W)

    # Perform dilation
    dilated = F.max_pool2d(
        label_map_4d,
        kernel_size=2 * tolerance + 1,
        stride=1,
        padding=tolerance
    )
    # Perform erosion (negate, max_pool, then negate back)
    eroded = -F.max_pool2d(
        -label_map_4d,
        kernel_size=2 * tolerance + 1,
        stride=1,
        padding=tolerance
    )

    # Detect boundaries
    boundaries = (dilated != eroded).squeeze(0).squeeze(0).bool()

    # Exclude background (label=0) from boundaries
    boundaries &= (label_map != 0)

    return boundaries


In [3]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt


def contour_metrics(
    contour_gt: torch.Tensor,
    contour_pred: torch.Tensor,
    contour_pred_dilated: torch.Tensor,
    tolerance: int = 5,
    eps: float = 1e-6
):
    """
    Calculate precision, recall, and F1-score between two boolean boundary maps,
    ignoring the outer 'tolerance' pixels on each edge.

    Args:
        contour_gt (torch.Tensor): Ground truth boundary map (bool), shape (H, W).
        contour_pred (torch.Tensor): Predicted boundary map (bool), shape (H, W).
        tolerance (int): Number of pixels to crop from each edge.
        eps (float): A small value to avoid division by zero.

    Returns:
        dict: A dictionary containing 'precision', 'recall', and 'f1' as floats.
    """
    # Ensure both tensors are boolean
    contour_gt = contour_gt.bool()
    contour_pred = contour_pred.bool()
    contour_pred_dilated = contour_pred_dilated.bool()

    # Verify that the image is large enough to crop
    cropped_gt = contour_gt[tolerance:-tolerance, tolerance:-tolerance]
    cropped_pred = contour_pred[tolerance:-tolerance, tolerance:-tolerance]
    contour_pred_dilated = contour_pred_dilated[tolerance:-tolerance, tolerance:-tolerance]

    # compute precision with raw prediction
    tp = torch.sum(cropped_gt & cropped_pred).float()
    fp = torch.sum(~cropped_gt & cropped_pred).float()
    fn = torch.sum(cropped_gt & ~cropped_pred).float()
    precision = tp / (tp + fp + eps)

    # compute recall with dilated prediction
    tp = torch.sum(cropped_gt & contour_pred_dilated).float()
    fp = torch.sum(~cropped_gt & contour_pred_dilated).float()
    fn = torch.sum(cropped_gt & ~contour_pred_dilated).float()
    recall = tp / (tp + fn + eps)


    # Compute F1
    f1 = 2 * (precision * recall) / (precision + recall + eps)

    return {
        'precision': precision.item(),
        'recall': recall.item(),
        'f1': f1.item()
    }

In [None]:

results_dir = 'evaluation_intrinsic/outputs/segmentation_results'
output_dir = 'evaluation_intrinsic/outputs/segmentation_metrics'

resolutions = [
    384, 
    768, 
    1024, 
    1500
    ]

splits = [
    'SA1B', 
    'COCONut_relabeld_COCO_val', 
    'PascalPanopticParts', 
    'ADE20k',
    'EgoHOS'
    ]

all_models = [
    'directsam_large_sa1b_2ep@0.05', 'directsam_large_sa1b_2ep@0.1', 'directsam_large_sa1b_2ep@0.15', 'directsam_large_sa1b_2ep@0.2', 
    'directsam_large_sa1b_2ep@0.25', 'directsam_large_sa1b_2ep@0.3', 'directsam_large_sa1b_2ep@0.35', 'directsam_large_sa1b_2ep@0.4', 
    'directsam_large_sa1b_2ep@0.45', 'directsam_large_sa1b_2ep@0.5', 'directsam_tiny_sa1b_2ep@0.05', 
    'directsam_tiny_sa1b_2ep@0.1', 'directsam_tiny_sa1b_2ep@0.15', 'directsam_tiny_sa1b_2ep@0.2', 'directsam_tiny_sa1b_2ep@0.25', 
    'directsam_tiny_sa1b_2ep@0.3', 'directsam_tiny_sa1b_2ep@0.35', 'directsam_tiny_sa1b_2ep@0.4', 'directsam_tiny_sa1b_2ep@0.45', 
    'directsam_tiny_sa1b_2ep@0.5', 
    
    'fastsam', 'mobilesamv2', 
    
    'panoptic_mask2former_base', 'panoptic_mask2former_large', 'panoptic_mask2former_small', 'panoptic_mask2former_tiny', 
    'panoptic_mask2former_large_ade',
    'panoptic_oneformer_large', 'panoptic_oneformer_tiny', 
    'panoptic_oneformer_large_coco',
    
    'patch_10_per_side_raster', 'patch_11_per_side_raster', 'patch_12_per_side_raster', 'patch_13_per_side_raster', 'patch_14_per_side_raster', 'patch_15_per_side_raster', 
    'patch_16_per_side_raster', 'patch_17_per_side_raster', 'patch_18_per_side_raster', 'patch_19_per_side_raster', 'patch_20_per_side_raster', 'patch_21_per_side_raster', 
    'patch_22_per_side_raster', 'patch_23_per_side_raster', 'patch_24_per_side_raster', 'patch_25_per_side_raster', 'patch_26_per_side_raster', 'patch_27_per_side_raster', 
    'patch_28_per_side_raster', 'patch_29_per_side_raster', 'patch_2_per_side_raster', 'patch_30_per_side_raster', 'patch_31_per_side_raster', 'patch_3_per_side_raster', 
    'patch_4_per_side_raster', 'patch_5_per_side_raster', 'patch_6_per_side_raster', 'patch_7_per_side_raster', 'patch_8_per_side_raster', 'patch_9_per_side_raster', 
    
    'sam_vit_b', 'sam_vit_h', 'sam_vit_h_48points', 'sam_vit_h_64points', 'sam_vit_h_64points_1layer', 'sam_vit_l', 
    
    'superpixel_slic_100', 'superpixel_slic_121', 'superpixel_slic_144', 'superpixel_slic_16', 'superpixel_slic_169', 'superpixel_slic_196', 'superpixel_slic_225', 
    'superpixel_slic_25', 'superpixel_slic_256', 'superpixel_slic_36', 'superpixel_slic_4', 'superpixel_slic_49', 'superpixel_slic_64', 'superpixel_slic_81', 'superpixel_slic_9'
    ]


for resolution in resolutions:
    for split in splits:
        print('\n\n')
        print('='*64)
        print(split)
        print('-'*64)
        dataset = load_dataset("chendelong/HEIT", split=split)
        os.makedirs(f'{output_dir}/{split}/{resolution}', exist_ok=True)

        models = []
        for model in os.listdir(f'{results_dir}/{split}/{resolution}'):
            samples_pred = load_samples(results_dir, split, resolution, model)
            if len(samples_pred) == len(dataset):
                models.append(model)
            else:
                print(f'{model} at {resolution} resolution only has {len(samples_pred)}/{len(dataset)} samples')

        # find and print missing models
        missing_models = set(all_models) - set(models)
        additional_models = set(models) - set(all_models)

        if missing_models:
            print("Missing models:")
            for model in sorted(missing_models):
                print('-', model)

        if additional_models:
            print("Additional models:")
            for model in sorted(additional_models):
                print('-', model)

        # all_models.extend(models)
        # all_models = list(set(all_models))

        print(f'{len(models)} models at {resolution} resolution in {split} dataset')
        models.sort()
        # print(models)
        # for model in models:
        #     print(model)

In [5]:
import torch
import torch.nn.functional as F

def compute_monosemanticity(
    contour_gt: torch.Tensor, 
    label_map_pred: torch.LongTensor, 
    tolerance: int = 5
) -> list:
    """
    Perform morphological erosion by `tolerance` for each label, then count how many
    contour_gt pixels fall within each eroded label region.

    Args:
        contour_gt (torch.Tensor): Boolean tensor of shape (H, W), where True indicates 
            contour pixels.
        label_map_pred (torch.LongTensor): Label map of shape (H, W), where 0..K = different segments.
        tolerance (int): The amount (in pixels) by which to shrink each label region.

    Returns:
        list[int]: Overlap counts (one per label, in ascending label_id order).
    """
    unique_labels = torch.unique(label_map_pred)

    # 1) Create a stacked binary mask for each label: shape (K, H, W)
    label_stack = torch.stack([(label_map_pred == lbl) for lbl in unique_labels], dim=0)

    # 2) Morphologically erode each label in a single batch
    #    Convert to float: True -> 1.0, False -> 0.0
    label_stack_4d = label_stack.unsqueeze(1).float()  # (K,1,H,W)
    #    Erosion = -max_pool2d(-mask)
    eroded_4d = -F.max_pool2d(
        -label_stack_4d, 
        kernel_size=2 * tolerance + 1,
        stride=1,
        padding=tolerance
    )
    eroded_stack = eroded_4d.squeeze(1).bool()  # (K,H,W)

    # 3) Count overlap with contour_gt
    #    Expand contour_gt to shape (K,H,W) for batch-wise AND
    overlap_tensor = eroded_stack & contour_gt.unsqueeze(0)
    overlap_counts = overlap_tensor.sum(dim=(1,2))  # (K,)

    # Return as a list (aligned with unique_labels order)
    return overlap_counts.tolist()

In [None]:
import cv2 

# model = 'patch_31_per_side_raster'
model = 'directsam_tiny_sa1b_2ep@0.1'

results_dir = 'evaluation_intrinsic/outputs/segmentation_results'
split = 'EgoHOS'
resolution = 768

dataset = load_dataset("chendelong/HEIT", split=split)

samples_pred = load_samples(results_dir, split, resolution, model)

tolerance_recall = 5
tolerance_monosemanticity = 25

device = 'cuda'
do_visualization = True

total_area = resolution**2

all_metrics = []
for i in tqdm.tqdm(range(len(samples_pred))):
    i = random.randint(0, len(samples_pred))
    sample = dataset[i]
    image = np.array(sample['image'])

    contour_gt = np.array(sample['contour'])
    contour_gt = torch.tensor(contour_gt, device=device).bool()

    sample_pred = json.load(open(samples_pred[i]))

    masks_pred = decode_masks(sample_pred['rles'])
    masks_pred = torch.tensor(masks_pred, device=device)

    label_map_pred = masks_to_label_map(masks_pred, output_size=(1024, 1024))
    contour_pred_thin = label_map_to_contour(label_map_pred, 1)
    contour_pred_dilated = label_map_to_contour(label_map_pred, tolerance_recall)

    metrics = contour_metrics(contour_gt, contour_pred_thin, contour_pred_dilated, tolerance=tolerance_recall)
    metrics['time'] = sample_pred['time']
    metrics['n_tokens'] = masks_pred.shape[0]

    mask_areas = masks_pred.sum(dim=(1, 2)) / total_area
    mask_areas = mask_areas.cpu().numpy().tolist()
    metrics['mask_areas'] = mask_areas
    
    monosemanticity = compute_monosemanticity(contour_gt, label_map_pred, tolerance_recall*5)
    metrics['monosemanticity'] = monosemanticity
    
    print(metrics)
    all_metrics.append(metrics)

    if do_visualization:

        contour_gt = contour_gt.cpu().numpy().astype(np.uint8)
        dilation_size = 2
        kernel = cv2.getStructuringElement(
            cv2.MORPH_ELLIPSE, 
            (2*dilation_size + 1, 2*dilation_size + 1)
        )
        contour_gt_dilated = cv2.dilate(contour_gt, kernel)
        contour_gt_dilated = contour_gt_dilated.astype(bool)


        plt.figure(figsize=(40, 10))
        plt.subplot(1, 4, 1)
        plt.imshow(image)

        plt.subplot(1, 4, 2)

        label_map_vis = label_map_to_random_rgb(label_map_pred.cpu().numpy())
        label_map_vis[contour_pred_thin.cpu().numpy()>0] = [255, 255, 255]
        plt.imshow(label_map_vis)

        plt.title(f"{(masks_pred.sum(dim=(1, 2)) > 0).sum().item()} tokens")

        plt.subplot(1, 4, 3)
        plt.imshow(contour_pred_thin.cpu().numpy(), cmap='Blues', alpha=0.8)
        plt.imshow(contour_pred_dilated.cpu().numpy(), cmap='Blues', alpha=0.3)
        plt.imshow(contour_gt_dilated, cmap='Reds', alpha=0.5)
        plt.imshow(image, alpha=0.1)
        plt.title(f"Predicted: P={metrics['precision']:.2f}, R={metrics['recall']:.2f}, F1={metrics['f1']:.2f}")

        plt.subplot(1, 4, 4)
        plt.imshow(contour_pred_thin.cpu().numpy(), cmap='Blues')
        plt.imshow(image, alpha=0.3)

        unique_labels = torch.unique(label_map_pred)

        visual_label_map = torch.zeros_like(label_map_pred)
        for lbl_i, lbl_id in enumerate(unique_labels):
            if monosemanticity[lbl_i] == 0:
            # if monosemanticity[lbl_i] == 0:
                visual_label_map[label_map_pred == lbl_id] = 0
            else:
                visual_label_map[label_map_pred == lbl_id] = 1

        visual_label_map_np = visual_label_map.cpu().numpy()
        color_map = np.array([
            [50, 150,   150],   # monosemantic
            [255,  50,   50],  # not monosemantic
        ], dtype=np.uint8)
        rgb_image = color_map[visual_label_map_np]

        plt.imshow(contour_gt_dilated, cmap='Reds')

        rgb_image[contour_pred_thin.cpu().numpy()>0] = [255, 255, 255]
        # plt.imshow(contour_pred.cpu().numpy(), cmap='Blues', alpha=0.2)
        plt.imshow(rgb_image, alpha=0.5)

        plt.show()

    break

# output_file = f'{output_dir}/{split}/{resolution}/{model}.json'
# json.dump(all_metrics, open(output_file, 'w'))
