In [13]:
import json

data = json.loads(open('USIS10K\\multi_class_annotations\\multi_class_test_annotations.json').read())

id_to_img = {}

for img in data['images']:
    id_to_img[img['id']] = img['file_name']

def load_data(path='USIS10K\\multi_class_annotations\\multi_class_test_annotations.json'):
    data = json.loads(open(path).read())
    polished_data = []

    for annotation in data['annotations']:
        category_id = annotation['category_id']
        image_id = annotation['image_id']
        bbox = annotation['bbox']
        segmentation = annotation['segmentation']
        img_path = f"USIS10K/test/{id_to_img[image_id]}"

        polished_data.append({
            'img_path': img_path,
            'category_id': category_id,
            'bbox': bbox,
            'segmentation': segmentation
        })

    return polished_data

def load_usis_preds_data(path='usis_sam_preds_rle.json'):
    data = json.loads(open(path).read())
    polished_data = []

    for annotation in data['predictions']:
        category_id = annotation['category_id']
        image_id = annotation['image_id']
        bbox = annotation['bbox']
        segmentation = annotation['segmentation']
        img_path = f"USIS10K/test/{id_to_img[image_id]}"

        polished_data.append({
            'img_path': img_path,
            'category_id': category_id,
            'bbox': bbox,
            'segmentation': segmentation
        })

    return polished_data

In [14]:
import json
import numpy as np
from pycocotools import mask as mask_utils
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from collections import defaultdict
from PIL import Image
from tqdm import tqdm

# Category mapping: prompt name -> category_id in USIS10K
# You may need to adjust this based on actual category IDs in your dataset
PROMPT_TO_CATEGORY = {
    "wrecks/ruins": 1,
    "fish": 2,  
    "reefs": 3,
    "aquatic plants": 4,
    "human divers": 5,
    "robots": 6,
    "sea-floor": 7,
}

# Check actual category IDs from the dataset
def get_category_mapping():
    data = json.load(open('USIS10K\\multi_class_annotations\\multi_class_test_annotations.json'))
    print("Categories in dataset:")
    for cat in data.get('categories', []):
        print(f"  ID {cat['id']}: {cat['name']}")
    return {cat['name']: cat['id'] for cat in data.get('categories', [])}

category_map = get_category_mapping()

Categories in dataset:
  ID 1: wrecks/ruins
  ID 2: fish
  ID 3: reefs
  ID 4: aquatic plants
  ID 5: human divers
  ID 6: robots
  ID 7: sea-floor


In [15]:
# Load ground truth in COCO format for evaluation
gt_data = load_data()
usis_data = load_usis_preds_data()
predictions = json.load(open('sam3_usis10k_preds_rle_simple_prompt.json'))
for pred in predictions:
    pred['img_path'] = pred['img_path'].replace('\\', '/')

In [16]:
PROMPT_TO_CATEGORY = {
    "wrecks/ruins": 1,
    "fish": 2,  
    "reefs": 3,
    "aquatic plants": 4,
    "human divers": 5,
    "robots": 6,
    "sea-floor": 7,
}

CATEGORY_TO_PROMPT = {v: k for k, v in PROMPT_TO_CATEGORY.items()}

