## Image segmentation with SAM 3

This notebook demonstrates how to use SAM 3 for image segmentation with text or visual prompts. It covers the following capabilities:

- **Text prompts**: Using natural language descriptions to segment objects (e.g., "person", "face")
- **Box prompts**: Using bounding boxes as exemplar visual prompts

Will use for per frame segmentation.

In [None]:
import os
import glob
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

import sam3
from PIL import Image
from sam3 import build_sam3_image_model
from sam3.model.box_ops import box_xywh_to_cxcywh
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results

sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")

import requests
from io import BytesIO

from sam3.train.data.collator import collate_fn_api as collate
from sam3.model.utils.misc import copy_data_to_device

In [None]:
import torch

# turn on tfloat32 for Ampere GPUs
# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# use bfloat16 for the entire notebook. If your card doesn't support it, try float16 instead
torch.autocast("cuda", dtype=torch.float16).__enter__()

# inference mode for the whole notebook. Disable if you need gradients
torch.inference_mode().__enter__()


# Build Model

In [None]:
# bpe_path = f"{sam3_root}/assets/bpe_simple_vocab_16e6.txt.gz"
model = build_sam3_image_model(bpe_path=r"/workspaces/sam3-main/sam3/assets/bpe_simple_vocab_16e6.txt.gz", checkpoint_path=r"/workspaces/sam3-main/assets/models/sam3.pt")

# Batch Processing Configuration

In [None]:
from sam3.eval.postprocessors import PostProcessImage

postprocessor = PostProcessImage(
    max_dets_per_img=-1,
    iou_type="segm",
    use_original_sizes_box=True,
    use_original_sizes_mask=True,
    convert_mask_to_rle=False,
    detection_threshold=0.65,
    to_cpu=True,
)

In [None]:
# import numpy as np
# import cv2
# from PIL import Image as PILImage

# class CLAHETransformAPI:
#     def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)):
#         self.clip_limit = clip_limit
#         self.tile_grid_size = tile_grid_size

#     def __call__(self, datapoint, **kwargs):
#         # 1. Ensure we have images
#         if not hasattr(datapoint, 'images') or not datapoint.images:
#             return datapoint

#         for img_wrapper in datapoint.images:
#             # 2. Extract PIL Image
#             # We check the most likely locations for the actual bitmap
#             actual_img = None
#             if hasattr(img_wrapper, 'image'):
#                 actual_img = img_wrapper.image
#             elif hasattr(img_wrapper, '_image'):
#                 actual_img = img_wrapper._image
            
#             if actual_img is None:
#                 continue

#             # 3. Process with CLAHE
#             img_np = np.array(actual_img)
            
#             # Normalize 16-bit to 8-bit
#             if img_np.dtype == np.uint16:
#                 img_min, img_max = img_np.min(), img_np.max()
#                 img_np = (255.0 * (img_np - img_min) / (img_max - img_min + 1e-6)).astype(np.uint8)
            
#             # CLAHE Logic
#             if len(img_np.shape) == 3 and img_np.shape[2] == 3:
#                 lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
#                 l, a, b = cv2.split(lab)
#                 clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)
#                 l = clahe.apply(l)
#                 img_np = cv2.cvtColor(cv2.merge((l, a, b)), cv2.COLOR_LAB2RGB)
#             else:
#                 if img_np.ndim == 3: img_np = img_np.squeeze()
#                 clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)
#                 img_np = clahe.apply(img_np)
#                 img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)

#             # 4. Update the wrapper
#             new_pil = PILImage.fromarray(img_np)
#             if hasattr(img_wrapper, 'image'):
#                 img_wrapper.image = new_pil
#             elif hasattr(img_wrapper, '_image'):
#                 img_wrapper._image = new_pil
            
#             # CRITICAL: Some SAM3 wrappers have a 'tensor' or '_tensor' attribute 
#             # that needs to be cleared so 'ToTensorAPI' regenerates it.
#             for tensor_attr in ['tensor', '_tensor', 'data', '_data']:
#                 if hasattr(img_wrapper, tensor_attr):
#                     setattr(img_wrapper, tensor_attr, None)

