In [None]:
import tqdm
import os
import json
import numpy as np
import pycocotools.mask as mask_util
from datasets import load_dataset
from utils.visualization import visualize_masks
import matplotlib.pyplot as plt
import random 
from HEIT.metrics import create_circular_kernel, masks_to_contour, contour_recall


def get_samples(output_dir, split, resolution, model):
    samples_dict = {}
    results_dir = f'{output_dir}/{split}/{resolution}/{model}'
    for file in os.listdir(results_dir):
        samples_dict[int(file.split('.')[0])] = os.path.join(results_dir, file)

    samples = []
    for index in range(len(samples_dict)):
        samples.append(samples_dict[index])

    return samples

def decode_masks(sample_json):
    masks_rle = json.load(open(sample_json))
    masks = []
    for mask_rle in masks_rle:
        mask = mask_util.decode(mask_rle)
        masks.append(mask)
    masks = np.array(masks)
    return masks
    
output_dir = 'HEIT/outputs/tokenized_HEIT'
resolution = 1024


splits = [
    'SA1B', 'COCONut_relabeld_COCO_val', 'EntitySeg', 'PascalPanopticParts', 'plantorgans', 'MapillaryMetropolis', 
    'cityscapes', 'NYUDepthv2', 'tcd', 'FoodSeg103', 'ADE20k', 'WireFrame', 'ISAID', 'PhenoBench', 'EgoHOS', 'LIP', 
    'SOBA', 'CIHP', 'LoveDA', 'SPIN', 'SUIM', 'MyFood', 'DIS5K_DIS_VD', 'DUTS_TE', 'Fashionpedia', 'PartImageNetPP', 
    'SeginW', 'LVIS', 'PACO', 'DRAM'
    ]

split = 'EntitySeg'
dataset = load_dataset("chendelong/HEIT", split=split)
print(dataset)
print(dataset[0]) 


In [None]:
for split in splits:
    print('='*64)
    print(split)
    print('-'*64)
    dataset = load_dataset("chendelong/HEIT", split=split)

    models = []
    for model in os.listdir(f'{output_dir}/{split}/{resolution}'):
        # # model = 'superpixel_slic'
        # model = 'directsam_tiny_dsa_100ep@0.05'
        # # model = 'mobilesamv2'

        samples = get_samples(output_dir, split, resolution, model)
        if len(samples) == len(dataset):
            models.append(model)
        else:
            print(f'{model} only has {len(samples)}/{len(dataset)} samples')

    print('---')
    models.sort()
    for model in models:
        print(model)

In [3]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

def masks_to_contour_torch(masks, tolerance, device='cuda'):
    """
    Converts a set of masks to a contour map using PyTorch.

    Args:
        masks (torch.Tensor): A tensor of shape (N, H, W), where N is the number of masks.
        tolerance (int): The size of the neighborhood to consider for detecting edges.

    Returns:
        torch.BoolTensor: A boolean tensor of shape (H_resized, W_resized) indicating the contour.
    """
    # Ensure masks are in torch.Tensor format and move to device
    if not isinstance(masks, torch.Tensor):
        masks = torch.tensor(masks, device=device)
    else:
        masks = masks.to(device)

    # Create label map
    label_map = torch.zeros_like(masks[0], dtype=torch.int64)
    for i, mask in enumerate(masks):
        if torch.sum(mask) == 0:
            continue
        label_map += (i + 1) * mask

    # Resize label_map to (1024, 1024)
    label_map = label_map.unsqueeze(0).unsqueeze(0).float()  # Shape: (1, 1, H, W)
    label_map = F.interpolate(label_map, size=(1024, 1024), mode='nearest')
    label_map = label_map.squeeze(0).squeeze(0).long()  # Shape: (1024, 1024)

    # Perform dilation
    label_map = label_map.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, H, W)
    dilated = F.max_pool2d(label_map.float(), kernel_size=2 * tolerance + 1, stride=1, padding=tolerance)
    # Perform erosion
    eroded = -F.max_pool2d(-label_map.float(), kernel_size=2 * tolerance + 1, stride=1, padding=tolerance)

    # Compute boundaries
    boundaries = (dilated != eroded).squeeze(0).squeeze(0)
    boundaries &= label_map.squeeze(0).squeeze(0) != 0
    boundaries = boundaries.bool()

    return boundaries

