In [1]:
import json
import PIL.Image as Image

data = json.loads(open('USIS10K\\multi_class_annotations\\multi_class_test_annotations.json').read())

id_to_img = {}

for img in data['images']:
    id_to_img[img['id']] = img['file_name']

from collections import defaultdict

from collections import defaultdict
import numpy as np
from pycocotools import mask as mask_utils
from PIL import Image, ImageDraw

PROMPT_TO_CATEGORY = {
    "wrecks/ruins": 1,
    "fish": 2,
    "reefs": 3,
    "aquatic plants": 4,
    "human divers": 5,
    "robots": 6,
    "sea-floor": 7,
}

CATEGORY_TO_PROMPT = {v: k for k, v in PROMPT_TO_CATEGORY.items()}


def polygon_to_rle(segmentation, img_height, img_width):
    # segmentation may be: [x1,y1,...] or [[...]] or list of polygons
    if len(segmentation) == 0:
        return None

    # Normalize to list of polygons
    if isinstance(segmentation[0], list):
        polys = segmentation
    else:
        polys = [segmentation]

    mask = Image.new("L", (img_width, img_height), 0)
    draw = ImageDraw.Draw(mask)

    for poly in polys:
        coords = np.array(poly).reshape(-1, 2)
        draw.polygon(coords.flatten().tolist(), outline=1, fill=1)

    mask_np = np.array(mask, dtype=np.uint8)
    rle = mask_utils.encode(np.asfortranarray(mask_np))
    rle["counts"] = rle["counts"].decode("ascii")

    return rle


def group_items_to_list(items):
    grouped = defaultdict(lambda: defaultdict(lambda: {"boxes": [], "masks_rle": []}))

    for d in items:
        img_path = d["img_path"]
        img_loaded = Image.open(img_path)

        cat = CATEGORY_TO_PROMPT[d["category_id"]]

        grouped[img_path][cat]["boxes"].append(d["bbox"])
        grouped[img_path][cat]["masks_rle"].append(
            polygon_to_rle(d["segmentation"], img_loaded.height, img_loaded.width)
        )

    out = []
    for img_path, cats in grouped.items():
        entry = {"img_path": img_path}
        entry.update(cats)
        out.append(entry)

    return out


def load_data(path='USIS10K\\multi_class_annotations\\multi_class_test_annotations.json'):
    data = json.loads(open(path).read())
    polished_data = []

    for annotation in data['annotations']:
        category_id = annotation['category_id']
        image_id = annotation['image_id']
        bbox = annotation['bbox']
        segmentation = annotation['segmentation']
        img_path = f"USIS10K/test/{id_to_img[image_id]}"

        polished_data.append({
            'img_path': img_path,
            'category_id': category_id,
            'bbox': bbox,
            'segmentation': segmentation
        })

    return group_items_to_list(polished_data)

def load_usis_preds_data(path='usis_sam_preds_rle.json'):
    data = json.loads(open(path).read())
    polished_data = []

    for annotation in data['predictions']:
        category_id = annotation['category_id']
        image_id = annotation['image_id']
        bbox = annotation['bbox']
        segmentation = annotation['segmentation']
        img_path = f"USIS10K/test/{id_to_img[image_id]}"

        polished_data.append({
            'img_path': img_path,
            'category_id': category_id,
            'bbox': bbox,
            'segmentation': segmentation
        })

    return polished_data

In [2]:
gt_data = load_data()


In [3]:
def convert_detections(input_data):
    """
    Convert detection data from input format to output format.
    
    Args:
        input_data: List of dictionaries with 'file_name' and 'detections'
        
    Returns:
        List of dictionaries with 'img_path' and category-based detections
    """
    
    # Category mapping
    LABEL_TO_CATEGORY = {
        "Fish": "fish",
        "Reefs": "reefs",
        "Aquatic plants": "aquatic plants",
        "Human divers": "human divers",
        "Robots": "robots",
        "Sea-floor": "sea-floor",
        "Wrecks/ruins": "wrecks/ruins"
    }
    
    # All possible categories
    ALL_CATEGORIES = [
        "fish", "wrecks/ruins", "reefs", "aquatic plants", 
        "human divers", "robots", "sea-floor"
    ]
    
    result = []
    
    for item in input_data:
        file_name = item['file_name']
        detections = item['detections']
        
        # Initialize output structure
        output_item = {
            'img_path': f'USIS10K/test/{file_name}.jpg',
        }

        image = Image.open(output_item['img_path'])
        width, height = image.size
        
        # Initialize all categories with empty lists
        for category in ALL_CATEGORIES:
            output_item[category] = {
                'boxes': [],
                'scores': [],
                'masks_rle': []
            }
        
        # Group detections by category
        for detection in detections:
            label = detection['label']
            box_2d = detection['box_2d']

            try:
                abs_y1 = int(box_2d[0]/1000 * height)
                abs_x1 = int(box_2d[1]/1000 * width)
                abs_y2 = int(box_2d[2]/1000 * height)
                abs_x2 = int(box_2d[3]/1000 * width)
                new_bbox = [abs_x1, abs_y1, abs_x2, abs_y2]
            except Exception as e:
                print(f"Error processing box_2d {box_2d} for image {file_name}: {e}")
                new_bbox = [0, 0, 0, 0]
            
            # Get the category name (normalize case)
            category = LABEL_TO_CATEGORY.get(label)
            
            if category:
                # Add box to the appropriate category
                output_item[category]['boxes'].append(new_bbox)
                output_item[category]['scores'].append(1.0)  # Default score
        
        result.append(output_item)
    
    return result