#         return datapoint

In [None]:
from sam3.train.transforms.basic_for_api import ComposeAPI, RandomResizeAPI, ToTensorAPI, NormalizeAPI, CLAHETransformAPI

# transform = ComposeAPI(
#     transforms=[
#         RandomResizeAPI(sizes=1008, max_size=1008, square=True, consistent_transform=False),
#         # RandomResizeAPI(sizes=1152, max_size=1152, square=True, consistent_transform=False),
#         ToTensorAPI(),
#         NormalizeAPI(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
#     ]
# )

transform = ComposeAPI(
    transforms=[
        # 1. Enhance features first
        CLAHETransformAPI(clip_limit=3.0),
        
        # 2. Resize to the mandatory 1008 (avoiding the RoPE error)
        RandomResizeAPI(sizes=1008, max_size=1008, square=True, consistent_transform=False),
        
        # 3. Convert to Tensor
        ToTensorAPI(),
        
        # 4. Standard SAM3 Normalization
        NormalizeAPI(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)

In [None]:
from sam3.train.data.sam3_image_dataset import InferenceMetadata, FindQueryLoaded, Image as SAMImage, Datapoint
from typing import List

GLOBAL_COUNTER = 1

def create_empty_datapoint():
    """ A datapoint is a single image on which we can apply several queries at once. """
    return Datapoint(find_queries=[], images=[])

def set_image(datapoint, pil_image):
    """ Add the image to be processed to the datapoint """
    w, h = pil_image.size
    datapoint.images = [SAMImage(data=pil_image, objects=[], size=[h, w])]

def add_text_prompt(datapoint, text_query):
    """ Add a text query to the datapoint """
    global GLOBAL_COUNTER
    assert len(datapoint.images) == 1, "please set the image first"
    
    w, h = datapoint.images[0].size
    datapoint.find_queries.append(
        FindQueryLoaded(
            query_text=text_query,
            image_id=0,
            object_ids_output=[],
            is_exhaustive=True,
            query_processing_order=0,
            inference_metadata=InferenceMetadata(
                coco_image_id=GLOBAL_COUNTER,
                original_image_id=GLOBAL_COUNTER,
                original_category_id=1,
                original_size=[w, h],
                object_id=0,
                frame_index=0,
            )
        )
    )
    GLOBAL_COUNTER += 1
    return GLOBAL_COUNTER - 1

# Setup Batching Utilities

In [None]:
# Configuration
folder_path = "/root/data/230624DS30/230624DS30_p0001"
output_folder = "/root/data/Segmentation_SAM3/230624DS30/230624DS30_p0001"
text_prompt = "cells"  # What to segment
batch_size = 2  # Number of images to process at once

# Create output folder
os.makedirs(output_folder, exist_ok=True)

# Get all image files
image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.tiff', '*.bmp']
image_files = []
for ext in image_extensions:
    image_files.extend(glob.glob(os.path.join(folder_path, ext)))
    image_files.extend(glob.glob(os.path.join(folder_path, ext.upper())))

image_files = sorted(image_files)
image_files = [i for i in image_files if "w00" in i]

# only use last 20
image_files = image_files[-4:]

# print first few names
for i in range(4):
    print(image_files[i])


print(f"Found {len(image_files)} images to process")
print(f"Output will be saved to: {output_folder}")

# Batch Process All Images

In [None]:
# from tqdm import tqdm

# # Process images in batches
# for batch_idx in tqdm(range(0, len(image_files), batch_size), desc="Processing batches"):
#     batch_files = image_files[batch_idx:batch_idx + batch_size]
    
#     # Create datapoints for this batch
#     datapoints = []
#     query_ids = []
#     images = []
    
#     for img_path in batch_files:
#         img = Image.open(img_path)

#         # Only take the size that the model allows (1008x)
#         img = i
        
#         # Convert to RGB if needed (handles grayscale and 16-bit images)
#         if img.mode != 'RGB':
#             # For 16-bit or grayscale images, convert to 8-bit RGB
#             if img.mode == 'I;16' or img.mode == 'I':
#                 # Normalize 16-bit to 8-bit range
#                 img_array = np.array(img)
#                 img_array = ((img_array - img_array.min()) / (img_array.max() - img_array.min()) * 255).astype(np.uint8)
#                 img = Image.fromarray(img_array)
#             img = img.convert('RGB')
        
#         images.append(img)
        
#         datapoint = create_empty_datapoint()
#         set_image(datapoint, img)
#         query_id = add_text_prompt(datapoint, text_prompt)
#         query_ids.append(query_id)
        
#         datapoint = transform(datapoint)
#         datapoints.append(datapoint)
    
#     # Collate and move to GPU
#     batch = collate(datapoints, dict_key="dummy")["dummy"]
#     batch = copy_data_to_device(batch, torch.device("cuda"), non_blocking=True)
    
#     # Forward pass
#     output = model(batch)
    
#     # Post-process results
#     processed_results = postprocessor.process_results(output, batch.find_metadatas)
    
#     # Save results for each image
#     for i, (img_path, query_id, img) in enumerate(zip(batch_files, query_ids, images)):
#         result = processed_results[query_id]
        
#         # Get base filename without extension
#         base_name = Path(img_path).stem
        
#         # Save masks as single image with IDs
#         if "masks" in result and len(result["masks"]) > 0:
#             masks = result["masks"]
            
#             # Convert to numpy if it's a tensor
#             if torch.is_tensor(masks):
#                 masks = masks.cpu().numpy()
            
#             height, width = img.size[1], img.size[0]
            
#             # Handle different mask shapes
#             if masks.ndim == 4:
#                 masks = masks.squeeze(1)
            
#             mask_image = np.zeros((height, width), dtype=np.uint16)
#             for mask_idx in range(masks.shape[0]):
#                 mask_image[masks[mask_idx] > 0.5] = mask_idx + 1
            
#             mask_pil = Image.fromarray(mask_image)
#             mask_pil.save(os.path.join(output_folder, f"{base_name}_masks.png"))
        
#         # Save boxes
#         if "boxes" in result and len(result["boxes"]) > 0:
#             boxes = result["boxes"]
#             # Convert to numpy if it's a tensor
#             if torch.is_tensor(boxes):
#                 boxes = boxes.cpu().numpy()
#             np.savetxt(os.path.join(output_folder, f"{base_name}_boxes.txt"), boxes, fmt="%.2f")
        
#         # Save scores
#         if "scores" in result and len(result["scores"]) > 0:
#             scores = result["scores"]
#             # Convert to numpy if it's a tensor
#             if torch.is_tensor(scores):
#                 scores = scores.cpu().numpy()
#             np.savetxt(os.path.join(output_folder, f"{base_name}_scores.txt"), scores, fmt="%.4f")
#             print(os.path.join(output_folder, f"{base_name}_scores.txt"))

# print(f"Processing complete! Results saved to {output_folder}")

In [None]:
import torch
import numpy as np
import torchvision
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import os

# --- Configuration for 2304x2304 Input ---
TILE_SIZE = 1008
# Stride calculation: (2304 - 1008) / 2 = 648. 
# This gives exactly 3 tiles: [0, 648, 1296] (last one ends at 2304)
STRIDES = [0, 648, 2304 - TILE_SIZE] 

print(f"Starting Tiled Inference: {len(image_files)} images, Tile Size={TILE_SIZE}, Grid=3x3")

# We process the tiles in batches, not the full images, to keep memory low
# Flatten all tasks: specific image path + specific tile coordinates
all_tile_tasks = []
for img_path in image_files:
    base_name = Path(img_path).stem
    for top in STRIDES:
        for left in STRIDES:
            all_tile_tasks.append({
                "path": img_path,
                "base_name": base_name,
                "top": top,
                "left": left
            })

# Dictionary to hold results before stitching
# Structure: { "filename": { "boxes": [], "scores": [], "masks": [] } }
stitched_results = {}

# Process tiles in chunks
for i in tqdm(range(0, len(all_tile_tasks), batch_size), desc="Processing Tiles"):
    batch_tasks = all_tile_tasks[i : i + batch_size]
    
    datapoints = []
    metadata_batch = []
    
    # 1. Prepare Batch
    for task in batch_tasks:
        # Load Full Image (Lazy load could be optimized, but OS caching helps here)
        full_img = Image.open(task["path"])
        
        # Crop the 1008x1008 tile
        tile = full_img.crop((task["left"], task["top"], task["left"] + TILE_SIZE, task["top"] + TILE_SIZE))
        
        # Standardize 16-bit/Grayscale to RGB
        # (Your CLAHE transform will handle normalization/contrast later if configured)
        if tile.mode != 'RGB':
            if tile.mode in ['I;16', 'I']:
                arr = np.array(tile)
                # Normalize 16-bit to 8-bit
                arr = ((arr - arr.min()) / (arr.max() - arr.min() + 1e-6) * 255).astype(np.uint8)
                tile = Image.fromarray(arr)
            tile = tile.convert('RGB')

        # Construct Datapoint
        dp = create_empty_datapoint()
        set_image(dp, tile)
        qid = add_text_prompt(dp, text_prompt)
        
        # Transform (Apply CLAHE, Resizing, Normalization here)
        # Ensure your transform is set to output 1008x1008 tensors!
        dp = transform(dp)
        
        datapoints.append(dp)
        metadata_batch.append({
            "base_name": task["base_name"],
            "offset": torch.tensor([task["left"], task["top"], task["left"], task["top"]]),
            "query_id": qid
        })

    # 2. Inference
    if not datapoints: continue
    
    batch_input = collate(datapoints, dict_key="dummy")["dummy"]
    batch_input = copy_data_to_device(batch_input, torch.device("cuda"), non_blocking=True)
    
    with torch.inference_mode():
        output = model(batch_input)
    
# 3. Process & Collect Results with EDGE FILTERING
    processed_batch = postprocessor.process_results(output, batch_input.find_metadatas)
    if isinstance(processed_batch, dict):
        results_list = list(processed_batch.values())
    else:
        results_list = processed_batch

    for i, meta in enumerate(metadata_batch):
        if i >= len(results_list): continue
        res = results_list[i]
        name = meta["base_name"]
        
        if name not in stitched_results:
            stitched_results[name] = {"boxes": [], "scores": [], "masks": [], "mask_offsets": []}
        
        if "boxes" in res and len(res["boxes"]) > 0:
            boxes = res["boxes"] # Local coordinates (0-1008)
            scores = res["scores"]
            masks = res["masks"]
            
            # --- EDGE REJECTION LOGIC ---
            # Define inner margins (e.g., 5-10 pixels from the tile edge)
            margin = 5
            tile_h, tile_w = TILE_SIZE, TILE_SIZE # 1008
            
            # Check which edges of the tile are ACTUAL image borders
            # meta["offset"] is [left, top, left, top]
            # If left == 0, we keep left-touching cells. If left > 0, we drop them.
            is_left_edge = (meta["offset"][0] == 0)
            is_top_edge = (meta["offset"][1] == 0)
            # Check if this tile reaches the bottom/right of the full 2304 image
            is_right_edge = (meta["offset"][0] + TILE_SIZE >= 2304)
            is_bottom_edge = (meta["offset"][1] + TILE_SIZE >= 2304)

            keep_mask = torch.ones(len(boxes), dtype=torch.bool, device=boxes.device)
            
            # x1, y1, x2, y2
            b_x1 = boxes[:, 0]
            b_y1 = boxes[:, 1]
            b_x2 = boxes[:, 2]
            b_y2 = boxes[:, 3]

            # Filter Left: If not true image edge, drop boxes touching x=0
            if not is_left_edge:
                keep_mask &= (b_x1 > margin)
            
            # Filter Top: If not true image edge, drop boxes touching y=0
            if not is_top_edge:
                keep_mask &= (b_y1 > margin)
            
            # Filter Right: If not true image edge, drop boxes touching x=1008
            if not is_right_edge:
                keep_mask &= (b_x2 < tile_w - margin)
                
            # Filter Bottom: If not true image edge, drop boxes touching y=1008
            if not is_bottom_edge:
                keep_mask &= (b_y2 < tile_h - margin)
            
            # Apply Filter
            if keep_mask.sum() == 0:
                continue

            boxes = boxes[keep_mask]
            scores = scores[keep_mask]
            masks = masks[keep_mask]
            
            # Shift boxes to Global Coordinates
            offset = meta["offset"].to(boxes.device)
            global_boxes = boxes + offset
            
            stitched_results[name]["boxes"].append(global_boxes)
            stitched_results[name]["scores"].append(scores)
            stitched_results[name]["masks"].append(masks)
            
            # Store offsets for reconstruction
            n_dets = len(scores)
            stitched_results[name]["mask_offsets"].append(offset[:2].repeat(n_dets, 1))

In [None]:
# --- Stitching & Saving Phase ---
print("Stitching tiles and saving results...")

for base_name, data in tqdm(stitched_results.items(), desc="Saving Images"):
    if not data["boxes"]:
        print(f"No detections for {base_name}")
        continue
        
    # 1. Concatenate all tiles for this image
    # We use .detach() to ensure we aren't dragging the gradient graph along
    all_boxes = torch.cat(data["boxes"]).detach()
    all_scores = torch.cat(data["scores"]).detach()
    all_masks = torch.cat(data["masks"]).detach()
    all_offsets = torch.cat(data["mask_offsets"]).detach()

    # 2. CRITICAL FIX: Cast to float32 to satisfy torchvision.ops.nms
    all_boxes = all_boxes.to(dtype=torch.float32)
    all_scores = all_scores.to(dtype=torch.float32)

    # 3. Global NMS (Non-Maximum Suppression)
    # Remove duplicates in overlap regions (IoU > 0.5)
    keep_indices = torchvision.ops.nms(all_boxes, all_scores, iou_threshold=0.5)

    # Change from 0.5 to 0.3 or 0.2
    # This prevents merging two different cells that are packed tightly together
    # keep_indices = torchvision.ops.nms(all_boxes, all_scores, iou_threshold=0.3)
    
    # 4. Filter results based on NMS
    final_boxes = all_boxes[keep_indices].cpu().numpy()
    final_scores = all_scores[keep_indices].cpu().numpy()
    final_masks = all_masks[keep_indices] # Keep on GPU for faster processing
    final_offsets = all_offsets[keep_indices]

    # 5. Reconstruct the 2304x2304 Mask Image
    # 0 = Background, 1..N = Instance IDs
    # We use int32 to prevent overflow if you have > 65535 cells (unlikely but safe)
    full_mask = torch.zeros((2304, 2304), dtype=torch.int32, device="cuda")
    
    for i, (mask_tensor, offset) in enumerate(zip(final_masks, final_offsets)):
        # Squeeze if mask is (1, H, W) -> (H, W)
        if mask_tensor.ndim == 3: 
            mask_tensor = mask_tensor.squeeze(0)
        
        # Threshold logic
        if mask_tensor.min() < 0:
            binary_mask = mask_tensor > 0.0
        else:
            binary_mask = mask_tensor > 0.5
        
        # Calculate placement coordinates
        x_start, y_start = int(offset[0]), int(offset[1])
        x_end = int(x_start + TILE_SIZE)
        y_end = int(y_start + TILE_SIZE)
        
        # Clip coordinates to image boundaries
        img_h, img_w = 2304, 2304
        x_end = min(x_end, img_w)
        y_end = min(y_end, img_h)
        
        # Calculate mask dimensions
        mask_h = y_end - y_start
        mask_w = x_end - x_start
        
        # Slice the binary mask
        valid_binary_mask = binary_mask[:mask_h, :mask_w]
        
        # --- CRITICAL FIX: DEVICE SYNC ---
        # Ensure the mask slice is on the same device as the full_mask (GPU)
        valid_binary_mask = valid_binary_mask.to(full_mask.device)
        
        # Get the region of the full mask
        current_slice = full_mask[y_start:y_end, x_start:x_end]
        
        # Perform the update
        # We ensure the scalar 'i+1' is also on the correct device
        full_mask[y_start:y_end, x_start:x_end] = torch.where(
            valid_binary_mask, 
            torch.tensor(i + 1, dtype=torch.int32, device=full_mask.device), 
            current_slice
        )

    # 6. Save Files
    # Move final mask to CPU/Numpy for saving
    full_mask_np = full_mask.cpu().numpy().astype(np.uint16)

    # Save Mask (Use PIL or OpenCV)
    Image.fromarray(full_mask_np).save(os.path.join(output_folder, f"{base_name}_masks.png"))
    
    # Save Boxes
    np.savetxt(os.path.join(output_folder, f"{base_name}_boxes.txt"), final_boxes, fmt="%.2f")
    
    # Save Scores
    np.savetxt(os.path.join(output_folder, f"{base_name}_scores.txt"), final_scores, fmt="%.4f")

print(f"Processing complete! Results saved to {output_folder}")

In [None]:
# import torch
# import numpy as np
# from PIL import Image

# def optimized_prepare_batch(img_paths):
#     tensors = []
#     for path in img_paths:
#         # Load as-is (16-bit)
#         img = Image.open(path)
#         arr = np.array(img).astype(np.float32)
        
#         # Normalize 0-65535 to 0-1 on CPU quickly
#         arr /= 65535.0 
        
#         # To tensor (H, W) -> (1, H, W)
#         t = torch.from_numpy(arr).unsqueeze(0)
#         tensors.append(t)
    
#     # Stack into (B, 1, H, W)
#     batch = torch.stack(tensors).to("cuda", non_blocking=True)
    
#     # "Triple" the data ON THE GPU using .expand
#     # This creates a view, not a memory copy, until necessary
#     batch_rgb = batch.expand(-1, 3, -1, -1) 
#     return batch_rgb

# # Updated Benchmark for 2304x2304
# def benchmark_sam3_real_res(model, test_batch_sizes=[1, 2, 4, 8]):
#     # Note: 2304x2304 is huge, you likely won't get past batch 4 or 8
#     for b in test_batch_sizes:
#         try:
#             # Match your actual input resolution
#             dummy_batch = torch.randn(b, 3, 2304, 2304).to("cuda")
            
#             with torch.inference_mode():
#                 # Warmup
#                 _ = model(dummy_batch)
                
#                 torch.cuda.synchronize()
#                 start = time.perf_counter()
#                 for _ in range(5):
#                     _ = model(dummy_batch)
#                 torch.cuda.synchronize()
                
#                 throughput = (b * 5) / (time.perf_counter() - start)
#                 print(f"BS: {b} | {throughput:.2f} images/sec")
#         except torch.cuda.OutOfMemoryError:
#             print(f"BS: {b} | OOM")
#             torch.cuda.empty_cache()
#             break

# benchmark_sam3_real_res(model)

In [None]:
import matplotlib.patches as patches

# Visualize results for a sample image
if len(image_files) > 0:
    sample_idx = -1
    sample_path = image_files[sample_idx]
    base_name = Path(sample_path).stem
    
    # Paths for all result files
    mask_path = os.path.join(output_folder, f"{base_name}_masks.png")
    box_path = os.path.join(output_folder, f"{base_name}_boxes.txt")
    score_path = os.path.join(output_folder, f"{base_name}_scores.txt") # Load scores
    
    if os.path.exists(mask_path) and os.path.exists(box_path) and os.path.exists(score_path):
        img = Image.open(sample_path)
        mask_image = np.array(Image.open(mask_path))
        
        # Handle cases where files might be empty
        boxes = np.loadtxt(box_path).reshape(-1, 4) if os.path.getsize(box_path) > 0 else np.array([])
        scores = np.loadtxt(score_path).reshape(-1) if os.path.getsize(score_path) > 0 else np.array([])
        
        fig, axes = plt.subplots(1, 2, figsize=(20, 10))
        
        # Original Image + Bounding Boxes & Confidence
        axes[0].imshow(img)
        for i, (box, score) in enumerate(zip(boxes, scores)):
            # Box format: assuming [x1, y1, x2, y2]
            x1, y1, x2, y2 = box
            rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='r', facecolor='none')
            axes[0].add_patch(rect)
            
            # Confidence Label
            label = f"{score*100:.1f}%"
            axes[0].text(x1, y1 - 10, label, color='white', fontsize=6, 
                         fontweight='bold', bbox=dict(facecolor='red', alpha=0.2, edgecolor='none'))
            
        axes[0].set_title(f"Detections: {base_name}")
        axes[0].axis("off")
        
        # Mask Visualization
        axes[1].imshow(mask_image, cmap='nipy_spectral')
        axes[1].set_title(f"Masks: {len(scores)} objects found")
        axes[1].axis("off")
        
        plt.tight_layout()
        
        # Save image
        plt.savefig("check.png")

        plt.show()
        
        print(f"Sample: {base_name}")
        print(f"Confidence range: {scores.min():.2f} to {scores.max():.2f}")
    else:
        print(f"Missing result files for {base_name}. Check your output folder!")
else:
    print("No images found in the folder!")

# Optional: Single Image Examples (Original Code)

The cells below show the original single-image processing examples for reference.

### Visual prompt: a single bounding box

In [None]:
# # Here the box is in  (x,y,w,h) format, where (x,y) is the top left corner.
# box_input_xywh = torch.tensor([480.0, 290.0, 110.0, 360.0]).view(-1, 4)
# box_input_cxcywh = box_xywh_to_cxcywh(box_input_xywh)

# norm_box_cxcywh = normalize_bbox(box_input_cxcywh, width, height).flatten().tolist()
# print("Normalized box input:", norm_box_cxcywh)

# processor.reset_all_prompts(inference_state)
# inference_state = processor.add_geometric_prompt(
#     state=inference_state, box=norm_box_cxcywh, label=True
# )

# img0 = Image.open(image_path)
# image_with_box = draw_box_on_image(img0, box_input_xywh.flatten().tolist())
# plt.imshow(image_with_box)
# plt.axis("off")  # Hide the axis
# plt.show()

In [None]:
# plot_results(img0, inference_state)

### Visual prompt: multi-box prompting (with positive and negative boxes)

In [None]:
# box_input_xywh = [[480.0, 290.0, 110.0, 360.0], [370.0, 280.0, 115.0, 375.0]]
# box_input_cxcywh = box_xywh_to_cxcywh(torch.tensor(box_input_xywh).view(-1,4))
# norm_boxes_cxcywh = normalize_bbox(box_input_cxcywh, width, height).tolist()

# box_labels = [True, False]

# processor.reset_all_prompts(inference_state)

# for box, label in zip(norm_boxes_cxcywh, box_labels):
#     inference_state = processor.add_geometric_prompt(
#         state=inference_state, box=box, label=label
#     )

# img0 = Image.open(image_path)
# image_with_box = img0
# for i in range(len(box_input_xywh)):
#     if box_labels[i] == 1:
#         color = (0, 255, 0)
#     else:
#         color = (255, 0, 0)
#     image_with_box = draw_box_on_image(image_with_box, box_input_xywh[i], color)
# plt.imshow(image_with_box)
# plt.axis("off")  # Hide the axis
# plt.show()

In [None]:
# plot_results(img0, inference_state)