# 배경 제거 Application

Colab 환경에서 배경 제거 애플리케이션을 만들어봅시다. 애플리케이션 사용자의 유스케이스는 아래와 같습니다.

- 사용자는 이미지 파일을 업로드할 수 있다.
- 사용자는 이미지에서 원하는 객체 클릭한다.
- 사용자는 배경 제거 이미지의 결과를 확인하고 다운로드 받을 수 있다.

## Colab 환경 설정
python package들을 설치합니다. 예제로 사용할 이미지들도 다운로드 받습니다.

In [None]:
# Local에서 Run하는 경우 False로 변경
using_colab = True

In [None]:
if using_colab:
    !wget https://raw.githubusercontent.com/mrsyee/dl_apps/main/segmentation/requirements.txt
    !pip install -r requirements.txt
    !wget https://raw.githubusercontent.com/mrsyee/dl_apps/main/segmentation/app.py

    !mkdir examples
    !cd examples && wget https://github.com/mrsyee/dl_apps/raw/main/segmentation/examples/dog.jpg
    !cd examples && wget https://github.com/mrsyee/dl_apps/raw/main/segmentation/examples/mannequin.jpg

## Import dependency

In [None]:
import os
import urllib
from typing import Tuple

import cv2
import gradio as gr
import numpy as np
import torch
from PIL import Image
from segment_anything import SamPredictor, sam_model_registry

## UI 구성

In [None]:
with gr.Blocks() as app:
    gr.Markdown("# Interactive Remove Background from Image")
    with gr.Row():
        coord_x = gr.Number(label="Mouse coords x")
        coord_y = gr.Number(label="Mouse coords y")

    with gr.Row():
        input_img = gr.Image(label="Input image").style(height=600)
        output_img = gr.Image(label="Output image").style(height=600)

In [None]:
app.launch(inline=False, share=True)

In [None]:
app.close()

## 마우스 클릭 이벤트

In [None]:
def get_coords(evt: gr.SelectData):
    return evt.index[0], evt.index[1]

with gr.Blocks() as app:
    gr.Markdown("# Interactive Remove Background from Image")
    with gr.Row():
        coord_x = gr.Number(label="Mouse coords x")
        coord_y = gr.Number(label="Mouse coords y")

    with gr.Row():
        input_img = gr.Image(label="Input image").style(height=600)
        output_img = gr.Image(label="Output image").style(height=600)

    input_img.select(get_coords, None, [coord_x, coord_y])

In [None]:
app.launch(inline=False, share=True)

In [None]:
app.close()

## SAM 추론기

In [None]:
CHECKPOINT_PATH = os.path.join("checkpoint")
CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
MODEL_TYPE = "default"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class SAMInferencer:
    def __init__(
        self,
        checkpoint_path: str,
        checkpoint_name: str,
        checkpoint_url: str,
        model_type: str,
        device: torch.device,
    ):
        print("[INFO] Initailize inferencer")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path, exist_ok=True)
        checkpoint = os.path.join(checkpoint_path, checkpoint_name)
        if not os.path.exists(checkpoint):
            urllib.request.urlretrieve(checkpoint_url, checkpoint)
        sam = sam_model_registry[model_type](checkpoint=checkpoint).to(device)
        self.predictor = SamPredictor(sam)

    def inference(
        self,
        image: np.ndarray,
        point_coords: np.ndarray,
        points_labels: np.ndarray,
    ) -> np.ndarray:
        self.predictor.set_image(image)
        masks, scores, _ = self.predictor.predict(point_coords, points_labels)
        mask, score = self.select_masks(masks, scores, point_coords.shape[0])
        return mask

    def select_masks(
        self, masks: np.ndarray, iou_preds: np.ndarray, num_points: int
    ) -> Tuple [np.ndarray, np.ndarray]:
        # Determine if we should return the multiclick mask or not from the number of points.
        # The reweighting is used to avoid control flow.
        # Reference: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/utils/onnx.py#L92-L105
        score_reweight = np.array([1000] + [0] * 2)
        score = iou_preds + (num_points - 2.5) * score_reweight
        best_idx = np.argmax(score)
        masks = np.expand_dims(masks[best_idx, :, :], axis=-1)
        iou_preds = np.expand_dims(iou_preds[best_idx], axis=0)
        return masks, iou_preds


inferencer = SAMInferencer(
    CHECKPOINT_PATH, CHECKPOINT_NAME, CHECKPOINT_URL, MODEL_TYPE, DEVICE
)

## 배경 제거 후처리

In [None]:
def extract_object(image: np.ndarray, point_h: int, point_w: int):
    point_coords = np.array([[point_h, point_w], [0, 0]])
    point_label = np.array([1, -1])

    # Get mask
    mask = inferencer.inference(image, point_coords, point_label)

    # Extract object
    mask = mask.astype(np.uint8) * 255
    segmented_image = cv2.bitwise_and(image, image, mask=mask)

    return segmented_image


def extract_object_by_event(image: np.ndarray, evt: gr.SelectData):
    click_h, click_w = evt.index

    return extract_object(image, click_h, click_w)

In [None]:
with gr.Blocks() as demo:
    gr.Markdown("# Interactive Extracting Object from Image")
    with gr.Row():
        coord_h = gr.Number(label="Mouse coords h")
        coord_w = gr.Number(label="Mouse coords w")

    with gr.Row():
        input_img = gr.Image(label="Input image").style(height=600)
        output_img = gr.Image(label="Output image").style(height=600)

    input_img.select(extract_object_by_event, [input_img], output_img)
    input_img.select(get_coords, None, [coord_h, coord_w])

In [None]:
app.launch(inline=False, share=True)

In [None]:
app.close()