In [17]:
def convert_usis_to_prediction_format(usis_data):
    """
    Convert USIS predictions from per-annotation format to grouped format matching SAM3 predictions.
    
    Input format (usis_data):
    [
        {'img_path': '...', 'category_id': 1, 'bbox': [x,y,w,h], 'segmentation': {...}},
        {'img_path': '...', 'category_id': 2, 'bbox': [x,y,w,h], 'segmentation': {...}},
        ...
    ]
    
    Output format (matching predictions):
    [
        {
            'img_path': '...',
            'fish': {'boxes': [[x1,y1,x2,y2], ...], 'scores': [...], 'masks_rle': [...]},
            'reefs': {'boxes': [...], 'scores': [...], 'masks_rle': [...]},
            ...
        },
        ...
    ]
    """
    from collections import defaultdict
    
    # Group annotations by image path
    img_annotations = defaultdict(list)
    for ann in usis_data:
        img_annotations[ann['img_path']].append(ann)
    
    # Convert to prediction format
    converted = []
    for img_path, annotations in img_annotations.items():
        pred = {'img_path': img_path}
        
        # Initialize all categories with empty lists
        for prompt in PROMPT_TO_CATEGORY.keys():
            pred[prompt] = {'boxes': [], 'scores': [], 'masks_rle': []}
        
        # Fill in annotations
        for ann in annotations:
            category_id = ann['category_id'] + 1
            prompt_name = CATEGORY_TO_PROMPT.get(category_id, f"unknown_{category_id}")
            
            if prompt_name not in pred:
                pred[prompt_name] = {'boxes': [], 'scores': [], 'masks_rle': []}
            
            # Convert bbox from [x, y, w, h] to [x1, y1, x2, y2]
            x, y, w, h = ann['bbox']
            box_xyxy = [x, y, x + w, y + h]
            pred[prompt_name]['boxes'].append(box_xyxy)
            
            # USIS predictions don't have scores, use 1.0 as default
            pred[prompt_name]['scores'].append(1.0)
            
            # Add segmentation mask (already in RLE format)
            if 'segmentation' in ann and ann['segmentation']:
                pred[prompt_name]['masks_rle'].append(ann['segmentation'])
        
        converted.append(pred)
    
    return converted


def align_predictions_by_image(predictions, usis_predictions):
    """
    Align SAM3 predictions and USIS predictions by image path so they can be compared.
    Returns two lists with matching order based on SAM3 predictions order.
    """
    # Create lookup for USIS predictions by image path
    usis_by_path = {p['img_path']: p for p in usis_predictions}
    
    aligned_usis = []
    for pred in predictions:
        img_path = pred['img_path']
        if img_path in usis_by_path:
            aligned_usis.append(usis_by_path[img_path])
        else:
            # Create empty prediction if not found
            empty_pred = {'img_path': img_path}
            for prompt in PROMPT_TO_CATEGORY.keys():
                empty_pred[prompt] = {'boxes': [], 'scores': [], 'masks_rle': []}
            aligned_usis.append(empty_pred)
    
    return aligned_usis


# Convert USIS data to match prediction format
usis_predictions = convert_usis_to_prediction_format(usis_data)

# Align USIS predictions with SAM3 predictions order
usis_predictions_aligned = align_predictions_by_image(predictions, usis_predictions)

In [18]:
import numpy as np
from pycocotools import mask as mask_utils


def calculate_iou(box1, box2):
    """
    Calculate IoU between two boxes.
    Boxes are in format [x_min, y_min, x_max, y_max]
    """
    x1_min, y1_min, x1_max, y1_max = box1
    x2_min, y2_min, x2_max, y2_max = box2
    
    # Calculate intersection
    inter_x_min = max(x1_min, x2_min)
    inter_y_min = max(y1_min, y2_min)
    inter_x_max = min(x1_max, x2_max)
    inter_y_max = min(y1_max, y2_max)
    
    if inter_x_max <= inter_x_min or inter_y_max <= inter_y_min:
        return 0.0
    
    inter_area = (inter_x_max - inter_x_min) * (inter_y_max - inter_y_min)
    
    # Calculate union
    box1_area = (x1_max - x1_min) * (y1_max - y1_min)
    box2_area = (x2_max - x2_min) * (y2_max - y2_min)
    union_area = box1_area + box2_area - inter_area
    
    return inter_area / union_area if union_area > 0 else 0.0


def gt_bbox_to_xyxy(bbox):
    """
    Convert ground truth bbox from [x, y, width, height] to [x_min, y_min, x_max, y_max]
    """
    x, y, w, h = bbox
    return [x, y, x + w, y + h]


def polygon_to_rle(segmentation, img_height, img_width):
    """
    Convert polygon segmentation to RLE format
    """
    rles = mask_utils.frPyObjects(segmentation, img_height, img_width)
    rle = mask_utils.merge(rles)
    return rle


def calculate_mask_iou(rle1, rle2):
    """
    Calculate IoU between two masks in RLE format
    """
    iou = mask_utils.iou([rle1], [rle2], [0])
    return iou[0][0]


def evaluate_recall(predictions, gt_data, iou_threshold=0.5, ignore_labels=False):
    """
    Evaluate recall: what fraction of ground truth boxes are detected?
    
    Args:
        predictions: List of prediction dictionaries
        gt_data: List of ground truth dictionaries
        iou_threshold: IoU threshold to consider a detection as correct
        ignore_labels: If True, check if any predicted box matches GT box regardless of category
    
    Returns:
        Dictionary with overall and per-category recall metrics
    """
    results = {
        'overall': {'total_gt': 0, 'detected': 0, 'recall': 0.0},
        'per_category': {}
    }
    
    for pred, gt in zip(predictions, gt_data):
        # Get ground truth category and bbox
        gt_category_id = gt['category_id']
        gt_category_name = CATEGORY_TO_PROMPT.get(gt_category_id, f"unknown_{gt_category_id}")
        gt_bbox_xyxy = gt_bbox_to_xyxy(gt['bbox'])
        
        # Initialize category stats if needed
        if gt_category_name not in results['per_category']:
            results['per_category'][gt_category_name] = {
                'total_gt': 0, 'detected': 0, 'recall': 0.0
            }
        
        # Update counts
        results['overall']['total_gt'] += 1
        results['per_category'][gt_category_name]['total_gt'] += 1
        
        # Check if any predicted box matches the ground truth
        detected = False
        
        if ignore_labels:
            # Check all predicted boxes across all categories
            for category_name, category_data in pred.items():
                if category_name == 'img_path':
                    continue
                pred_boxes = category_data.get('boxes', [])
                for pred_box in pred_boxes:
                    iou = calculate_iou(pred_box, gt_bbox_xyxy)
                    if iou >= iou_threshold:
                        detected = True
                        break
                if detected:
                    break
        else:
            # Only check predicted boxes for the matching category
            if gt_category_name in pred:
                pred_boxes = pred[gt_category_name]['boxes']
                for pred_box in pred_boxes:
                    iou = calculate_iou(pred_box, gt_bbox_xyxy)
                    if iou >= iou_threshold:
                        detected = True
                        break
        
        if detected:
            results['overall']['detected'] += 1
            results['per_category'][gt_category_name]['detected'] += 1
    
    # Calculate recall percentages
    if results['overall']['total_gt'] > 0:
        results['overall']['recall'] = results['overall']['detected'] / results['overall']['total_gt']
    
    for category in results['per_category']:
        cat_stats = results['per_category'][category]
        if cat_stats['total_gt'] > 0:
            cat_stats['recall'] = cat_stats['detected'] / cat_stats['total_gt']
    
    return results


