In [9]:
#with mobile_sam model_basic_masking
import cv2
import numpy as np
import onnxruntime as ort
import torch
from mobile_sam import sam_model_registry
from mobile_sam.utils.transforms import ResizeLongestSide
import time
encoder = ort.InferenceSession("../mobile_sam_onnx/mobile_sam_encoder.onnx", providers=["CPUExecutionProvider"])
decoder = ort.InferenceSession("../mobile_sam_onnx/mobile_sam_decoder.onnx", providers=["CPUExecutionProvider"])
sam = sam_model_registry["vit_t"](checkpoint="../mobile_sam_onnx/mobile_sam.pt")
sam.to("cpu")
prompt_encoder = sam.prompt_encoder
transform = ResizeLongestSide(1024)
image_path = "../sample_images/istockphoto-2196087139-612x612.jpg"
orig = cv2.imread(image_path)
H, W = orig.shape[:2]
rgb = cv2.cvtColor(orig, cv2.COLOR_BGR2RGB)
resized = transform.apply_image(rgb).astype(np.float32)
tensor = torch.as_tensor(resized).permute(2, 0, 1)[None]
padded = torch.zeros((1, 3, 1024, 1024), dtype=torch.float32)
padded[:, :, :tensor.shape[2], :tensor.shape[3]] = tensor
padded = padded / 255.0
print(f"Original: {H}x{W} | Resized: {tensor.shape[2]}x{tensor.shape[3]} | Padding: {(1024-tensor.shape[2])}px bottom")
t0 = time.time()
image_embedding = encoder.run(None, {"image": padded.numpy()})[0]
t1 = time.time()
print(f"Encoder runtime: {t1 - t0:.3f}s | Embedding shape: {image_embedding.shape}")
clicks = []

def mouse(event, x, y, flags, param):
    if event == cv2.EVENT_LBUTTONDOWN:
        clicks.clear()
        clicks.append((x, y))

cv2.namedWindow("MobileSAM Live")
cv2.setMouseCallback("MobileSAM Live", mouse)
while True:
    frame = orig.copy()

    if clicks:
        px, py = clicks[0]
        pt = np.array([[px, py]])
        pt = torch.as_tensor(transform.apply_coords(pt, (H, W)), dtype=torch.float32)[None]
        labels = torch.ones(1, dtype=torch.int64)[None]
        with torch.no_grad():
            sparse, dense = prompt_encoder(points=(pt, labels), boxes=None, masks=None)
        sparse = sparse.detach().cpu().numpy()
        dense = dense.detach().cpu().numpy()
        t2 = time.time()
        mask, iou = decoder.run(None, {
            "image_embedding": image_embedding,
            "sparse_prompt": sparse,
            "dense_prompt": dense
        })
        t3 = time.time()
        print(f"Decoder runtime: {t3 - t2:.3f}s")
        mask = mask[0][0]  
        new_h, new_w = tensor.shape[2] // 4, tensor.shape[3] // 4
        pad_y = (1024 - tensor.shape[2]) // 4
        pad_x = (1024 - tensor.shape[3]) // 4
        print(f"Cropping mask: valid {new_h}x{new_w}, pad_y={pad_y}, pad_x={pad_x}")
        mask = mask[:new_h, :new_w]
        mask = cv2.resize(mask, (W, H))
        mask_bin = (mask > 0.5).astype(np.uint8)
        frame[mask_bin == 1] = (0, 255, 0)

    cv2.imshow("MobileSAM Live", frame)
    if cv2.waitKey(1) & 0xFF == 27:
        break

cv2.destroyAllWindows()


Original: 408x612 | Resized: 683x1024 | Padding: 341px bottom
Encoder runtime: 0.476s | Embedding shape: (1, 256, 64, 64)
Decoder runtime: 0.029s
Cropping mask: valid 170x256, pad_y=85, pad_x=0
Decoder runtime: 0.027s
Cropping mask: valid 170x256, pad_y=85, pad_x=0
Decoder runtime: 0.021s
Cropping mask: valid 170x256, pad_y=85, pad_x=0
Decoder runtime: 0.021s
Cropping mask: valid 170x256, pad_y=85, pad_x=0
Decoder runtime: 0.021s
Cropping mask: valid 170x256, pad_y=85, pad_x=0
Decoder runtime: 0.022s
Cropping mask: valid 170x256, pad_y=85, pad_x=0
Decoder runtime: 0.022s
Cropping mask: valid 170x256, pad_y=85, pad_x=0
Decoder runtime: 0.021s
Cropping mask: valid 170x256, pad_y=85, pad_x=0
Decoder runtime: 0.022s
Cropping mask: valid 170x256, pad_y=85, pad_x=0
Decoder runtime: 0.024s
Cropping mask: valid 170x256, pad_y=85, pad_x=0
Decoder runtime: 0.026s
Cropping mask: valid 170x256, pad_y=85, pad_x=0
Decoder runtime: 0.027s
Cropping mask: valid 170x256, pad_y=85, pad_x=0
Decoder runtim