def calculate_metrics(image, gt_contour, masks, tolerance, do_visualization=False, device='cuda'):
    # Ensure masks are in torch.Tensor format and move to device
    if not isinstance(masks, torch.Tensor):
        masks = torch.tensor(masks)
    masks = masks.to(device)

    # Remove empty masks
    mask_sums = masks.sum(dim=(1, 2))
    masks = masks[mask_sums > 0]

    # Process ground truth contour
    gt_contour[:tolerance] = gt_contour[-tolerance:] = gt_contour[:, :tolerance] = gt_contour[:, -tolerance:] = 0
    gt_contour = torch.tensor(gt_contour, device=device).bool()

    # Get predicted contour
    pred_contour = masks_to_contour_torch(masks, tolerance, device=device)

    # Compute True Positives (TP), False Positives (FP), and False Negatives (FN)
    TP = (gt_contour & pred_contour).sum().item()
    FP = ((~gt_contour) & pred_contour).sum().item()
    FN = (gt_contour & (~pred_contour)).sum().item()

    # Compute precision, recall, and F1 score
    if gt_contour.sum().item() == 0:
        precision = 1.0
        recall = 1.0
        f1 = 1.0
    else:
        precision = TP / (TP + FP + 1e-8)
        recall = TP / (TP + FN + 1e-8)
        f1 = 2 * precision * recall / (precision + recall + 1e-8)

    # Sort masks based on the number of pixels, from large to small
    mask_sizes = masks.sum(dim=(1, 2))
    sorted_indices = torch.argsort(mask_sizes, descending=True)
    masks = masks[sorted_indices]
    mask_sizes = mask_sizes[sorted_indices].tolist()

    metrics = {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'mask_sizes': mask_sizes
    }

    if do_visualization:
        plt.figure(figsize=(40, 10))
        plt.subplot(1, 4, 1)
        plt.imshow(image)

        plt.subplot(1, 4, 2)
        plt.imshow(visualize_masks(image, masks.cpu().numpy()))
        plt.title(f"{(masks.sum(dim=(1, 2)) > 0).sum().item()} tokens")

        plt.subplot(1, 4, 3)
        plt.imshow(gt_contour.cpu().numpy(), cmap='Greens')
        plt.title('Ground Truth')

        plt.subplot(1, 4, 4)
        plt.imshow(pred_contour.cpu().numpy(), cmap='Blues', alpha=0.3)
        plt.title(f'Predicted: P={precision:.2f}, R={recall:.2f}, F1={f1:.2f}')

        plt.show()

    return metrics

In [None]:

model = 'directsam_tiny_dsa_100ep@0.1'
samples = get_samples(output_dir, split, resolution, model)

tolerance = 5
max_tokens = 576

all_metrics = []
for i in tqdm.tqdm(range(20)):
    sample = dataset[i]
    image = np.array(sample['image'])
    gt_contour = np.array(sample['contour'])

    masks = decode_masks(samples[i]).astype(np.int32)


    if len(masks) > max_tokens:
        print(f"Sample {i} has {len(masks)} tokens")
        masks = masks[:max_tokens]

    metrics = calculate_metrics(image, gt_contour, masks, tolerance=tolerance, do_visualization=True)
    all_metrics.append(metrics)


In [None]:


all_mask_sizes = np.zeros((len(all_metrics), max_tokens), dtype=np.int32)
for i, metrics in enumerate(all_metrics):
    all_mask_sizes[i, :len(metrics['mask_sizes'])] = metrics['mask_sizes'][:max_tokens]

accumulated_mask_sizes = np.cumsum(all_mask_sizes, axis=1) / (resolution * resolution)
accumulated_mask_sizes = np.mean(accumulated_mask_sizes, axis=0)

plt.figure(figsize=(20, 10))
plt.plot(accumulated_mask_sizes)

# plot the point that it reaches 99% and 100% of the total area
for percentage in [0.90, 0.95, 0.98, 0.99]:
    if accumulated_mask_sizes[-1] >= percentage:
        x = np.argmax(accumulated_mask_sizes >= percentage)
        plt.plot(x, percentage, 'ro')
        plt.text(x, percentage, f'{percentage:.0%}:{x} tokens', va='bottom', ha='right')

plt.xlim(0, max_tokens)
plt.show()