def evaluate_recall_with_segmentation(predictions, gt_data, bbox_iou_threshold=0.5, 
                                      mask_iou_threshold=0.5, ignore_labels=False,
                                      img_height=480, img_width=640):
    """
    Evaluate recall for both bounding boxes and segmentation masks.
    For each GT box that's detected, also check if the segmentation mask matches.
    
    Args:
        predictions: List of prediction dictionaries
        gt_data: List of ground truth dictionaries
        bbox_iou_threshold: IoU threshold for bounding box matching
        mask_iou_threshold: IoU threshold for segmentation mask matching
        ignore_labels: If True, check any predicted box/mask regardless of category
        img_height: Image height for mask conversion
        img_width: Image width for mask conversion
    
    Returns:
        Dictionary with bbox and segmentation recall metrics
    """
    results = {
        'bbox': {'total_gt': 0, 'detected': 0, 'recall': 0.0},
        'segmentation': {'total_gt': 0, 'detected': 0, 'recall': 0.0},
        'per_category': {}
    }
    
    for pred, gt in zip(predictions, gt_data):
        # Get ground truth category and bbox
        gt_category_id = gt['category_id']
        gt_category_name = CATEGORY_TO_PROMPT.get(gt_category_id, f"unknown_{gt_category_id}")
        gt_bbox_xyxy = gt_bbox_to_xyxy(gt['bbox'])
        
        # Convert GT segmentation to RLE if available
        gt_has_segmentation = 'segmentation' in gt and gt['segmentation']
        if gt_has_segmentation:
            gt_rle = polygon_to_rle(gt['segmentation'], img_height, img_width)
        
        # Initialize category stats if needed
        if gt_category_name not in results['per_category']:
            results['per_category'][gt_category_name] = {
                'bbox': {'total_gt': 0, 'detected': 0, 'recall': 0.0},
                'segmentation': {'total_gt': 0, 'detected': 0, 'recall': 0.0}
            }
        
        # Update counts
        results['bbox']['total_gt'] += 1
        results['per_category'][gt_category_name]['bbox']['total_gt'] += 1
        
        if gt_has_segmentation:
            results['segmentation']['total_gt'] += 1
            results['per_category'][gt_category_name]['segmentation']['total_gt'] += 1
        
        # Check if any predicted box matches the ground truth
        bbox_detected = False
        mask_detected = False
        matched_pred_idx = -1
        matched_category = None
        
        if ignore_labels:
            # Check all predicted boxes across all categories
            for category_name, category_data in pred.items():
                if category_name == 'img_path':
                    continue
                pred_boxes = category_data.get('boxes', [])
                for idx, pred_box in enumerate(pred_boxes):
                    iou = calculate_iou(pred_box, gt_bbox_xyxy)
                    if iou >= bbox_iou_threshold:
                        bbox_detected = True
                        matched_pred_idx = idx
                        matched_category = category_name
                        break
                if bbox_detected:
                    break
        else:
            # Only check predicted boxes for the matching category
            if gt_category_name in pred:
                pred_boxes = pred[gt_category_name]['boxes']
                for idx, pred_box in enumerate(pred_boxes):
                    iou = calculate_iou(pred_box, gt_bbox_xyxy)
                    if iou >= bbox_iou_threshold:
                        bbox_detected = True
                        matched_pred_idx = idx
                        matched_category = gt_category_name
                        break
        
        # Update bbox detection counts
        if bbox_detected:
            results['bbox']['detected'] += 1
            results['per_category'][gt_category_name]['bbox']['detected'] += 1
            
            # If bbox was detected and GT has segmentation, check mask IoU
            if gt_has_segmentation and matched_category and matched_pred_idx >= 0:
                pred_masks_rle = pred[matched_category].get('masks_rle', [])
                if matched_pred_idx < len(pred_masks_rle):
                    pred_rle = pred_masks_rle[matched_pred_idx]
                    mask_iou = calculate_mask_iou(pred_rle, gt_rle)
                    
                    if mask_iou >= mask_iou_threshold:
                        mask_detected = True
                        results['segmentation']['detected'] += 1
                        results['per_category'][gt_category_name]['segmentation']['detected'] += 1
    
    # Calculate recall percentages
    if results['bbox']['total_gt'] > 0:
        results['bbox']['recall'] = results['bbox']['detected'] / results['bbox']['total_gt']
    
    if results['segmentation']['total_gt'] > 0:
        results['segmentation']['recall'] = results['segmentation']['detected'] / results['segmentation']['total_gt']
    
    for category in results['per_category']:
        cat_stats = results['per_category'][category]
        if cat_stats['bbox']['total_gt'] > 0:
            cat_stats['bbox']['recall'] = cat_stats['bbox']['detected'] / cat_stats['bbox']['total_gt']
        if cat_stats['segmentation']['total_gt'] > 0:
            cat_stats['segmentation']['recall'] = cat_stats['segmentation']['detected'] / cat_stats['segmentation']['total_gt']
    
    return results