In [None]:
#with mobile_sam model and black and white masking
import cv2
import numpy as np
import onnxruntime as ort
import torch
from mobile_sam import sam_model_registry
from mobile_sam.utils.transforms import ResizeLongestSide
import time
encoder = ort.InferenceSession("../mobile_sam_onnx/mobile_sam_encoder.onnx", providers=["CPUExecutionProvider"])
decoder = ort.InferenceSession("../mobile_sam_onnx/mobile_sam_decoder.onnx", providers=["CPUExecutionProvider"])
sam = sam_model_registry["vit_t"](checkpoint="../mobile_sam_onnx/mobile_sam.pt")
sam.to("cpu")
prompt_encoder = sam.prompt_encoder
transform = ResizeLongestSide(1024)
image_path = "../sample_images/gettyimages-2168448371-612x612.jpg"
orig = cv2.imread(image_path)
H, W = orig.shape[:2]
rgb = cv2.cvtColor(orig, cv2.COLOR_BGR2RGB)
resized = transform.apply_image(rgb).astype(np.float32)
tensor = torch.as_tensor(resized).permute(2, 0, 1)[None]
padded = torch.zeros((1, 3, 1024, 1024), dtype=torch.float32)
padded[:, :, :tensor.shape[2], :tensor.shape[3]] = tensor
padded = padded / 255.0
print(f"Original: {H}x{W} | Resized: {tensor.shape[2]}x{tensor.shape[3]} | Padding: {(1024-tensor.shape[2])}px bottom")
t0 = time.time()
image_embedding = encoder.run(None, {"image": padded.numpy()})[0]
t1 = time.time()
print(f"Encoder runtime: {t1 - t0:.3f}s | Embedding shape: {image_embedding.shape}")
clicks = []
def mouse(event, x, y, flags, param):
    if event == cv2.EVENT_LBUTTONDOWN:
        clicks.clear()
        clicks.append((x, y))

cv2.namedWindow("MobileSAM Live")
cv2.namedWindow("Segmentation Mask")
cv2.setMouseCallback("MobileSAM Live", mouse)
mask_bin = np.zeros((H, W), dtype=np.uint8)

while True:
    frame = orig.copy()
    mask_display = np.zeros_like(orig)

    if clicks:
        px, py = clicks[0]
        pt = np.array([[px, py]])
        pt = torch.as_tensor(transform.apply_coords(pt, (H, W)), dtype=torch.float32)[None]
        labels = torch.ones(1, dtype=torch.int64)[None]
        with torch.no_grad():
            sparse, dense = prompt_encoder(points=(pt, labels), boxes=None, masks=None)
        sparse = sparse.detach().cpu().numpy()
        dense = dense.detach().cpu().numpy()
        t2 = time.time()
        mask, iou = decoder.run(None, {
            "image_embedding": image_embedding,
            "sparse_prompt": sparse,
            "dense_prompt": dense
        })
        t3 = time.time()
        print(f"Decoder runtime: {t3 - t2:.3f}s")
        mask = mask[0][0]
        new_h, new_w = tensor.shape[2] // 4, tensor.shape[3] // 4
        pad_y = (1024 - tensor.shape[2]) // 4
        pad_x = (1024 - tensor.shape[3]) // 4
        mask = mask[:new_h, :new_w]
        mask = cv2.resize(mask, (W, H))
        mask_bin = (mask > 0.5).astype(np.uint8)
        frame[mask_bin == 1] = (0, 255, 0)
        mask_display = np.stack([mask_bin * 255] * 3, axis=-1)
    cv2.imshow("MobileSAM Live", frame)
    cv2.imshow("Segmentation Mask", mask_display)
    key = cv2.waitKey(1) & 0xFF
    if key == 27:
        break
    elif key == ord('s'):
        cv2.imwrite("segmentation_mask.png", mask_display)
        print("ðŸ’¾ Saved segmentation mask as segmentation_mask.png")

cv2.destroyAllWindows()


Original: 401x612 | Resized: 671x1024 | Padding: 353px bottom
Encoder runtime: 0.436s | Embedding shape: (1, 256, 64, 64)


qt.qpa.plugin: Could not find the Qt platform plugin "wayland" in "/home/logan78/.local/lib/python3.10/site-packages/cv2/qt/plugins"


Decoder runtime: 0.026s
Decoder runtime: 0.027s
Decoder runtime: 0.021s
Decoder runtime: 0.018s
Decoder runtime: 0.018s
Decoder runtime: 0.021s
Decoder runtime: 0.018s
Decoder runtime: 0.022s
Decoder runtime: 0.021s
Decoder runtime: 0.021s
Decoder runtime: 0.021s
Decoder runtime: 0.021s
Decoder runtime: 0.018s
Decoder runtime: 0.018s
Decoder runtime: 0.022s
Decoder runtime: 0.019s
Decoder runtime: 0.021s
Decoder runtime: 0.021s
Decoder runtime: 0.021s
Decoder runtime: 0.021s
Decoder runtime: 0.021s
Decoder runtime: 0.022s
Decoder runtime: 0.022s
Decoder runtime: 0.019s
Decoder runtime: 0.019s
Decoder runtime: 0.019s
Decoder runtime: 0.018s
Decoder runtime: 0.022s
Decoder runtime: 0.022s
Decoder runtime: 0.018s
Decoder runtime: 0.019s
Decoder runtime: 0.019s
Decoder runtime: 0.019s
Decoder runtime: 0.019s
Decoder runtime: 0.018s
Decoder runtime: 0.018s
Decoder runtime: 0.020s
Decoder runtime: 0.021s
Decoder runtime: 0.021s
Decoder runtime: 0.021s
Decoder runtime: 0.021s
Decoder runtime: