# Example code of Segment Anything Model (SAM)

Colab 환경에서 SAM 모델을 사용해 이미지에 클릭한 위치의 객체를 segmentation 하는 예제입니다.

## Colab 환경 설정
예제를 실행시키기 위해 python package들을 설치합니다. 예제로 사용할 이미지들도 다운로드 받습니다.

In [None]:
# Local에서 Run하는 경우 False로 변경
using_colab = True

In [None]:
if using_colab:
    !wget https://raw.githubusercontent.com/mrsyee/sam-remove-background/main/jupyternotebook/requirements.txt
    !pip install -r requirements.txt

    # Download examples
    !mkdir examples
    !cd examples && wget https://raw.githubusercontent.com/mrsyee/sam-remove-background/main/assets/examples/mannequin.jpeg

## Import dependency

In [None]:
import os
import urllib
from typing import Tuple

import cv2
import gradio as gr
import numpy as np
import torch
from PIL import Image
from segment_anything import SamPredictor, sam_model_registry

## Set constant

In [2]:
CHECKPOINT_PATH = os.path.join("checkpoint")
CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
MODEL_TYPE = "default"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cpu


## Initialize and load pre-trained SAM

In [3]:
if not os.path.exists(CHECKPOINT_PATH):
    os.makedirs(CHECKPOINT_PATH, exist_ok=True)
checkpoint = os.path.join(CHECKPOINT_PATH, CHECKPOINT_NAME)
if not os.path.exists(checkpoint):
    urllib.request.urlretrieve(CHECKPOINT_URL, checkpoint)
sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint).to(DEVICE)

In [4]:
predictor = SamPredictor(sam)

## Segment with one click

In [5]:
def select_masks(
    masks: np.ndarray, iou_preds: np.ndarray, num_points: int
) -> Tuple [np.ndarray, np.ndarray]:
    # Determine if we should return the multiclick mask or not from the number of points.
    # The reweighting is used to avoid control flow.
    # Reference: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/utils/onnx.py#L92-L105
    score_reweight = np.array([1000] + [0] * 2)
    score = iou_preds + (num_points - 2.5) * score_reweight
    best_idx = np.argmax(score)
    masks = np.expand_dims(masks[best_idx, :, :], axis=-1)
    iou_preds = np.expand_dims(iou_preds[best_idx], axis=0)
    return masks, iou_preds


def segment(image: np.ndarray, point_w: int, point_h: int) -> np.ndarray:
    points_coords = np.array([[point_w, point_h], [0, 0]])
    points_label = np.array([1, -1])

    # Preprocess image and get image embedding with SAM Encoder
    predictor.set_image(image)

    # Inference SAM Decoder model with point information.
    masks, scores, _ = predictor.predict(points_coords, points_label)

    # Select the best mask based on the score.
    mask, _ = select_masks(masks, scores, points_coords.shape[0])
    mask = (mask > 0).astype(np.uint8) * 255
    mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)

    return mask

## UI: Upload image and click

In [None]:
def get_coords(evt: gr.SelectData):
    return evt.index[0], evt.index[1]

def segment_by_click(image: np.ndarray, evt: gr.SelectData):
    click_w, click_h = evt.index
    return segment(image, click_w, click_h)

with gr.Blocks() as app:
    gr.Markdown("# Example of SAM with 1 click")
    with gr.Row():
        coord_w = gr.Number(label="Mouse coords w")
        coord_h = gr.Number(label="Mouse coords h")

    with gr.Row():
        input_img = gr.Image(label="Input image").style(height=600)
        output_img = gr.Image(label="Output image").style(height=600)

    input_img.select(get_coords, None, [coord_w, coord_h])
    input_img.select(segment_by_click, [input_img], output_img)

    gr.Markdown("## Image Examples")
    gr.Examples(
        examples=[
            ["examples/mannequin.jpeg", 1720, 230]
        ],
        inputs=[input_img, coord_h, coord_w],
        outputs=output_img,
        fn=segment,
        run_on_click=True,
    )

In [None]:
app.launch(inline=False, share=True)

In [11]:
app.close()

Closing server running on port: 7860