def print_results(results):
    """Pretty print evaluation results"""
    print("=" * 60)
    print("RECALL EVALUATION RESULTS")
    print("=" * 60)
    
    print(f"\nOverall Recall:")
    print(f"  Detected: {results['overall']['detected']}/{results['overall']['total_gt']}")
    print(f"  Recall: {results['overall']['recall']:.2%}")
    
    print(f"\nPer-Category Recall:")
    for category, stats in sorted(results['per_category'].items()):
        print(f"  {category}:")
        print(f"    Detected: {stats['detected']}/{stats['total_gt']}")
        print(f"    Recall: {stats['recall']:.2%}")
    print("=" * 60)


def print_segmentation_results(results):
    """Pretty print evaluation results including segmentation"""
    print("=" * 60)
    print("BBOX AND SEGMENTATION RECALL EVALUATION")
    print("=" * 60)
    
    print(f"\nOverall Bounding Box Recall:")
    print(f"  Detected: {results['bbox']['detected']}/{results['bbox']['total_gt']}")
    print(f"  Recall: {results['bbox']['recall']:.2%}")
    
    print(f"\nOverall Segmentation Recall:")
    print(f"  Detected: {results['segmentation']['detected']}/{results['segmentation']['total_gt']}")
    print(f"  Recall: {results['segmentation']['recall']:.2%}")
    
    print(f"\nPer-Category Results:")
    for category, stats in sorted(results['per_category'].items()):
        print(f"\n  {category}:")
        print(f"    BBox - Detected: {stats['bbox']['detected']}/{stats['bbox']['total_gt']}, "
              f"Recall: {stats['bbox']['recall']:.2%}")
        print(f"    Seg  - Detected: {stats['segmentation']['detected']}/{stats['segmentation']['total_gt']}, "
              f"Recall: {stats['segmentation']['recall']:.2%}")
    print("=" * 60)



In [19]:
def compare_results(results1, results2, name1="Results 1", name2="Results 2"):
    """
    Compare two detection results dictionaries and display differences.
    
    Args:
        results1: First results dictionary
        results2: Second results dictionary
        name1: Label for first results
        name2: Label for second results
    """
    print(f"\n{'='*80}")
    print(f"COMPARISON: {name1} vs {name2}")
    print(f"{'='*80}\n")
    
    # Compare overall metrics
    print("OVERALL METRICS")
    print("-" * 80)
    overall1 = results1['overall']
    overall2 = results2['overall']
    
    for metric in ['total_gt', 'detected', 'recall']:
        val1 = overall1[metric]
        val2 = overall2[metric]
        diff = val2 - val1
        
        if metric == 'recall':
            pct_diff = ((val2 - val1) / val1 * 100) if val1 != 0 else 0
            print(f"{metric:12s}: {val1:.4f} → {val2:.4f} (Δ {diff:+.4f}, {pct_diff:+.2f}%)")
        else:
            pct_diff = ((val2 - val1) / val1 * 100) if val1 != 0 else 0
            print(f"{metric:12s}: {val1:6d} → {val2:6d} (Δ {diff:+6d}, {pct_diff:+.2f}%)")
    
    # Compare per-category metrics
    print("\n\nPER-CATEGORY METRICS")
    print("-" * 80)
    
    # Get all categories from both results
    categories1 = set(results1['per_category'].keys())
    categories2 = set(results2['per_category'].keys())
    all_categories = sorted(categories1 | categories2)
    
    for category in all_categories:
        cat1 = results1['per_category'].get(category)
        cat2 = results2['per_category'].get(category)
        
        print(f"\n{category.upper()}")
        
        if cat1 is None:
            print(f"  Only in {name2}: {cat2}")
            continue
        if cat2 is None:
            print(f"  Only in {name1}: {cat1}")
            continue
        
        for metric in ['total_gt', 'detected', 'recall']:
            val1 = cat1[metric]
            val2 = cat2[metric]
            diff = val2 - val1
            
            if metric == 'recall':
                pct_diff = ((val2 - val1) / val1 * 100) if val1 != 0 else 0
                print(f"  {metric:12s}: {val1:.4f} → {val2:.4f} (Δ {diff:+.4f}, {pct_diff:+.2f}%)")
            else:
                pct_diff = ((val2 - val1) / val1 * 100) if val1 != 0 else 0
                print(f"  {metric:12s}: {val1:6d} → {val2:6d} (Δ {diff:+6d}, {pct_diff:+.2f}%)")
    
    print(f"\n{'='*80}\n")

