In [None]:
# Import libraries
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Literal
import time

from PIL import Image, ImageDraw, ImageOps
from tqdm import tqdm
from inference import get_model
from ultralytics import SAM, FastSAM

In [None]:
# RoboflowDetector class
class RoboflowDetector:
    def __init__(self, model_id: str, api_key: str):
        self.model = get_model(model_id=model_id, api_key=api_key)

    def detect(self, img_path: Path) -> list[float] | None:
        results = self.model.infer(str(img_path))
        predictions = results[0].predictions

        if not predictions:
            return None

        det = predictions[0]
        return [
            det.x - det.width / 2,
            det.y - det.height / 2,
            det.x + det.width / 2,
            det.y + det.height / 2,
        ]

In [None]:
# SAMSegmenter class
class SAMSegmenter:
    def __init__(self, model_name: str = "sam2.1_b.pt"):
        self.model = SAM(model_name)

    def segment(self, img: Image.Image):
        cx, cy = img.width // 2, img.height // 2
        return self.model(img, points=[[cx, cy]], labels=[1], verbose=False)

    def segment_bbox(self, img_path: Path, bbox: list[float]):
        return self.model(str(img_path), bboxes=[bbox], verbose=False)

In [None]:
# FastSAMSegmenter class
class FastSAMSegmenter:
    def __init__(self, model_name: str = "FastSAM-s.pt"):
        self.model = FastSAM(model_name)

    def segment(self, img: Image.Image):
        results = self.model(img, verbose=False)
        if not results or results[0].masks is None:
            return results
        
        masks = results[0].masks.data
        if len(masks) == 0:
            return results
        
        areas = masks.sum(dim=(1, 2))
        largest_idx = areas.argmax().item()
        
        results[0].masks.data = masks[largest_idx:largest_idx+1]
        if results[0].boxes is not None:
            results[0].boxes.data = results[0].boxes.data[largest_idx:largest_idx+1]
        
        return results

    def segment_bbox(self, img_path: Path, bbox: list[float]):
        return self.model(str(img_path), bboxes=[bbox], verbose=False)

In [None]:
# img_pipeline function
def img_pipeline(
    img_path: Path,
    detect_fn: Callable[[Path], list[float] | None],
    segment_fn: Callable[..., Any],
    det_output_dir: Path = Path("detection_output"),
    seg_output_dir: Path = Path("sam_output"),
    txt_output_dir: Path | None = None,
    mode: Literal["crop", "bbox"] = "crop",
):
    bbox = detect_fn(img_path)
    if bbox is None:
        return

    img = Image.open(img_path)
    img_corrected = ImageOps.exif_transpose(img)

    img_viz = img_corrected.copy()
    draw = ImageDraw.Draw(img_viz)
    draw.rectangle(bbox, outline="red", width=5)

    det_output_dir.mkdir(exist_ok=True, parents=True)
    img_viz.save(det_output_dir / (img_path.stem + ".png"))

    if mode == "crop":
        cropped = img_corrected.crop(bbox)
        results_guided = segment_fn(cropped)
    else:
        with NamedTemporaryFile(suffix=".png", delete=False) as tmp:
            img_corrected.save(tmp, format="PNG")
            tmp_path = Path(tmp.name)
        
        try:
            results_guided = segment_fn(tmp_path, bbox)
        finally:
            if tmp_path.exists():
                tmp_path.unlink()

    seg_output_dir.mkdir(exist_ok=True, parents=True)

    if txt_output_dir is not None:
        txt_output_dir.mkdir(exist_ok=True, parents=True)

    for result in results_guided:
        res_plotted = result.plot()
        img_sam = Image.fromarray(res_plotted[..., ::-1])
        img_sam.save(seg_output_dir / (img_path.stem + ".png"))

        if txt_output_dir is not None and result.masks is not None:
            mask_data = result.masks.data[0]
            ys, xs = (mask_data > 0.5).nonzero(as_tuple=True)
            if len(xs) == 0:
                continue

            h, w = mask_data.shape
            xs_norm = xs.float() / float(w)
            ys_norm = ys.float() / float(h)

            parts: list[str] = ["0"]
            for x_val, y_val in zip(xs_norm.tolist(), ys_norm.tolist()):
                parts.append(f"{x_val:.6f}")
                parts.append(f"{y_val:.6f}")

            txt_path = txt_output_dir / f"{img_path.stem}.txt"
            txt_path.write_text(" ".join(parts))

In [None]:
# Configuration
ROBOFLOW_API_KEY = "KDf14v839AjRzJXanN4h"
ROBOFLOW_MODEL_ID = "raspberrypi_redball/2"
SAM_MODEL_NAME = "sam2.1_b.pt"

detector = RoboflowDetector(model_id=ROBOFLOW_MODEL_ID, api_key=ROBOFLOW_API_KEY)
segmenter = SAMSegmenter()

In [None]:
# Setup paths
imgs_path = Path("balls")
DET_OUTPUT = Path("detection_output_folder")
SEG_OUTPUT = Path("seg_output_folder")
TEXT_OUTPUT = Path("txt_output_folder")
assert imgs_path.exists(), f"Folder '{imgs_path}' not found"
img_paths = list(imgs_path.glob("*"))

In [None]:
# Run pipeline
start_time = time.time()

for img_path in tqdm(img_paths):
    img_pipeline(
        img_path,
        detect_fn=detector.detect,
        segment_fn=segmenter.segment_bbox,
        det_output_dir=DET_OUTPUT,
        seg_output_dir=SEG_OUTPUT,
        txt_output_dir=TEXT_OUTPUT,
        mode="bbox"
    )

duration = time.time() - start_time
print(f"Done in {duration:.2f}s ({duration / len(img_paths):.2f}s/img)")