# Object masks from prompts with SAM


The Segment Anything Model (SAM) predicts object masks given prompts that indicate the desired object. The model first converts the image into an image embedding that allows high quality masks to be efficiently produced from a prompt.


## Environment Set-up

In [1]:
import torch
import numpy as np
import IPython
import matplotlib.pyplot as plt
import cv2
import io
import PIL.Image as Image
import os
import sys
from torchvision.transforms import GaussianBlur
from tqdm import tqdm

IN_COLAB = 'google.colab' in sys.modules

In [2]:
print("CUDA is available:", torch.cuda.is_available())


!{sys.executable} -m pip install opencv-python
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
!pip install ipympl

sam_checkpoint = "sam_vit_h_4b8939.pth"
face_detection_model = "face_detection_yunet_2023mar.onnx"

if not os.path.isdir('models'):
  !mkdir models
# download SAM
  if not os.path.isfile(f'models/{sam_checkpoint}'):
    !wget https://dl.fbaipublicfiles.com/segment_anything/$sam_checkpoint
    !mv $sam_checkpoint models/

  # download YuNet Faec Detection Model
  if not os.path.isfile(f'models/{face_detection_model}'):
    !wget https://github.com/astaileyyoung/CineFace/raw/main/research/data/face_detection_yunet_2023mar.onnx
    !mv $face_detection_model models/

if not os.path.isdir('images/original_images'):
  !mkdir -p images/original_images

if not os.path.isdir('images/masks'):
  !mkdir -p images/masks

