In [4]:
import json
import numpy as np
from PIL import Image
from pycocotools import mask as maskUtils
import cv2

# =========================
# CONFIG
# =========================
gt_json_path = "./flatbug-dataset/cao2022/instances_default.json"
pred_json_path = "./flatbug-dataset/cao2022/sam3_results.json"
IOU_THRESHOLD = 0.5
MAX_DILATION = 25 # max pixels to try

# =========================
# UTILS
# =========================
def dilate_mask(mask, dilation_pixels=3):
    kernel = np.ones((dilation_pixels, dilation_pixels), np.uint8)
    return cv2.dilate(mask.astype(np.uint8), kernel, iterations=1)

def mask_from_seg(segmentation, h, w):
    rle = maskUtils.frPyObjects(segmentation, h, w)
    mask = maskUtils.decode(rle)
    if len(mask.shape) == 3:
        mask = np.any(mask, axis=2)
    return mask.astype(np.uint8)

def compute_metrics(pred_masks, gt_masks, iou_thresh=0.5):
    TP = FP = FN = 0
    matched_gt = set()

    for p_mask in pred_masks:
        # ensure proper type for pycocotools
        p_mask_enc = maskUtils.encode(np.asfortranarray(p_mask.astype(np.uint8)))
        ious = []
        for g_mask in gt_masks:
            g_mask_enc = maskUtils.encode(np.asfortranarray(g_mask.astype(np.uint8)))
            iou_val = maskUtils.iou([p_mask_enc], [g_mask_enc], [0])[0]
            ious.append(iou_val)

        if len(ious) == 0:
            FP += 1
            continue

        max_iou_idx = np.argmax(ious)
        if ious[max_iou_idx] >= iou_thresh:
            TP += 1
            matched_gt.add(max_iou_idx)
        else:
            FP += 1

    FN = len(gt_masks) - len(matched_gt)
    precision = TP / (TP + FP + 1e-6)
    recall = TP / (TP + FN + 1e-6)
    f1 = 2 * precision * recall / (precision + recall + 1e-6)
    return TP, FP, FN, precision, recall, f1

# =========================
# LOAD JSON
# =========================
with open(gt_json_path) as f:
    gt_data = json.load(f)

with open(pred_json_path) as f:
    pred_data = json.load(f)

# -------------------------
# Map image_id → (height, width)
# -------------------------
image_sizes = {img["id"]: (img["height"], img["width"]) for img in gt_data["images"]}

# -------------------------
# Load GT masks
# -------------------------
gt_by_image = {}
for ann in gt_data["annotations"]:
    img_id = ann["image_id"]
    h, w = image_sizes[img_id]
    mask = mask_from_seg(ann["segmentation"], h, w)
    gt_by_image.setdefault(img_id, []).append(mask)

# -------------------------
# Load predicted masks
# -------------------------
pred_by_image = {}
for ann in pred_data["annotations"]:
    img_id = ann["image_id"]
    h, w = image_sizes[img_id]
    mask = mask_from_seg(ann["segmentation"], h, w)
    pred_by_image.setdefault(img_id, []).append(mask)

# =========================
# SEARCH BEST DILATION
# =========================
best_f1 = 0
best_dilation = 0

print("\nSearching best dilation...\n")

for dilation in range(12, MAX_DILATION + 1):
    total_TP = total_FP = total_FN = 0
    for img_id, gt_masks in gt_by_image.items():
        pred_masks = pred_by_image.get(img_id, [])
        # apply dilation
        pred_masks_dilated = [dilate_mask(m, dilation) for m in pred_masks]
        TP, FP, FN, _, _, _ = compute_metrics(pred_masks_dilated, gt_masks, IOU_THRESHOLD)
        total_TP += TP
        total_FP += FP
        total_FN += FN

    precision = total_TP / (total_TP + total_FP + 1e-6)
    recall = total_TP / (total_TP + total_FN + 1e-6)
    f1 = 2 * precision * recall / (precision + recall + 1e-6)

    print(f"Dilation {dilation}: TP={total_TP}, FP={total_FP}, FN={total_FN}, "
          f"Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}")

    if f1 > best_f1:
        best_f1 = f1
        best_dilation = dilation

print(f"\n✅ Best dilation: {best_dilation}, F1={best_f1:.4f}")



Searching best dilation...