gemini_2_5_flash_converted = convert_detections(json.load(open('gemini_2.5_flash.json')))
gemini_2_5_pro_converted = convert_detections(json.load(open('gemini_2.5_pro.json')))


Error processing box_2d [180, 480, 680] for image test_01506: list index out of range
Error processing box_2d [550, 570, 750] for image test_01506: list index out of range


In [4]:
len(gt_data), len(gemini_2_5_flash_converted), len(gemini_2_5_pro_converted)

(1596, 1594, 1596)

In [5]:
print(gt_data[0].keys())
print(gt_data[1].keys())
print(gemini_2_5_flash_converted[0]['fish'].keys())

dict_keys(['img_path', 'fish'])
dict_keys(['img_path', 'reefs'])
dict_keys(['boxes', 'scores', 'masks_rle'])


In [6]:
gt_data = {img['img_path']: img for img in gt_data}
gemini_2_5_flash_converted = {img['img_path']: img for img in gemini_2_5_flash_converted}
gemini_2_5_pro_converted = {img['img_path']: img for img in gemini_2_5_pro_converted}

In [7]:
for key in list(gt_data.keys()):
    gt_instance = gt_data[key]
    gemini_instance = gemini_2_5_flash_converted.get(key)

    print(gt_instance, gemini_instance)
    break