In [20]:
results_sam3 = evaluate_recall(predictions, gt_data, iou_threshold=0.5)
results_usis = evaluate_recall(usis_predictions_aligned, gt_data, iou_threshold=0.5)

compare_results(results_sam3, results_usis, name1="SAM3 Predictions", name2="USIS Predictions")


COMPARISON: SAM3 Predictions vs USIS Predictions

OVERALL METRICS
--------------------------------------------------------------------------------
total_gt    :   2860 →   2860 (Δ     +0, +0.00%)
detected    :   1548 →   2364 (Δ   +816, +52.71%)
recall      : 0.5413 → 0.8266 (Δ +0.2853, +52.71%)


PER-CATEGORY METRICS
--------------------------------------------------------------------------------

AQUATIC PLANTS
  total_gt    :     74 →     74 (Δ     +0, +0.00%)
  detected    :     32 →     22 (Δ    -10, -31.25%)
  recall      : 0.4324 → 0.2973 (Δ -0.1351, -31.25%)

FISH
  total_gt    :   1566 →   1566 (Δ     +0, +0.00%)
  detected    :   1365 →   1414 (Δ    +49, +3.59%)
  recall      : 0.8716 → 0.9029 (Δ +0.0313, +3.59%)

HUMAN DIVERS
  total_gt    :    193 →    193 (Δ     +0, +0.00%)
  detected    :      2 →    181 (Δ   +179, +8950.00%)
  recall      : 0.0104 → 0.9378 (Δ +0.9275, +8950.00%)