Dilation 12: TP=592, FP=36, FN=13, Precision=0.9427, Recall=0.9785, F1=0.9603
Dilation 13: TP=593, FP=35, FN=12, Precision=0.9443, Recall=0.9802, F1=0.9619
Dilation 14: TP=591, FP=37, FN=14, Precision=0.9411, Recall=0.9769, F1=0.9586
Dilation 15: TP=587, FP=41, FN=19, Precision=0.9347, Recall=0.9686, F1=0.9514
Dilation 16: TP=587, FP=41, FN=19, Precision=0.9347, Recall=0.9686, F1=0.9514
Dilation 17: TP=581, FP=47, FN=25, Precision=0.9252, Recall=0.9587, F1=0.9417
Dilation 18: TP=577, FP=51, FN=29, Precision=0.9188, Recall=0.9521, F1=0.9352
Dilation 19: TP=575, FP=53, FN=32, Precision=0.9156, Recall=0.9473, F1=0.9312
Dilation 20: TP=567, FP=61, FN=40, Precision=0.9029, Recall=0.9341, F1=0.9182
Dilation 21: TP=559, FP=69, FN=48, Precision=0.8901, Recall=0.9209, F1=0.9053
Dilation 22: TP=551, FP=77, FN=56, Precision=0.8774, Recall=0.9077, F1=0.8923
Dilation 23: TP=536, FP=92, FN=71, Precision=0.8535, Recall=0.8830, F1=0.8680
Dilation 24: TP=523, FP=105, FN=83,

In [None]:
import os
import json
import numpy as np
from collections import defaultdict
from pycocotools import mask as maskUtils
import cv2
from PIL import Image, ImageDraw, ImageFont

# ==========================
# CONFIG
# ==========================
root_dataset = "./flatbug-dataset"
output_overlay_root = "./dilated_metrics"

datasets_to_eval = {
    "nhm-beetles-crops",
    "cao2022",
    "gernat2018",
    "sittinger2023",
    "amarathunga2022",
    "biodiscover-arm",
}

IOU_THRESHOLD = 0.5
MAX_DILATION = 50
NUM_VIS_IMAGES = 2  # number of images to visualize per dataset

# ==========================
# UTILITIES
# ==========================
def seg_to_rle_mask(segmentation, H, W):
    if segmentation is None:
        return np.zeros((H, W), dtype=np.uint8)
    if isinstance(segmentation, dict):
        return maskUtils.decode(segmentation).astype(np.uint8)
    if isinstance(segmentation, list):
        rle = maskUtils.frPyObjects(segmentation, H, W)
        rle = maskUtils.merge(rle)
        return maskUtils.decode(rle).astype(np.uint8)
    return np.zeros((H, W), dtype=np.uint8)

def dilate_mask(mask, dilation_pixels):
    if dilation_pixels <= 0:
        return mask
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_pixels, dilation_pixels))
    return cv2.dilate(mask, kernel, iterations=1)

def compute_metrics(pred_masks, gt_masks, iou_thresh=0.5):
    TP = FP = 0
    matched_gt = set()
    for p_mask in pred_masks:
        if p_mask is None:
            continue
        p_mask_rle = maskUtils.encode(np.asfortranarray(p_mask.astype(np.uint8)))
        ious = []
        for g_mask in gt_masks:
            if g_mask is None:
                continue
            g_mask_rle = maskUtils.encode(np.asfortranarray(g_mask.astype(np.uint8)))
            iou_val = maskUtils.iou([p_mask_rle], [g_mask_rle], [0])[0]
            ious.append(iou_val)
        if not ious or max(ious) < iou_thresh:
            FP += 1
        else:
            best_idx = np.argmax(ious)
            if best_idx not in matched_gt:
                TP += 1
                matched_gt.add(best_idx)
            else:
                FP += 1
    FN = len(gt_masks) - len(matched_gt)
    precision = TP / (TP + FP + 1e-6)
    recall = TP / (TP + FN + 1e-6)
    f1 = 2 * precision * recall / (precision + recall + 1e-6)
    return TP, FP, FN, precision, recall, f1

def overlay_masks_on_image(image_path, gt_masks, pred_masks, save_path):
    """
    Overlay GT and Pred masks on original image, distinct colors, labels.
    """
    image = Image.open(image_path).convert("RGB")
    overlay = Image.new("RGBA", image.size, (0,0,0,0))
    draw = ImageDraw.Draw(overlay)

    # colors
    gt_color = (0,255,0,120)      # semi-transparent green
    pred_color = (30,144,255,120) # semi-transparent blue

    # draw GT masks
    for m in gt_masks:
        ys, xs = np.where(m>0)
        for x, y in zip(xs, ys):
            draw.point((x,y), fill=gt_color)

    # draw Pred masks
    for m in pred_masks:
        ys, xs = np.where(m>0)
        for x, y in zip(xs, ys):
            draw.point((x,y), fill=pred_color)

    # composite and add labels
    final = Image.alpha_composite(image.convert("RGBA"), overlay)
    draw_final = ImageDraw.Draw(final)
    try:
        font = ImageFont.truetype("arial.ttf", 20)
    except:
        font = ImageFont.load_default()

    draw_final.text((10,10), "GT", fill=(0,255,0), font=font)
    draw_final.text((10,35), "Pred", fill=(30,144,255), font=font)

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    final.convert("RGB").save(save_path)

