In [16]:
import json
import PIL.Image as Image

data = json.loads(open('USIS10K\\multi_class_annotations\\multi_class_test_annotations.json').read())

id_to_img = {}

for img in data['images']:
    id_to_img[img['id']] = img['file_name']

from collections import defaultdict

from collections import defaultdict
import numpy as np
from pycocotools import mask as mask_utils
from PIL import Image, ImageDraw

PROMPT_TO_CATEGORY = {
    "wrecks/ruins": 1,
    "fish": 2,
    "reefs": 3,
    "aquatic plants": 4,
    "human divers": 5,
    "robots": 6,
    "sea-floor": 7,
}

CATEGORY_TO_PROMPT = {v: k for k, v in PROMPT_TO_CATEGORY.items()}


def polygon_to_rle(segmentation, img_height, img_width):
    # segmentation may be: [x1,y1,...] or [[...]] or list of polygons
    if len(segmentation) == 0:
        return None

    # Normalize to list of polygons
    if isinstance(segmentation[0], list):
        polys = segmentation
    else:
        polys = [segmentation]

    mask = Image.new("L", (img_width, img_height), 0)
    draw = ImageDraw.Draw(mask)

    for poly in polys:
        coords = np.array(poly).reshape(-1, 2)
        draw.polygon(coords.flatten().tolist(), outline=1, fill=1)

    mask_np = np.array(mask, dtype=np.uint8)
    rle = mask_utils.encode(np.asfortranarray(mask_np))
    rle["counts"] = rle["counts"].decode("ascii")

    return rle


def group_items_to_list(items):
    grouped = defaultdict(lambda: defaultdict(lambda: {"boxes": [], "masks_rle": []}))

    for d in items:
        img_path = d["img_path"]
        img_loaded = Image.open(img_path)

        cat = CATEGORY_TO_PROMPT[d["category_id"]]

        grouped[img_path][cat]["boxes"].append(d["bbox"])
        grouped[img_path][cat]["masks_rle"].append(
            polygon_to_rle(d["segmentation"], img_loaded.height, img_loaded.width)
        )

    out = []
    for img_path, cats in grouped.items():
        entry = {"img_path": img_path}
        entry.update(cats)
        out.append(entry)

    return out


def load_data(path='USIS10K\\multi_class_annotations\\multi_class_test_annotations.json'):
    data = json.loads(open(path).read())
    polished_data = []

    for annotation in data['annotations']:
        category_id = annotation['category_id']
        image_id = annotation['image_id']
        bbox = annotation['bbox']
        segmentation = annotation['segmentation']
        img_path = f"USIS10K/test/{id_to_img[image_id]}"

        polished_data.append({
            'img_path': img_path,
            'category_id': category_id,
            'bbox': bbox,
            'segmentation': segmentation
        })

    return group_items_to_list(polished_data)

def load_usis_preds_data(path='usis_sam_preds_rle.json'):
    data = json.loads(open(path).read())
    polished_data = []

    for annotation in data['predictions']:
        category_id = annotation['category_id']
        image_id = annotation['image_id']
        bbox = annotation['bbox']
        segmentation = annotation['segmentation']
        img_path = f"USIS10K/test/{id_to_img[image_id]}"

        polished_data.append({
            'img_path': img_path,
            'category_id': category_id,
            'bbox': bbox,
            'segmentation': segmentation
        })

    return polished_data

In [17]:
import json
import numpy as np
from pycocotools import mask as mask_utils
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from collections import defaultdict
from PIL import Image
from tqdm import tqdm

# Category mapping: prompt name -> category_id in USIS10K
# You may need to adjust this based on actual category IDs in your dataset
PROMPT_TO_CATEGORY = {
    "wrecks/ruins": 1,
    "fish": 2,  
    "reefs": 3,
    "aquatic plants": 4,
    "human divers": 5,
    "robots": 6,
    "sea-floor": 7,
}

# Check actual category IDs from the dataset
def get_category_mapping():
    data = json.load(open('USIS10K\\multi_class_annotations\\multi_class_test_annotations.json'))
    print("Categories in dataset:")
    for cat in data.get('categories', []):
        print(f"  ID {cat['id']}: {cat['name']}")
    return {cat['name']: cat['id'] for cat in data.get('categories', [])}

category_map = get_category_mapping()

Categories in dataset:
  ID 1: wrecks/ruins
  ID 2: fish
  ID 3: reefs
  ID 4: aquatic plants
  ID 5: human divers
  ID 6: robots
  ID 7: sea-floor


In [18]:
# Load ground truth in COCO format for evaluation
gt_data = load_data()
usis_data = load_usis_preds_data()
predictions = json.load(open('sam3_usis10k_preds_rle.json'))
for pred in predictions:
    pred['img_path'] = pred['img_path'].replace('\\', '/')

In [19]:
PROMPT_TO_CATEGORY = {
    "wrecks/ruins": 1,
    "fish": 2,  
    "reefs": 3,
    "aquatic plants": 4,
    "human divers": 5,
    "robots": 6,
    "sea-floor": 7,
}

