# 👩🏽‍💻 Setup

In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything.git

In [None]:
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

In [None]:
!pip install --upgrade modelbit

In [None]:
from segment_anything import sam_model_registry, SamPredictor
import cv2
import urllib
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import modelbit
mb = modelbit.login()

# 📊 Matplotlib Helpers

In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='yellow', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

def get_bounding_box(mask):
    x_min = None
    x_max = None
    y_min = None
    y_max = None
    for y, row in enumerate(mask):
        for x, val in enumerate(row):
            if val:
                if x_min is None or x_min > x:
                    x_min = x
                if y_min is None or y_min > y:
                    y_min = y
                if x_max is None or x_max < x:
                    x_max = x
                if y_max is None or y_max < y:
                    y_max = y
    return x_min, y_min, x_max, y_max

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0),
                      w,
                      h,
                      edgecolor='yellow',
                      facecolor=(0, 0, 0, 0),
                      lw=2))

# 🧠 Model Configuration

In [None]:
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
predictor = SamPredictor(sam)

In [None]:
# download the image into cv2
IMAGE_URL = "https://doc.modelbit.com/img/groceries.jpg"
resp = urllib.request.urlopen(IMAGE_URL)
image = np.asarray(bytearray(resp.read()))
image = cv2.imdecode(image, -1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# load the image into Segment Anything
predictor.set_image(image)

# 🔥 Inference

In [None]:
def segment_image(point_x, point_y):
    # Get the segment mask
    input_point = np.array([[point_x, point_y]])
    input_label = np.array([1])
    mask, score, logit = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False,
    )
    mask = mask[0]

    # Draw the image, the point we selected, the mask, and the bounding box
    fig = plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    bbox = get_bounding_box(mask)
    show_box(bbox, plt.gca())
    plt.show()
    mb.log_image(fig) # show the segmented image in modelbit logs

    # return the bounding box of our segment
    return bbox

In [None]:
segment_image(400, 275)

# 🚀 Deployment

In [None]:
mb.deploy(segment_image,
          python_packages=[
              "git+https://github.com/facebookresearch/segment-anything.git",
              "matplotlib==3.7.1",
              "numpy==1.22.4",
              "opencv-python-headless==4.8.0.74",
              "torch==2.0.1+cpu",
              "torchvision==0.15.2"
          ])