REEFS
  total_gt    :    759 →    759 (Δ     +0, +0.00%)
  detected    :    141 →    564 (Δ

In [21]:
results_sam3 = evaluate_recall(predictions, gt_data, iou_threshold=0.5, ignore_labels=True)
results_usis = evaluate_recall(usis_predictions_aligned, gt_data, iou_threshold=0.5, ignore_labels=True)

compare_results(results_sam3, results_usis, name1="SAM3 Predictions", name2="USIS Predictions")


COMPARISON: SAM3 Predictions vs USIS Predictions

OVERALL METRICS
--------------------------------------------------------------------------------
total_gt    :   2860 →   2860 (Δ     +0, +0.00%)
detected    :   1728 →   2515 (Δ   +787, +45.54%)
recall      : 0.6042 → 0.8794 (Δ +0.2752, +45.54%)


PER-CATEGORY METRICS
--------------------------------------------------------------------------------

AQUATIC PLANTS
  total_gt    :     74 →     74 (Δ     +0, +0.00%)
  detected    :     36 →     53 (Δ    +17, +47.22%)
  recall      : 0.4865 → 0.7162 (Δ +0.2297, +47.22%)

FISH
  total_gt    :   1566 →   1566 (Δ     +0, +0.00%)
  detected    :   1379 →   1449 (Δ    +70, +5.08%)
  recall      : 0.8806 → 0.9253 (Δ +0.0447, +5.08%)

HUMAN DIVERS
  total_gt    :    193 →    193 (Δ     +0, +0.00%)
  detected    :      2 →    188 (Δ   +186, +9300.00%)
  recall      : 0.0104 → 0.9741 (Δ +0.9637, +9300.00%)

REEFS
  total_gt    :    759 →    759 (Δ     +0, +0.00%)
  detected    :    273 →    602 (Δ

In [23]:
def compare_bbox_seg_results(results1, results2, name1="Results 1", name2="Results 2"):
    """
    Compare two detection/segmentation results dictionaries and display differences.
    
    Args:
        results1: First results dictionary with bbox and segmentation metrics
        results2: Second results dictionary with bbox and segmentation metrics
        name1: Label for first results
        name2: Label for second results
    """
    print(f"\n{'='*90}")
    print(f"COMPARISON: {name1} vs {name2}")
    print(f"{'='*90}\n")
    
    # Compare bbox and segmentation overall metrics
    for task_type in ['bbox', 'segmentation']:
        print(f"{task_type.upper()} - OVERALL METRICS")
        print("-" * 90)
        
        task1 = results1[task_type]
        task2 = results2[task_type]
        
        for metric in ['total_gt', 'detected', 'recall']:
            val1 = task1[metric]
            val2 = task2[metric]
            diff = val2 - val1
            
            if metric == 'recall':
                pct_diff = ((val2 - val1) / val1 * 100) if val1 != 0 else 0
                print(f"  {metric:12s}: {val1:.4f} → {val2:.4f} (Δ {diff:+.4f}, {pct_diff:+.2f}%)")
            else:
                pct_diff = ((val2 - val1) / val1 * 100) if val1 != 0 else 0
                print(f"  {metric:12s}: {val1:6d} → {val2:6d} (Δ {diff:+6d}, {pct_diff:+.2f}%)")
        print()
    
    # Compare per-category metrics
    print("\nPER-CATEGORY METRICS")
    print("-" * 90)
    
    # Get all categories from both results
    categories1 = set(results1['per_category'].keys())
    categories2 = set(results2['per_category'].keys())
    all_categories = sorted(categories1 | categories2)
    
    for category in all_categories:
        cat1 = results1['per_category'].get(category)
        cat2 = results2['per_category'].get(category)
        
        print(f"\n{category.upper()}")
        
        if cat1 is None:
            print(f"  Only in {name2}")
            continue
        if cat2 is None:
            print(f"  Only in {name1}")
            continue
        
        for task_type in ['bbox', 'segmentation']:
            print(f"  {task_type}:")
            
            task1 = cat1[task_type]
            task2 = cat2[task_type]
            
            for metric in ['total_gt', 'detected', 'recall']:
                val1 = task1[metric]
                val2 = task2[metric]
                diff = val2 - val1
                
                if metric == 'recall':
                    pct_diff = ((val2 - val1) / val1 * 100) if val1 != 0 else 0
                    print(f"    {metric:12s}: {val1:.4f} → {val2:.4f} (Δ {diff:+.4f}, {pct_diff:+.2f}%)")
                else:
                    pct_diff = ((val2 - val1) / val1 * 100) if val1 != 0 else 0
                    print(f"    {metric:12s}: {val1:6d} → {val2:6d} (Δ {diff:+6d}, {pct_diff:+.2f}%)")
    
    print(f"\n{'='*90}\n")


def summarize_bbox_seg_differences(results1, results2, name1="Results 1", name2="Results 2"):
    """
    Print a concise summary highlighting the biggest differences.
    """
    print(f"\n{'='*90}")
    print(f"SUMMARY OF KEY DIFFERENCES: {name1} vs {name2}")
    print(f"{'='*90}\n")
    
    # Overall recall improvements
    print("OVERALL RECALL CHANGES:")
    for task_type in ['bbox', 'segmentation']:
        r1 = results1[task_type]['recall']
        r2 = results2[task_type]['recall']
        diff = r2 - r1
        pct_diff = (diff / r1 * 100) if r1 != 0 else 0
        status = "↑ IMPROVED" if diff > 0 else "↓ DECLINED" if diff < 0 else "→ UNCHANGED"
        print(f"  {task_type:15s}: {r1:.4f} → {r2:.4f} ({diff:+.4f}, {pct_diff:+.2f}%) {status}")
    
    # Find categories with biggest recall changes
    print("\n\nCATEGORIES WITH LARGEST BBOX RECALL CHANGES:")
    bbox_changes = []
    for category in results1['per_category'].keys():
        if category in results2['per_category']:
            r1 = results1['per_category'][category]['bbox']['recall']
            r2 = results2['per_category'][category]['bbox']['recall']
            diff = r2 - r1
            bbox_changes.append((category, r1, r2, diff))
    
    bbox_changes.sort(key=lambda x: abs(x[3]), reverse=True)
    for cat, r1, r2, diff in bbox_changes[:5]:
        pct_diff = (diff / r1 * 100) if r1 != 0 else 0
        status = "↑" if diff > 0 else "↓" if diff < 0 else "→"
        print(f"  {status} {cat:20s}: {r1:.4f} → {r2:.4f} ({diff:+.4f}, {pct_diff:+.2f}%)")
    
    print("\n\nCATEGORIES WITH LARGEST SEGMENTATION RECALL CHANGES:")
    seg_changes = []
    for category in results1['per_category'].keys():
        if category in results2['per_category']:
            r1 = results1['per_category'][category]['segmentation']['recall']
            r2 = results2['per_category'][category]['segmentation']['recall']
            diff = r2 - r1
            seg_changes.append((category, r1, r2, diff))
    
    seg_changes.sort(key=lambda x: abs(x[3]), reverse=True)
    for cat, r1, r2, diff in seg_changes[:5]:
        pct_diff = (diff / r1 * 100) if r1 != 0 else 0
        status = "↑" if diff > 0 else "↓" if diff < 0 else "→"
        print(f"  {status} {cat:20s}: {r1:.4f} → {r2:.4f} ({diff:+.4f}, {pct_diff:+.2f}%)")
    
    print(f"\n{'='*90}\n")

In [24]:
seg_results_sam3 = evaluate_recall_with_segmentation(
    predictions, gt_data, 
    bbox_iou_threshold=0.5, 
    mask_iou_threshold=0.5,
    ignore_labels=False
)

seg_results_usis = evaluate_recall_with_segmentation(
    usis_predictions_aligned, gt_data, 
    bbox_iou_threshold=0.5, 
    mask_iou_threshold=0.5,
    ignore_labels=False
)

# Full detailed comparison
# compare_bbox_seg_results(seg_results_sam3, seg_results_usis, "SAM3 Predictions", "USIS Predictions")

# Concise summary
summarize_bbox_seg_differences(seg_results_sam3, seg_results_usis, "SAM3 Predictions", "USIS Predictions")


SUMMARY OF KEY DIFFERENCES: SAM3 Predictions vs USIS Predictions

OVERALL RECALL CHANGES:
  bbox           : 0.5413 → 0.8266 (+0.2853, +52.71%) ↑ IMPROVED
  segmentation   : 0.4857 → 0.7168 (+0.2311, +47.59%) ↑ IMPROVED


CATEGORIES WITH LARGEST BBOX RECALL CHANGES:
  ↑ human divers        : 0.0104 → 0.9378 (+0.9275, +8950.00%)
  ↑ wrecks/ruins        : 0.0261 → 0.7908 (+0.7647, +2925.00%)
  ↑ robots              : 0.0851 → 0.7234 (+0.6383, +750.00%)
  ↑ reefs               : 0.1858 → 0.7431 (+0.5573, +300.00%)
  ↑ sea-floor           : 0.0000 → 0.4118 (+0.4118, +0.00%)


CATEGORIES WITH LARGEST SEGMENTATION RECALL CHANGES:
  ↑ human divers        : 0.0104 → 0.8446 (+0.8342, +8050.00%)
  ↑ wrecks/ruins        : 0.0261 → 0.7320 (+0.7059, +2700.00%)
  ↑ robots              : 0.0851 → 0.6170 (+0.5319, +625.00%)
  ↑ reefs               : 0.1449 → 0.5665 (+0.4216, +290.91%)
  ↑ sea-floor           : 0.0000 → 0.3088 (+0.3088, +0.00%)


