# Object masks from prompts with SAM

To run in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1A0LpGaDId6-7_okw47YXkqoXNoetHVp8)

The Segment Anything Model (SAM) predicts object masks given prompts that indicate the desired object. The model first converts the image into an image embedding that allows high quality masks to be efficiently produced from a prompt.


## Environment Set-up

In [1]:
import torch
import numpy as np
import IPython
import matplotlib.pyplot as plt
import cv2
import io
import PIL.Image as Image
import os
import sys
from torchvision.transforms import GaussianBlur
from tqdm import tqdm

IN_COLAB = 'google.colab' in sys.modules

In [None]:
print("CUDA is available:", torch.cuda.is_available())


!{sys.executable} -m pip install opencv-python
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
!pip install ipympl

sam_checkpoint = "sam_vit_h_4b8939.pth"
face_detection_model = "face_detection_yunet_2023mar.onnx"

if not os.path.isdir('models'):
  !mkdir models
# download SAM
  if not os.path.isfile(f'models/{sam_checkpoint}'):
    !wget https://dl.fbaipublicfiles.com/segment_anything/$sam_checkpoint
    !mv $sam_checkpoint models/

  # download YuNet Faec Detection Model
  if not os.path.isfile(f'models/{face_detection_model}'):
    !wget https://github.com/astaileyyoung/CineFace/raw/main/research/data/face_detection_yunet_2023mar.onnx
    !mv $face_detection_model models/

if not os.path.isdir('images/input_images'):
  !mkdir -p images/input_images

if not os.path.isdir('images/mask_images'):
  !mkdir -p images/mask_images

