In [21]:
from sam2.sam2_image_predictor import SAM2ImagePredictor
from PIL import Image
import numpy as np
import torch

In [5]:
device = "cpu"
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small", device=device)

image = Image.open("../notebooks/images/truck.jpg")
image = np.array(image.convert("RGB"))
predictor.set_image(image)

In [3]:
input_box = np.array([425, 600, 700, 875])
input_point = np.array([[575, 750]])
input_label = np.array([0])

Prepare coords

In [17]:
def prepare_points(point_coords=None, point_labels=None, boxes=None):
    if point_coords is not None:
        concat_points = (point_coords, point_labels)
    else:
        concat_points = None
    
    # Embed prompts
    if boxes is not None:
        box_coords = boxes.reshape(-1, 2, 2)
        box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
        box_labels = box_labels.repeat(boxes.size(0), 1)
        # we merge "boxes" and "points" into a single "concat_points" input (where
        # boxes are added at the beginning) to sam_prompt_encoder
        if concat_points is not None:
            concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
            concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
            concat_points = (concat_coords, concat_labels)
        else:
            concat_points = (box_coords, box_labels)
    return concat_points

This is the actual encoder being invoked (`boxes` is always `None`)

In [None]:
# sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
#     points=concat_points,
#     boxes=None,
#     masks=mask_input,
# )

In [15]:
_, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
    input_point, input_label, input_box, None, True
)

Compare preparation: points, box, both

In [24]:
points_only = prepare_points(unnorm_coords, labels)
points_only

(tensor([[[327.1111, 640.0000]]]), tensor([[0]], dtype=torch.int32))

In [25]:
box_only = prepare_points(boxes=unnorm_box)
box_only

(tensor([[[241.7778, 512.0000],
          [398.2222, 746.6667]]]),
 tensor([[2, 3]], dtype=torch.int32))

In [26]:
points_and_box = prepare_points(unnorm_coords, labels, unnorm_box)
points_and_box

(tensor([[[241.7778, 512.0000],
          [398.2222, 746.6667],
          [327.1111, 640.0000]]]),
 tensor([[2, 3, 0]], dtype=torch.int32))

Compare embeddings and save for later use

In [42]:
def points_to_embeddings(prepared_points):
    with torch.inference_mode():
        embeddings = predictor.model.sam_prompt_encoder(prepared_points, boxes=None, masks=None)
        embeddings_with_points_only = predictor.model.sam_prompt_encoder.points_only(prepared_points)
    for a, b in zip(embeddings, embeddings_with_points_only):
        assert torch.allclose(a, b)
    return embeddings

In [43]:
points_only_embeddings = points_to_embeddings(points_only)

In [52]:
[x.shape for x in points_only_embeddings]

[torch.Size([1, 2, 256]), torch.Size([1, 256, 64, 64])]

In [45]:
box_only_embeddings = points_to_embeddings(box_only)

In [53]:
[x.shape for x in box_only_embeddings]

[torch.Size([1, 3, 256]), torch.Size([1, 256, 64, 64])]

In [48]:
points_and_box_embeddings = points_to_embeddings(points_and_box)

In [54]:
[x.shape for x in points_and_box_embeddings]

[torch.Size([1, 4, 256]), torch.Size([1, 256, 64, 64])]

## Compare with Core ML

In [55]:
import coremltools as ct



In [56]:
coreml_prompt_encoder = ct.models.MLModel("sam2_small_prompt_encoder.mlpackage")

In [58]:
def encode_prompt(points, labels):
    output = coreml_prompt_encoder.predict({"points": points, "labels": labels})
    return output["sparse_embeddings"], output["dense_embeddings"]

In [61]:
points_only_coreml_embeddings = encode_prompt(points_only[0], points_only[1])
[x.shape for x in points_only_coreml_embeddings]

[(1, 2, 256), (1, 256, 64, 64)]

In [65]:
np.sum(points_only_coreml_embeddings[1] - points_only_embeddings[1].detach().numpy())

0.0

In [66]:
points_and_box_coreml_embeddings = encode_prompt(points_and_box[0], points_and_box[1])
[x.shape for x in points_and_box_coreml_embeddings]

[(1, 4, 256), (1, 256, 64, 64)]

In [68]:
np.sum(points_and_box_coreml_embeddings[1] - points_and_box_embeddings[1].detach().numpy())

0.0