{'img_path': 'USIS10K/test/test_00001.jpg', 'fish': {'boxes': [[335.0, 243.0, 200.0, 218.0], [117.0, 62.0, 78.0, 180.0]], 'masks_rle': [{'size': [480, 640], 'counts': 'P[m46i>2O2M2N2O1N2N2N2N2M3M3M3N2N2N2O1N2M4L3M3M4L3M3M4L3M3M4L3M3N2M3M4L6J4L3M3N3L5L3M2M3N2N2N1O2N3M2N2N3N1N2N3N1N2O1O2M2O1O1N2O2N1N2O1O1N3N1O1O1O1O1O1N2O001O1O1O1O1O1O1O001O1O1O0000BlFPKT9Q5mFlJT9U5;100O001O100O001O1O100O001O3M2O2M2N1O1O01O01O0001O01O0000010O001O010O001O010O001O010O0010O01O0010O01O010O0010O01O010O0100O1O1O100O1O100O100O2N1O100O2N1O100O2N1O2N2M2O2YLkF_1W9YNRGd1P9WNSGh1P9RNTGn1Q9hMTGW2Y9WMmFV1iN1l;GaD3a;HhD1Z;MmDF\\;8hDEX;:mDBS;=QE]OR;b0g1N1N2N3N1O2N1O2O1N3M2O2N3Mlk`1'}, {'size': [480, 640], 'counts': 'imf1k1l<a0B9H8I8J8I6K3M2N3M3M3N2M4M2N3M3N06J4L2N2M2N2M4M2M3M4K4O2M2O1O00102M3NO0M2M3N2N2N1O1000O2N3jL]Di2Q<eMSE5T;LnDG];9e1O010O10O10O10O01000O010000O010O100O00100O1O2N101N2N2N2M`^`6'}]}} {'img_path': 'USIS10K/test/test_00001.jpg', 'fish': {'boxes': [[336, 233, 547, 478], [101, 52, 179, 280], [140, 172, 183,

In [8]:
def calculate_iou(box1, box2):
    """
    Calculate IoU between two boxes.
    Boxes are in format [x_min, y_min, x_max, y_max]
    """
    x1_min, y1_min, x1_max, y1_max = box1
    x2_min, y2_min, x2_max, y2_max = box2
    
    # Calculate intersection
    inter_x_min = max(x1_min, x2_min)
    inter_y_min = max(y1_min, y2_min)
    inter_x_max = min(x1_max, x2_max)
    inter_y_max = min(y1_max, y2_max)
    
    if inter_x_max <= inter_x_min or inter_y_max <= inter_y_min:
        return 0.0
    
    inter_area = (inter_x_max - inter_x_min) * (inter_y_max - inter_y_min)
    
    # Calculate union
    box1_area = (x1_max - x1_min) * (y1_max - y1_min)
    box2_area = (x2_max - x2_min) * (y2_max - y2_min)
    union_area = box1_area + box2_area - inter_area
    
    return inter_area / union_area if union_area > 0 else 0.0

def gt_bbox_to_xyxy(bbox):
    """
    Convert ground truth bbox from [x, y, width, height] to [x_min, y_min, x_max, y_max]
    """
    x, y, w, h = bbox
    return [x, y, x + w, y + h]


def polygon_to_rle(segmentation, img_height, img_width):
    """
    Convert polygon segmentation to RLE format
    """
    rles = mask_utils.frPyObjects(segmentation, img_height, img_width)
    rle = mask_utils.merge(rles)
    return rle


def calculate_mask_iou(rle1, rle2):
    """
    Calculate IoU between two masks in RLE format
    """
    iou = mask_utils.iou([rle1], [rle2], [0])
    return iou[0][0]

def evaluate_recall(predictions, gt_data, iou_threshold=0.5, ignore_labels=False):
    results = {
        'overall': {'total_gt': 0, 'detected': 0, 'recall': 0.0},
        'per_category': {}
    }

    for pred_item, gt_item in zip(predictions, gt_data):
        # Iterate through every category in this ground truth item
        for gt_category_name, gt_category_data in gt_item.items():
            if gt_category_name == "img_path":
                continue

            gt_boxes = gt_category_data["boxes"]

            # Ensure category stats exist
            if gt_category_name not in results['per_category']:
                results['per_category'][gt_category_name] = {
                    'total_gt': 0, 'detected': 0, 'recall': 0.0
                }

            for gt_box in gt_boxes:
                gt_xyxy = gt_bbox_to_xyxy(gt_box)

                results['overall']['total_gt'] += 1
                results['per_category'][gt_category_name]['total_gt'] += 1

                detected = False

                if ignore_labels:
                    # search in ALL predicted categories
                    for pred_cat_name, pred_data in pred_item.items():
                        if pred_cat_name == "img_path":
                            continue
                        for pred_box in pred_data["boxes"]:
                            if calculate_iou(pred_box, gt_xyxy) >= iou_threshold:
                                detected = True
                                break
                        if detected:
                            break
                else:
                    # search only in the SAME predicted category
                    if gt_category_name in pred_item:
                        for pred_box in pred_item[gt_category_name]["boxes"]:
                            if calculate_iou(pred_box, gt_xyxy) >= iou_threshold:
                                detected = True
                                break

                if detected:
                    results['overall']['detected'] += 1
                    results['per_category'][gt_category_name]['detected'] += 1

    # Final recall calculations
    if results['overall']['total_gt'] > 0:
        results['overall']['recall'] = (
            results['overall']['detected'] / results['overall']['total_gt']
        )

    for cat_name, s in results['per_category'].items():
        if s['total_gt'] > 0:
            s['recall'] = s['detected'] / s['total_gt']

    return results


In [9]:
results = evaluate_recall(gemini_2_5_flash_converted.values(), gt_data.values(), ignore_labels=False)
print(results)

{'overall': {'total_gt': 2858, 'detected': 86, 'recall': 0.030090972708187544}, 'per_category': {'fish': {'total_gt': 1565, 'detected': 24, 'recall': 0.015335463258785943}, 'reefs': {'total_gt': 758, 'detected': 56, 'recall': 0.07387862796833773}, 'human divers': {'total_gt': 193, 'detected': 1, 'recall': 0.0051813471502590676}, 'wrecks/ruins': {'total_gt': 153, 'detected': 2, 'recall': 0.013071895424836602}, 'robots': {'total_gt': 47, 'detected': 0, 'recall': 0.0}, 'sea-floor': {'total_gt': 68, 'detected': 3, 'recall': 0.04411764705882353}, 'aquatic plants': {'total_gt': 74, 'detected': 0, 'recall': 0.0}}}


In [10]:
# with open('gemini_2_5_pro_eval_results_label_agnostic.json', 'w') as f:
#     json.dump(results, f, indent=4)