CUDA is available: True
Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-1ga40ceq
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-1ga40ceq
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Preparing metadata (setup.py) ... [?25l[?25hdone


## Program Set-up

Load the SAM model and predictor. Running on CUDA and using the default model are recommended for best results.

In [3]:
from segment_anything import sam_model_registry, SamPredictor

sys.path.append("..")

model_type = "vit_h"
device = "cuda"

sam = sam_model_registry[model_type](checkpoint=f'models/{sam_checkpoint}')
sam.to(device=device)

predictor = SamPredictor(sam)

  state_dict = torch.load(f)


Face detector class, `FaceDetectorYuNet()`, and helper functions for displaying points, boxes, and masks are defined.

In [4]:
class FaceDetectorYunet():
    def __init__(self,
                  model_path='models/face_detection_yunet_2023mar.onnx',
                  img_size=(300, 300),
                  threshold=0.5):
        self.model_path = model_path
        self.img_size = img_size
        self.fd = cv2.FaceDetectorYN_create(str(model_path),
                                            "",
                                            img_size,
                                            score_threshold=threshold)

    def draw_faces(self, image, faces, show_confidence=False):
        for face in faces:
            color = (0, 0, 255)
            thickness = 2
            cv2.rectangle(image, (face['x1'], face['y1']), (face['x2'], face['y2']), color, thickness, cv2.LINE_AA)

            if show_confidence:
                confidence = face['confidence']
                confidence = "{:.2f}".format(confidence)
                position = (face['x1'], face['y1'] - 10)
                font = cv2.FONT_HERSHEY_SIMPLEX
                scale = 0.5
                thickness = 1
                cv2.putText(image, confidence, position, font, scale, color, thickness, cv2.LINE_AA)
        return image

    def scale_coords(self, image, prediction):
        ih, iw = image.shape[:2]
        rw, rh = self.img_size
        a = np.array([
                (prediction['x1'], prediction['y1']),
                (prediction['x1'] + prediction['x2'], prediction['y1'] + prediction['y2'])
                    ])
        b = np.array([iw/rw, ih/rh])
        c = a * b
        prediction['img_width'] = iw
        prediction['img_height'] = ih
        prediction['x1'] = int(c[0,0].round())
        prediction['x2'] = int(c[1,0].round())
        prediction['y1'] = int(c[0,1].round())
        prediction['y2'] = int(c[1,1].round())
        prediction['face_width'] = (c[1,0] - c[0,0])
        prediction['face_height'] = (c[1,1] - c[0,1])
        # prediction['face_width'] = prediction['x2'] - prediction['x1']
        # prediction['face_height'] = prediction['y2'] - prediction['y1']
        prediction['area'] = prediction['face_width'] * prediction['face_height']
        prediction['pct_of_frame'] = prediction['area']/(prediction['img_width'] * prediction['img_height'])
        return prediction

    def detect(self, image):
        if isinstance(image, str):
            image = cv2.imread(str(image))
        img = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
        img = cv2.resize(img, self.img_size)
        self.fd.setInputSize(self.img_size)
        _, faces = self.fd.detect(img)
        if faces is None:
            return None
        else:
            predictions = self.parse_predictions(image, faces)
            return predictions

    def parse_predictions(self,
                          image,
                          faces):
        data = []
        for num, face in enumerate(list(faces)):
            x1, y1, x2, y2 = list(map(int, face[:4]))
            landmarks = list(map(int, face[4:len(face)-1]))
            landmarks = np.array_split(landmarks, len(landmarks) / 2)
            positions = ['left_eye', 'right_eye', 'nose', 'right_mouth', 'left_mouth']
            landmarks = {positions[num]: x.tolist() for num, x in enumerate(landmarks)}
            confidence = face[-1]
            datum = {'x1': x1,
                     'y1': y1,
                     'x2': x2,
                     'y2': y2,
                     'face_num': num,
                     'landmarks': landmarks,
                     'confidence': confidence,
                     'model': 'yunet'}
            d = self.scale_coords(image, datum)
            data.append(d)
        return data


def show_mask(mask, ax, random_color=False, display=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        # color to inpaint (white)
        color = np.array([255, 255, 255])

    # color to keep intact (black)
    background_color = np.array([0, 0, 0])
    h, w = mask.shape[-2:]

    # Reshape the mask to have the same number of color channels
    mask = mask.reshape(h, w, 1)

    # Apply color to the mask where mask is True, and color2 where mask is False
    mask_image = mask * color.reshape(1, 1, -1) + ~mask * background_color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))


## Batch Processing

Extract the compressed file.

In [5]:
# !unzip ./food.zip -d ./images/input_images/
# !rm -rf ./food.zip

!tar -xf original_images.tar.xz
!mv original_images/* images/original_images/
!rm -rf original_images/
!rm -rf ./original_images.tar.xz

In [8]:
!rm -r images/original_images

In [9]:
!rm -r images/masks

Perform batch processing.

In [15]:
def batch_detect_faces(input_img_dir, output_mask_dir, max_num_faces=2, save_boxes=False):
    # Create necessary directories if required
    if save_boxes: os.makedirs('images/box_images', exist_ok=True)

    # Loop through all images in the directory
    for image_path in tqdm(os.listdir(input_img_dir)):
        full_image_path = os.path.join(input_img_dir, image_path)

        # Skip non-image files (e.g., .DS_Store, other system files)
        if image_path.startswith('.') or not image_path.lower().endswith(('.jpg', '.jpeg', '.png')):
            print(f"Skipping non-image file: {image_path}")
            continue

        # Load the image
        img = cv2.imread(full_image_path)

        # Check if the image was loaded successfully
        if img is None:
            print(f"Failed to load image: {full_image_path}")
            continue

        faces = fd.detect(img)

        # If no faces are detected, skip to the next image
        if not faces:
            print(f"No face detected in {image_path}.")
            continue

        # Remove faces with low confidence
        for face in faces:
          if face['confidence'] < 0.55:
            faces.remove(face)

        # Only select faces with max_num_faces highest confidence at most
        faces = sorted(faces, key=lambda x: x['confidence'], reverse=True)[:max_num_faces]

        # If faces become empty after filtering
        if not faces:
            print(f"No eligible faces detected in {image_path}.")
            continue

        # Prepare face center coordinates and labels
        face_centers = [((face['x1'] + face['x2']) / 2, (face['y1'] + face['y2']) / 2) for face in faces]
        input_point = np.array(face_centers)
        input_label = np.ones(len(input_point))

        # Optionally save boxes around faces
        if save_boxes:
            box_file_name = os.path.splitext(os.path.basename(image_path))[0] + "_box.png"
            img_copy = img.copy()
            fd.draw_faces(img_copy, faces, show_confidence=True)
            cv2.imwrite(f'images/box_images/{box_file_name}', img_copy)

        # Process image and make predictions
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        predictor.set_image(img)
        masks, scores, logits = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=True,
        )

        # Select the second mask, which usually includes just face and body parts
        mask_index = 1
        mask_input = logits[mask_index, :, :]

        # Using the second mask, creates a more accurate mask
        masks, _, _ = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            mask_input=mask_input[None, :, :],
            multimask_output=False,
        )

        # Create a mask image
        mask_file_name = os.path.splitext(os.path.basename(image_path))[0] + "_mask.png"
        save_mask_path = os.path.join(output_mask_dir, mask_file_name)
        plt.figure(figsize=(10, 10))
        show_mask(masks, plt.gca())
        plt.axis('off')

        # Save the mask with blur applied
        buf = io.BytesIO()
        plt.savefig(buf, format='png',bbox_inches='tight')
        buf.seek(0)
        plt.close()
        blur = GaussianBlur(kernel_size=23, sigma=20)
        blurred_mask = blur(Image.open(buf))
        blurred_mask.save(save_mask_path)


In [16]:
fd = FaceDetectorYunet()
input_img_dir = '/content/images/original_images/'
output_mask_dir = '/content/images/masks/'
batch_detect_faces(input_img_dir, output_mask_dir)

100%|██████████| 1/1 [00:02<00:00,  2.59s/it]


In [None]:
!zip -r masks.zip /content/images/masks

  adding: content/images/masks/ (stored 0%)
  adding: content/images/masks/US_Festival_11_mask.png (deflated 9%)
  adding: content/images/masks/US_Festival_14_mask.png (deflated 20%)
  adding: content/images/masks/US_Festival_6_mask.png (deflated 11%)
  adding: content/images/masks/US_Festival_7_mask.png (deflated 25%)
  adding: content/images/masks/US_Festival_0_mask.png (deflated 2%)
  adding: content/images/masks/US_Festival_21_mask.png (deflated 13%)
  adding: content/images/masks/US_Festival_27_mask.png (deflated 10%)
  adding: content/images/masks/US_Festival_3_mask.png (deflated 27%)
  adding: content/images/masks/US_Festival_17_mask.png (deflated 3%)
  adding: content/images/masks/US_Festival_33_mask.png (deflated 10%)
  adding: content/images/masks/US_Festival_13_mask.png (deflated 13%)
  adding: content/images/masks/US_Festival_1_mask.png (deflated 17%)
  adding: content/images/masks/US_Festival_24_mask.png (deflated 5%)
  adding: content/images/masks/US_Festival_4_mask.png (

## Individual Processing with Gradio

You can ignore the error regarding pip's dependency when this cell is run.

In [5]:
!pip install -q gradio gradio_image_prompter

Launch a Gradio app.

In [None]:
import gradio as gr
import gradio_image_prompter as gr_ext
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
import gc
from google.colab import files

title = "Manual Masking with Segment Anything Model (SAM)"
header = (
    "<div align='center'>"
    "<h1>Manual Masking with Segment Anything Model (SAM)</h1>"
    "</div>"
)
theme = "soft"
css = """#anno-img .mask {opacity: 0.5; transition: all 0.2s ease-in-out;}
            #anno-img .mask.active {opacity: 0.7}"""

blur = GaussianBlur(kernel_size=23, sigma=20)

def on_click_submit_btn(click_input_img):
    predictor.set_image(click_input_img['image'])
    np_points = np.array(click_input_img['points'])


    # Get only points where the last column ([:, 5]) is 4 (click points)
    point_condition = (np_points[:, 5] == 4)
    input_points = np_points[point_condition][:, :2]  # Get x,y coordinates

    # Get unique coordinates as a list of tuples
    unique_tuples = []
    for point in input_points:
        point_tuple = tuple(point)
        if point_tuple not in unique_tuples:
            unique_tuples.append(point_tuple)
    input_points = np.array(unique_tuples)

    # All points are positive points (label = 1)
    input_labels = np.ones(len(input_points))

    # Get prediction from SAM
    masks, _, _ = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        multimask_output=True,
    )

    # Create a black blank canvas
    mask_all = np.zeros((click_input_img['image'].shape[0], click_input_img['image'].shape[1], 3, len(masks)))

    # Apply mask
    for i in range(len(masks)):
      white_mask = (np.array([255, 255, 255]) / 255).tolist()
      mask_all[masks[i], :, i] = white_mask

    # Convert the NumPy array to a PyTorch tensor to apply GaussianBlur
    for i in range(len(masks)):
      tensor_image = torch.from_numpy(mask_all[:, :, :, i]).permute(2, 0, 1)  # Change to (C, H, W) for PyTorch
      blurred_tensor = blur(tensor_image)
      mask_all[:, :, :, i] = blurred_tensor.permute(1, 2, 0).numpy()  # Change back to (H, W, C) NumPy array

    gc.collect()
    torch.cuda.empty_cache()

    return mask_all[..., 0], mask_all[..., 1], mask_all[..., 2]

def on_click_save_btn(mask):
    mask_dir = 'images/masks'
    num_files = len([f for f in os.listdir(mask_dir) if os.path.isfile(os.path.join(mask_dir, f))])
    mask_path = os.path.join(mask_dir, f"mask_{num_files}.png")
    Image.fromarray(mask).save(mask_path)
    files.download(mask_path)
    gr.Info(f'Mask successfully saved as {mask_path}. All the masks will be downloaded together once you stop running the cell.', duration=13)

    return None

with gr.Blocks(title=title, theme=theme, css=css) as demo:
    gr.Markdown(header)

    gr.Markdown("""
      Manually select the objects to be masked by clicking on it.      -
      - Click `Submit` **after clicking at least once** to receive three different masks.
      - Click one of the three buttons (`Mask 1`, `Mask 2`, `Mask 3`) to save and download. Downloading will begin when the cell that initiatied the demo is stopped.
      - With multiple people, some faces clicked may not be fully masked. Then, click another point on the face again **without clearing** and `submit` again.
      - In general, the mask of `Mask 1`, `Mask 2`, and `Mask 3` becomes increasingly inclusive. For example, while `Mask 1` only masks the face and `Mask 3` masks the whole body given **a coordinate on the face**.
        - Therefore, clicking once only on the face is enough to capture other exposed body parts like hands, which is usually generated as `Mask 2`.
    """)
    with gr.Row():
        with gr.Column():
          click_input_img = gr_ext.ImagePrompter(
              show_label=True,
              label="Input Image",
              interactive=True,
              sources='upload'
          )
        with gr.Column():
          with gr.Tab("Mask 1"):
              output_mask_1 = gr.Image(
                  interactive=False,
                  show_label=False,
                  show_download_button=False
                  )
          with gr.Tab("Mask 2"):
              output_mask_2 = gr.Image(
                  show_label=False,
                  interactive=False,
                  show_download_button=False
                  )
          with gr.Tab("Mask 3"):
              output_mask_3 = gr.Image(
                  show_label=False,
                  interactive=False,
                  show_download_button=False
                  )

    with gr.Row():
            click_save_btn_1 = gr.Button("Mask 1")
            click_save_btn_2 = gr.Button("Mask 2")
            click_save_btn_3 = gr.Button("Mask 3")

    with gr.Row():
            click_clr_btn=gr.ClearButton(components=[click_input_img, output_mask_1, output_mask_2, output_mask_3])
            click_submit_btn = gr.Button("Submit")

    click_submit_btn.click(
        fn=on_click_submit_btn,
        inputs=[click_input_img],
        outputs=[output_mask_1, output_mask_2, output_mask_3]
    )

    click_save_btn_1.click(
        fn=on_click_save_btn,
        inputs=[output_mask_1],
        outputs=None
    )

    click_save_btn_2.click(
        fn=on_click_save_btn,
        inputs=[output_mask_2],
        outputs=None
    )

    click_save_btn_3.click(
        fn=on_click_save_btn,
        inputs=[output_mask_3],
        outputs=None
    )


if __name__ == "__main__":
    demo.launch(debug=True)

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://820c5cdcb5a2329625.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


In [None]:
!zip -r masks.zip /content/images/masks