CUDA is available: True
Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-7hg7y3uq
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-7hg7y3uq
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: segment_anything
  Building wheel for segment_anything (setup.py) ... [?25l[?25hdone
  Created wheel for segment_anything: filename=segment_anything-1.0-py3-none-any.whl size=36592 sha256=f13a8ed92d20917fe895db8b310f6d38083b96adebaeff5bb23e21f1a1af54c1
  Stored in directory: /tmp/pip-ephem-wheel-cache-hi9wnjsd/wheels/10/cf/59/9ccb2f0a1bcc81d4fbd0e501680b5d088d690c6cfbc02dc99d
Successfully built segment_anything
Installing collected packages: segmen

## Program Set-up

Necessary imports and helper functions for displaying points, boxes, and masks.

In [3]:
def show_mask(mask, ax, random_color=False, display=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        # color to inpaint (white)
        # color = np.array([255, 255, 255])
        color = np.array([0, 0, 0])

    # color to keep intact (black)
    background_color = np.array([255, 255, 255])
    h, w = mask.shape[-2:]

    # Reshape the mask to have the same number of color channels
    mask = mask.reshape(h, w, 1)

    # Apply color to the mask where mask is True, and color2 where mask is False
    mask_image = mask * color.reshape(1, 1, -1) + ~mask * background_color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))


First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results.

In [4]:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

model_type = "vit_h"
device = "cuda"

sam = sam_model_registry[model_type](checkpoint=f'models/{sam_checkpoint}')
sam.to(device=device)

predictor = SamPredictor(sam)

  state_dict = torch.load(f)


Face detector class, `FaceDetectorYuNet()`, is constructed.

In [5]:
class FaceDetectorYunet():
    def __init__(self,
                  model_path='models/face_detection_yunet_2023mar.onnx',
                  img_size=(300, 300),
                  threshold=0.5):
        self.model_path = model_path
        self.img_size = img_size
        self.fd = cv2.FaceDetectorYN_create(str(model_path),
                                            "",
                                            img_size,
                                            score_threshold=threshold)

    def draw_faces(self,
                   image,
                   faces,
                   show_confidence=False):
        for face in faces:
            color = (0, 0, 255)
            thickness = 2
            cv2.rectangle(image, (face['x1'], face['y1']), (face['x2'], face['y2']), color, thickness, cv2.LINE_AA)

            if show_confidence:
                confidence = face['confidence']
                confidence = "{:.2f}".format(confidence)
                position = (face['x1'], face['y1'] - 10)
                font = cv2.FONT_HERSHEY_SIMPLEX
                scale = 0.5
                thickness = 1
                cv2.putText(image, confidence, position, font, scale, color, thickness, cv2.LINE_AA)
        return image

    def scale_coords(self, image, prediction):
        ih, iw = image.shape[:2]
        rw, rh = self.img_size
        a = np.array([
                (prediction['x1'], prediction['y1']),
                (prediction['x1'] + prediction['x2'], prediction['y1'] + prediction['y2'])
                    ])
        b = np.array([iw/rw, ih/rh])
        c = a * b
        prediction['img_width'] = iw
        prediction['img_height'] = ih
        prediction['x1'] = int(c[0,0].round())
        prediction['x2'] = int(c[1,0].round())
        prediction['y1'] = int(c[0,1].round())
        prediction['y2'] = int(c[1,1].round())
        prediction['face_width'] = (c[1,0] - c[0,0])
        prediction['face_height'] = (c[1,1] - c[0,1])
        # prediction['face_width'] = prediction['x2'] - prediction['x1']
        # prediction['face_height'] = prediction['y2'] - prediction['y1']
        prediction['area'] = prediction['face_width'] * prediction['face_height']
        prediction['pct_of_frame'] = prediction['area']/(prediction['img_width'] * prediction['img_height'])
        return prediction

    def detect(self, image):
        if isinstance(image, str):
            image = cv2.imread(str(image))
        img = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
        img = cv2.resize(img, self.img_size)
        self.fd.setInputSize(self.img_size)
        _, faces = self.fd.detect(img)
        if faces is None:
            return None
        else:
            predictions = self.parse_predictions(image, faces)
            return predictions

    def parse_predictions(self,
                          image,
                          faces):
        data = []
        for num, face in enumerate(list(faces)):
            x1, y1, x2, y2 = list(map(int, face[:4]))
            landmarks = list(map(int, face[4:len(face)-1]))
            landmarks = np.array_split(landmarks, len(landmarks) / 2)
            positions = ['left_eye', 'right_eye', 'nose', 'right_mouth', 'left_mouth']
            landmarks = {positions[num]: x.tolist() for num, x in enumerate(landmarks)}
            confidence = face[-1]
            datum = {'x1': x1,
                     'y1': y1,
                     'x2': x2,
                     'y2': y2,
                     'face_num': num,
                     'landmarks': landmarks,
                     'confidence': confidence,
                     'model': 'yunet'}
            d = self.scale_coords(image, datum)
            data.append(d)
        return data


# Batch Processing

In [None]:
!gdown 1zauKTEZWSwohBKp79lVIs0oLVP5nx2-3
!unzip input_images.zip
!mv input_images/* images/input_images/
!rm -rf input_images.zip
!rm -rf input_images/

In [8]:
def detect_faces(image_dir_path, save_boxes=False):
    # Create necessary directories if required
    if save_boxes: os.makedirs('images/box_images', exist_ok=True)

    # Loop through all images in the directory
    for image_path in tqdm(os.listdir(image_dir_path)):
        img = cv2.imread(os.path.join(image_dir_path, image_path))
        faces = fd.detect(img)

        # If no faces are detected, skip to the next image
        if not faces:
            print(f"\nNo face detected in {image_path}.")
            continue

        # Prepare face center coordinates and labels
        face_centers = [((face['x1'] + face['x2']) / 2, (face['y1'] + face['y2']) / 2) for face in faces]
        input_point = np.array(face_centers)
        input_label = np.ones(len(input_point))

        # Optionally save boxes around faces
        if save_boxes:
            box_file_name = os.path.splitext(os.path.basename(image_path))[0] + "_box.png"
            img_copy = img.copy()
            fd.draw_faces(img_copy, faces, show_confidence=True)
            cv2.imwrite(f'images/box_images/{box_file_name}', img_copy)


        # Process image and make predictions
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        predictor.set_image(img)
        masks, scores, logits = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=True,
        )

        # Select the best mask
        mask_index = np.argmax(scores)
        mask_input = logits[mask_index, :, :]

        # Using the best mask, creates a more accurate mask
        masks, _, _ = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            mask_input=mask_input[None, :, :],
            multimask_output=False,
        )

        # Create a mask image
        mask_file_name = os.path.splitext(os.path.basename(image_path))[0] + "_mask.png"
        plt.figure(figsize=(10, 10))
        show_mask(masks, plt.gca())
        plt.axis('off')

        # Save the mask with blur applied
        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight')
        buf.seek(0)
        plt.close()
        blur = GaussianBlur(11, 20)
        blurred_mask = blur(Image.open(buf))
        blurred_mask.save(f'images/mask_images/{mask_file_name}')

In [9]:
fd = FaceDetectorYunet()
image_dir_path = '/content/images/input_images'
detect_faces(image_dir_path, save_boxes=True)

 50%|█████     | 7/14 [00:19<00:16,  2.41s/it]


No face detected in inle.png.


100%|██████████| 14/14 [00:33<00:00,  2.37s/it]


# Individual Testing

## Selecting objects with SAM

Run either option 1 (automatic) or option 2 (manual) to select faces.

In [11]:
image_path = '/content/images/input_images/chan-myanmar.jpeg'
image = Image.open(image_path)  # Replace with your image path

### [Option 1] Automatic Face Detection

Get the coordinate of the mid point of the bounding box detecting the face. Toggle `show_drawn_face` to see the drawn box on the face.

In [None]:
fd = FaceDetectorYunet()
show_drawn_face = True

# read image and detect faces
img = cv2.imread(image_path)
faces = fd.detect(img)
face_centers = [((face['x1']+face['x2'])/2, (face['y1']+face['y2'])/2) for face in faces]

input_point = np.array(face_centers)
input_label = np.ones(len(input_point))

if show_drawn_face:
    # draw rectangle on faces
    if faces:
      fd.draw_faces(img, faces, show_confidence=True)

    # show image
    _, ret = cv2.imencode('.jpg', img)
    i = IPython.display.Image(data=ret)
    IPython.display.display(i)

### [Option 2] Manual Face Selection

Running the following cell, manually select the object to be masked by clicking on it. The `(x, y)` coordinate of the clicked point will be recorded. The more the points provided for the object, the less ambiguous the model is in determining the object of your interest.

In [None]:
%matplotlib widget

if IN_COLAB:
    from google.colab import output
    output.enable_custom_widget_manager()

fig, ax = plt.subplots()
ax.imshow(image)
plt.show()

coordinates = list()

def onclick(event):
    print('xdata=%f, ydata=%f' % (event.xdata, event.ydata))
    coordinates.append((event.xdata, event.ydata))

cid = fig.canvas.mpl_connect('button_press_event', onclick)

In [13]:
input_point = np.array(coordinates)
input_label = np.ones(len(input_point))

### Visualize the points being recorded (Optional)

In [None]:
%matplotlib inline

plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()

## Segmenting based on Selected Points

Process the image to produce an image embedding by calling `SamPredictor.set_image`. `SamPredictor` remembers this embedding and will use it for subsequent mask prediction. The model returns **mask**, **quality predictions (scores)** for those masks, and low resolution mask **logits** that can be passed to the next iteration of prediction.

With `multimask_output=True` (the default setting), SAM outputs 3 masks, where `scores` gives the model's own estimation of the quality of these masks. This setting is intended for ambiguous input prompts, and helps the model disambiguate different objects consistent with the prompt.
- When `False`, it will return a single mask.
- For ambiguous prompts such as a single point, it is recommended to use `multimask_output=True` even if only a single mask is desired; the best single mask can be chosen by picking the one with the highest score returned in `scores`.

In [None]:
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()

 A mask from a previous iteration can also be supplied to the model to aid in prediction. When specifying a single object with multiple prompts, a single mask can be requested by setting `multimask_output=False`.

In [None]:
# Comment out one of the two following lines based on your choice
mask_index = 1 # manually select preferred mask (0, 1, or 2)
# mask_index = np.argmax(scores) # select mask with highest score
assert mask_index in range(3), "Select either mask 1 [index 0], mask 2 [index 1], or mask3 [index 2]"

mask_input = logits[mask_index, :, :]  # Choose the model's best mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)

# Show the selected mask
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
plt.axis('off')

image_file_name = os.path.splitext(os.path.basename(image_path))[0] + "_mask.png"

# Create a bytes buffer to save the plot
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)

# Apply blur kernel to improve inpainting quality
blur = GaussianBlur(11,20)
blurred_mask = blur(Image.open(buf))

# save and show
blurred_mask.save(f'images/mask_images/{image_file_name}')
plt.show()