CATEGORY_TO_PROMPT = {v: k for k, v in PROMPT_TO_CATEGORY.items()}

In [20]:
def calculate_iou(box1, box2):
    """
    Calculate IoU between two boxes.
    Boxes are in format [x_min, y_min, x_max, y_max]
    """
    x1_min, y1_min, x1_max, y1_max = box1
    x2_min, y2_min, x2_max, y2_max = box2
    
    # Calculate intersection
    inter_x_min = max(x1_min, x2_min)
    inter_y_min = max(y1_min, y2_min)
    inter_x_max = min(x1_max, x2_max)
    inter_y_max = min(y1_max, y2_max)
    
    if inter_x_max <= inter_x_min or inter_y_max <= inter_y_min:
        return 0.0
    
    inter_area = (inter_x_max - inter_x_min) * (inter_y_max - inter_y_min)
    
    # Calculate union
    box1_area = (x1_max - x1_min) * (y1_max - y1_min)
    box2_area = (x2_max - x2_min) * (y2_max - y2_min)
    union_area = box1_area + box2_area - inter_area
    
    return inter_area / union_area if union_area > 0 else 0.0

def gt_bbox_to_xyxy(bbox):
    """
    Convert ground truth bbox from [x, y, width, height] to [x_min, y_min, x_max, y_max]
    """
    x, y, w, h = bbox
    return [x, y, x + w, y + h]


def polygon_to_rle(segmentation, img_height, img_width):
    """
    Convert polygon segmentation to RLE format
    """
    rles = mask_utils.frPyObjects(segmentation, img_height, img_width)
    rle = mask_utils.merge(rles)
    return rle


def calculate_mask_iou(rle1, rle2):
    """
    Calculate IoU between two masks in RLE format
    """
    iou = mask_utils.iou([rle1], [rle2], [0])
    return iou[0][0]

def evaluate_recall(predictions, gt_data, iou_threshold=0.5, ignore_labels=False):
    results = {
        'overall': {'total_gt': 0, 'detected': 0, 'recall': 0.0},
        'per_category': {}
    }

    for pred_item, gt_item in zip(predictions, gt_data):
        # Iterate through every category in this ground truth item
        for gt_category_name, gt_category_data in gt_item.items():
            if gt_category_name == "img_path":
                continue

            gt_boxes = gt_category_data["boxes"]

            # Ensure category stats exist
            if gt_category_name not in results['per_category']:
                results['per_category'][gt_category_name] = {
                    'total_gt': 0, 'detected': 0, 'recall': 0.0
                }

            for gt_box in gt_boxes:
                gt_xyxy = gt_bbox_to_xyxy(gt_box)

                results['overall']['total_gt'] += 1
                results['per_category'][gt_category_name]['total_gt'] += 1

                detected = False

                if ignore_labels:
                    # search in ALL predicted categories
                    for pred_cat_name, pred_data in pred_item.items():
                        if pred_cat_name == "img_path":
                            continue
                        for pred_box in pred_data["boxes"]:
                            if calculate_iou(pred_box, gt_xyxy) >= iou_threshold:
                                detected = True
                                break
                        if detected:
                            break
                else:
                    # search only in the SAME predicted category
                    if gt_category_name in pred_item:
                        for pred_box in pred_item[gt_category_name]["boxes"]:
                            if calculate_iou(pred_box, gt_xyxy) >= iou_threshold:
                                detected = True
                                break

                if detected:
                    results['overall']['detected'] += 1
                    results['per_category'][gt_category_name]['detected'] += 1

    # Final recall calculations
    if results['overall']['total_gt'] > 0:
        results['overall']['recall'] = (
            results['overall']['detected'] / results['overall']['total_gt']
        )

    for cat_name, s in results['per_category'].items():
        if s['total_gt'] > 0:
            s['recall'] = s['detected'] / s['total_gt']

    return results


In [21]:
results_sam3 = evaluate_recall(predictions, gt_data, iou_threshold=0.5, ignore_labels=False)

In [22]:
results_sam3

{'overall': {'total_gt': 2860, 'detected': 998, 'recall': 0.34895104895104895},
 'per_category': {'fish': {'total_gt': 1566,
   'detected': 593,
   'recall': 0.3786717752234994},
  'reefs': {'total_gt': 759, 'detected': 292, 'recall': 0.3847167325428195},
  'human divers': {'total_gt': 193,
   'detected': 77,
   'recall': 0.39896373056994816},
  'wrecks/ruins': {'total_gt': 153, 'detected': 0, 'recall': 0.0},
  'robots': {'total_gt': 47, 'detected': 19, 'recall': 0.40425531914893614},
  'sea-floor': {'total_gt': 68, 'detected': 11, 'recall': 0.16176470588235295},
  'aquatic plants': {'total_gt': 74,
   'detected': 6,
   'recall': 0.08108108108108109}}}

In [23]:
with open("sam3_eval_results.json", "w") as f:
    json.dump(results_sam3, f, indent=4)