# ==========================
# MAIN LOOP
# ==========================
best_dilations = {}

for dataset_name in sorted(os.listdir(root_dataset)):
    if dataset_name.lower() not in datasets_to_eval:
        continue

    dataset_path = os.path.join(root_dataset, dataset_name)
    gt_file = os.path.join(dataset_path, "instances_default.json")
    pred_file = os.path.join(dataset_path, "sam3_results.json")

    if not os.path.isfile(gt_file) or not os.path.isfile(pred_file):
        print(f"❌ Missing GT or SAM3 predictions in {dataset_name}, skipping.")
        continue

    print(f"\n======================\nEvaluating dataset: {dataset_name}\n======================")

    gt_data = json.load(open(gt_file))
    pred_data = json.load(open(pred_file))

    # map annotations per image
    gt_by_image = defaultdict(list)
    image_sizes = {}
    image_file_map = {}
    for im in gt_data["images"]:
        image_sizes[im["id"]] = (im["height"], im["width"])
        image_file_map[im["id"]] = os.path.join(dataset_path, im["file_name"])
    for ann in gt_data["annotations"]:
        gt_by_image[ann["image_id"]].append(ann)

    pred_by_image = defaultdict(list)
    for ann in pred_data["annotations"]:
        pred_by_image[ann["image_id"]].append(ann)

    # search for best dilation
    best_f1 = -1
    best_dilation = 0
    f1_history = []

    dilation = 0
    while dilation <= MAX_DILATION:
        TP_total = FP_total = FN_total = 0

        for img_id, gt_objs in gt_by_image.items():
            H, W = image_sizes[img_id]
            gt_masks = [seg_to_rle_mask(g["segmentation"], H, W) for g in gt_objs]
            pred_objs = pred_by_image.get(img_id, [])
            pred_masks = [dilate_mask(seg_to_rle_mask(p["segmentation"], H, W), dilation) for p in pred_objs]

            TP, FP, FN, _, _, _ = compute_metrics(pred_masks, gt_masks, IOU_THRESHOLD)
            TP_total += TP
            FP_total += FP
            FN_total += FN

        precision = TP_total / (TP_total + FP_total + 1e-6)
        recall = TP_total / (TP_total + FN_total + 1e-6)
        f1_score = 2 * precision * recall / (precision + recall + 1e-6)
        f1_history.append(f1_score)

        print(f"Dilation {dilation}: TP={TP_total}, FP={FP_total}, FN={FN_total}, Precision={precision:.4f}, Recall={recall:.4f}, F1={f1_score:.4f}")

        # stop if F1 decreasing for 2 consecutive steps
        if len(f1_history) >= 3:
            if f1_history[-1] < f1_history[-2] < f1_history[-3]:
                break

        if f1_score > best_f1:
            best_f1 = f1_score
            best_dilation = dilation

        dilation += 1

    best_dilations[dataset_name] = best_dilation
    print(f"✅ Best dilation for {dataset_name}: {best_dilation}, F1={best_f1:.4f}")

    # ==========================
    # SAVE VISUAL OVERLAYS FOR 2 IMAGES
    # ==========================
    save_overlay_path = os.path.join(output_overlay_root, dataset_name)
    saved_count = 0
    for img_id, gt_objs in gt_by_image.items():
        if saved_count >= NUM_VIS_IMAGES:
            break
        H, W = image_sizes[img_id]
        image_path = image_file_map[img_id]

        gt_masks = [seg_to_rle_mask(g["segmentation"], H, W) for g in gt_objs]
        pred_objs = pred_by_image.get(img_id, [])
        pred_masks = [dilate_mask(seg_to_rle_mask(p["segmentation"], H, W), best_dilation) for p in pred_objs]

        save_path = os.path.join(save_overlay_path, os.path.basename(image_path))
        overlay_masks_on_image(image_path, gt_masks, pred_masks, save_path)
        saved_count += 1

# ==========================
# Mean best dilation
# ==========================
mean_best_dilation = np.mean(list(best_dilations.values()))
print("\n======================")
print("Best dilations per dataset:", best_dilations)
print(f"Mean best dilation across datasets: {mean_best_dilation:.2f}")
print("======================")
