In [None]:
from datasets import load_dataset

splits = [
    'SA1B', 'COCONut_relabeld_COCO_val', 'EntitySeg', 'PascalPanopticParts', 'plantorgans', 'MapillaryMetropolis', 
    'cityscapes', 'NYUDepthv2', 'tcd', 'FoodSeg103', 'ADE20k', 'WireFrame', 'ISAID', 'PhenoBench', 'EgoHOS', 'LIP', 
    'SOBA', 'CIHP', 'LoveDA', 'SPIN', 'SUIM', 'MyFood', 'DIS5K_DIS_VD', 'DUTS_TE', 'Fashionpedia', 'PartImageNetPP', 
    'SeginW', 'LVIS', 'PACO', 'DRAM'
    ]

split = 'EgoHOS'
dataset = load_dataset("chendelong/HEIT", split=split)
print(dataset)
print(dataset[0]) 


In [None]:
import os
import json
import numpy as np
import pycocotools.mask as mask_util
from utils.visualization import visualize_masks

def get_samples(output_dir, split, resolution, model):
    samples_dict = {}
    results_dir = f'{output_dir}/{split}/{resolution}/{model}'
    for file in os.listdir(results_dir):
        samples_dict[int(file.split('.')[0])] = os.path.join(results_dir, file)

    samples = []
    for index in range(len(samples_dict)):
        samples.append(samples_dict[index])

    print(f'Number of samples: {len(samples)}')
    return samples

def decode_masks(sample_json):
    masks_rle = json.load(open(sample_json))
    masks = []
    for mask_rle in masks_rle:
        mask = mask_util.decode(mask_rle)
        masks.append(mask)
    masks = np.array(masks)
    return masks
    
output_dir = 'HEIT/outputs/tokenized_HEIT'
resolution = 1024

# model = 'superpixel_slic'
model = 'directsam_tiny_dsa_100ep@0.05'
# model = 'mobilesamv2'

samples = get_samples(output_dir, split, resolution, model)

In [None]:
import matplotlib.pyplot as plt
import random 
from HEIT.metrics import create_circular_kernel, masks_to_contour, contour_recall

# Global circular kernel
tolerance = 10
CIRCULAR_KERNEL = create_circular_kernel(tolerance)

for i in range(5):
    sample = dataset[i]
    image = np.array(sample['image'])
    masks = decode_masks(samples[i]).astype(np.int32)

    plt.figure(figsize=(20, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(image)

    plt.subplot(1, 3, 2)
    plt.imshow(visualize_masks(image, masks))
    plt.title(f"{(np.sum(masks, axis=(1, 2)) > 0).sum()} tokens")

    labels = np.zeros_like(masks[0]).astype(np.int32)
    for i, mask in enumerate(masks):
        if np.sum(mask) == 0:
            continue
        labels += (i + 1) * mask

    plt.subplot(1, 3, 3)
    plt.imshow(labels, cmap='plasma')
    plt.title('order')
    plt.show()

    gt_contour = np.array(sample['contour'])
    gt_contour[:tolerance] = gt_contour[-tolerance:] = gt_contour[:, :tolerance] = gt_contour[:, -tolerance:] = 0
    pred_contour = masks_to_contour(masks, CIRCULAR_KERNEL)

    # calculate recall
    recall = contour_recall(gt_contour, pred_contour)
    missing_contour = np.logical_and(gt_contour, np.logical_not(pred_contour))

    plt.figure(figsize=(20, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(gt_contour, cmap='Greens')
    plt.title('Ground Truth')

    plt.subplot(1, 3, 2)
    plt.imshow(missing_contour, cmap='Reds')
    plt.imshow(image, alpha=0.1)
    plt.title(f'Recall: {recall:.2f}')

    plt.subplot(1, 3, 3)
    plt.imshow(missing_contour, cmap='Reds')
    plt.imshow(pred_contour, cmap='Blues', alpha=0.3)
    plt.title('Predicted')

    plt.show()
