In [2]:
## dilated BB with mask overlay

import os
import json
import torch
from PIL import Image, ImageDraw, ImageColor
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
import numpy as np
import gc
import cv2

# ==========================
# CONFIG
# ==========================
device = "cuda:0"
bpe_path = "./assets/bpe_simple_vocab_16e6.txt.gz"

root_dataset = "./flatbug-dataset"

allowed_folders = {
    # "ALUS",
    # "BIOSCAN",
    # "DiversityScanner",
    # "nhm-beetles-crops",
    # "ArTaxOr",
    # "CollembolAI",
    # "gernat2018",
     "cao2022",
    # "sittinger2023",
    # "amarathunga2022",
    # "biodiscover-arm",
}

prompt_text = "insects"
category_id = 1

def dilate_mask(mask, dilation_pixels=3):
    kernel = np.ones((dilation_pixels, dilation_pixels), np.uint8)
    return cv2.dilate(mask.astype(np.uint8), kernel, iterations=1)


def mask_to_polygon(mask_np):
    import cv2
    mask_uint8 = (mask_np * 255).astype(np.uint8)
    contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    polygons = []
    for cnt in contours:
        if len(cnt) >= 3:
            polygons.append(cnt.reshape(-1).tolist())
    return polygons

# ==========================
# INIT SAM3 MODEL
# ==========================
model = build_sam3_image_model(bpe_path=bpe_path)
model.to(device)
model.eval()
processor = Sam3Processor(model, device=device, confidence_threshold=0.5)

# ==========================
# MAIN PROCESSING LOOP
# ==========================
for dataset_name in sorted(os.listdir(root_dataset)):
    if dataset_name not in allowed_folders:
        continue

    dataset_path = os.path.join(root_dataset, dataset_name)
    print(f"\n==============================")
    print(f"Processing dataset: {dataset_name}")
    print(f"==============================")

    output_json = os.path.join(dataset_path, "sam3_results.json")
    output_image_folder = os.path.join(dataset_path, "sam3_output_images")
    os.makedirs(output_image_folder, exist_ok=True)

    coco_output = {
        "images": [],
        "annotations": [],
        "categories": [{"id": category_id, "name": prompt_text}]
    }

    annotation_id = 1
    image_id = 1

    for filename in os.listdir(dataset_path):
        if not filename.lower().endswith((".jpg", ".jpeg", ".png")):
            continue

        img_path = os.path.join(dataset_path, filename)
        orig_image = Image.open(img_path).convert("RGB")
        orig_w, orig_h = orig_image.size

        # visualization image (RGB)
        draw_image = orig_image.copy()
        draw = ImageDraw.Draw(draw_image)

        # new overlay image (RGBA) for masks
        overlay = Image.new("RGBA", (orig_w, orig_h), (0, 0, 0, 0))
        overlay_draw = ImageDraw.Draw(overlay)

        # ===== SAM3 PREDICTION =====
        image = orig_image.copy()
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            with torch.inference_mode():
                state = processor.set_image(image)
                processor.reset_all_prompts(state)
                state = processor.set_text_prompt(prompt_text, state)

        masks = state["masks"]
        boxes = state["boxes"]
        scores = state["scores"]

        # Determine SAM3 internal size
        if len(masks) > 0:
            sam_h, sam_w = masks.shape[-2:]
        else:
            sam_w, sam_h = orig_w, orig_h
        scale_x = orig_w / sam_w
        scale_y = orig_h / sam_h

        # Add image info
        coco_output["images"].append({
            "id": image_id,
            "file_name": filename,
            "width": orig_w,
            "height": orig_h
        })

        for idx, mask in enumerate(masks):

            # Scale BB
            x0, y0, x1, y1 = boxes[idx].cpu().tolist()
            x0 *= scale_x;  y0 *= scale_y
            x1 *= scale_x;  y1 *= scale_y
            w = x1 - x0
            h = y1 - y0

            # Draw BB
            draw.rectangle([x0, y0, x0 + w, y0 + h], outline="red", width=2)
            draw.text((x0, max(0, y0 - 12)), f"{scores[idx]:.2f}", fill="red")

            # Resize mask to original size
            mask_np = mask.cpu().numpy().squeeze()
            mask_resized = np.array(
                Image.fromarray((mask_np * 255).astype(np.uint8))
                .resize((orig_w, orig_h), Image.NEAREST)
            ) > 127
            mask_resized = mask_resized.astype(np.uint8)
            mask_resized = dilate_mask(mask_resized, dilation_pixels=3)


            # ==========================
            # DRAW MASK OVERLAY
            # ==========================
            # transparent blue overlay (adjust alpha if needed)
            mask_color = (30, 144, 255, 110)  # RGBA: semi-transparent blue
            ys, xs = np.where(mask_resized == 1)
            for x, y in zip(xs, ys):
                overlay_draw.point((x, y), fill=mask_color)

            seg = mask_to_polygon(mask_resized)
            if not seg:
                continue

            coco_output["annotations"].append({
                "id": annotation_id,
                "image_id": image_id,
                "file_name": filename,
                "category_id": category_id,
                "bbox": [float(x0), float(y0), float(w), float(h)],
                "segmentation": seg,
                "area": float(np.sum(mask_resized)),
                "iscrowd": 0,
                "score": float(scores[idx])
            })
            annotation_id += 1

        # =========================================
        # COMPOSITE MASK OVERLAY + ORIGINAL IMAGE
        # =========================================
        draw_image = Image.alpha_composite(draw_image.convert("RGBA"), overlay).convert("RGB")

        # Save visualization
        out_img_path = os.path.join(output_image_folder, filename)
        draw_image.save(out_img_path)

        # Cleanup
        del state, masks, boxes, scores
        torch.cuda.empty_cache()
        gc.collect()

        image_id += 1

    # Save JSON
    with open(output_json, "w") as f:
        json.dump(coco_output, f, indent=2)

    print(f"Saved SAM3 JSON → {output_json}")
    print(f"Saved SAM3 images → {output_image_folder}")

print("\nAll datasets processed successfully.")



Processing dataset: cao2022
Saved SAM3 JSON → ./flatbug-dataset/cao2022/sam3_results.json
Saved SAM3 images → ./flatbug-dataset/cao2022/sam3_output_images

All datasets processed successfully.


In [3]:
!nvidia-smi

Thu Dec 11 00:35:23 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.163.01             Driver Version: 550.163.01     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA TITAN Xp                Off |   00000000:03:00.0 Off |                  N/A |
| 23%   26C    P8              9W /  250W |       2MiB /  12288MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA TITAN Xp                Off |   00

In [None]:
## kill the terminal to restart kernal and clear the memory
!pkill -f ipykernel