In [None]:
import cv2
import numpy as np
from samexporter.sam2_onnx import SegmentAnything2ONNX
drawing = False
ix, iy = -1, -1
current_rect = None
live_mask = None
model = None
embedding = None
image = None

def run_sam2_with_box(x1, y1, x2, y2):
    global embedding, model, image

    prompt = [
        {"type": "rectangle", "data": [x1, y1, x2, y2]}
    ]

    masks = model.predict_masks(embedding, prompt)
    mask = masks[0, 0]

    mask = (mask > 0.5).astype(np.uint8) * 255
    mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
    overlay = cv2.addWeighted(image, 0.6, mask, 0.4, 0)
    cv2.rectangle(overlay, (x1, y1), (x2, y2), (0,255,0), 2)
    return overlay


def mouse_callback(event, x, y, flags, param):
    global ix, iy, drawing, current_rect, live_mask

    if event == cv2.EVENT_LBUTTONDOWN:
        drawing = True
        ix, iy = x, y
        current_rect = None

    elif event == cv2.EVENT_MOUSEMOVE:
        if drawing:
            current_rect = (ix, iy, x, y)
            x1, y1, x2, y2 = current_rect
            live_mask = run_sam2_with_box(x1, y1, x2, y2)

    elif event == cv2.EVENT_LBUTTONUP:
        drawing = False
        current_rect = (ix, iy, x, y)
        live_mask = run_sam2_with_box(ix, iy, x, y)


def main():
    global model, image, embedding, live_mask
    image = cv2.imread("../images/truck.jpg")
    if image is None:
        print("ERROR: image not found")
        return
    model = SegmentAnything2ONNX(
        "../output_models/sam2_hiera_tiny.encoder.onnx",
        "../output_models/sam2_hiera_tiny.decoder.onnx"
    )

    embedding = model.encode(image)

    cv2.namedWindow("Live SAM2")
    cv2.setMouseCallback("Live SAM2", mouse_callback)

    print("Drag a rectangle to see live segmentation...")

    while True:
        if live_mask is not None:
            show = live_mask.copy()
        else:
            show = image.copy()

        cv2.imshow("Live SAM2", show)
        if cv2.waitKey(1) == ord("q"):
            break

    cv2.destroyAllWindows()


if __name__ == "__main__":
    main()
