<a href="https://colab.research.google.com/github/1ucky40nc3/ml4me/blob/main/vision/HQ-SAM/run_hq_sam_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Copyright 2023 Louis Wendler

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# A [**Segment Anything in High Quality (HQ-SAM)**](https://github.com/syscv/sam-hq) Demo


## Set up the Notebook

In [None]:
!nvidia-smi

In [None]:
# @title Clone the [`HQ-SAM`](https://github.com/syscv/sam-hq) Repository
!git clone https://github.com/SysCV/sam-hq.git

In [None]:
# @title Download a Pretrained SAM Model
import os


checkpoint_dir = "/content/ckpts"
checkpoint_url = "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth"
checkpoint_path = os.path.join(
    checkpoint_dir,
    os.path.split(checkpoint_url)[-1]
)

!mkdir -p $checkpoint_dir
!wget $checkpoint_url -P $checkpoint_dir

In [None]:
# @title Implement Utils
import os

import numpy as np

import cv2
import matplotlib.pyplot as plt
from matplotlib.axes import Axes


def show_mask(mask: np.ndarray, ax: Axes) -> None:
    color = np.concatenate([
            np.random.random(3),
            np.array([0.6])
        ],
        axis=0
    )
    h, w = mask.shape[-2:]
    mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask)


def show_points(
    coords: np.ndarray,
    labels: np.ndarray,
    ax: Axes,
    marker_size: int=375
) -> None:
    for i, color in ['red', 'green']:
        points = coords[labels==i]
        ax.scatter(
            points[:, 0],
            points[:, 1],
            color=color,
            marker='*',
            s=marker_size,
            edgecolor='white',
            linewidth=1.25
        )


def show_results(
    masks: np.ndarray,
    scores: np.ndarray,
    point_coords: np.ndarray,
    point_labels: np.ndarray,
    image: np.ndarray
) -> None:
    for i, (mask, score) in enumerate(zip(masks, scores)):
        print(f"Score: {score:.3f}")

        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_points(point_coords, point_labels, plt.gca())
        plt.axis('off')
        plt.show()

In [None]:
# @title Load a Pretrained SAM Model
%cd /content/sam-hq

import torch
from segment_anything import sam_model_registry, SamPredictor


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# @markdown Select the model type
model_type = "vit_l" # @param ['default', 'vit_h', 'vit_l', 'vit_b']
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
sam.to(device=device)
predictor = SamPredictor(sam)

## Do some Segementation!

In [None]:
# @title Prepare the SAM Prompts
import IPython
from IPython.display import HTML

from google.colab import output

from base64 import b64encode


def load_data_url(path: str) -> str:
    mediatype_map = {
        ".png": "image/png",
        ".jpg": "image/jpeg"
    }
    ext = os.path.splitext(path)[-1]
    mediatype = mediatype_map[ext]
    data = open(path, 'rb').read()
    data = b64encode(data).decode()
    return f"data:{mediatype};base64,{data}"


coordinates = []
labels = []

def click_coordinates_callback(x: int, y: int, left_click: bool) -> None:
    global coordinates
    coordinates.append((x, y))
    global labels
    label = int(left_click)
    labels.append(label)


output.register_callback(
    "notebook.ClickCoordinates",
    click_coordinates_callback
)

# @markdown Select an input image
image_path = "/content/sam-hq/demo/input_imgs/example0.png" # @param {type: "string"}
# @markdown Set the image width
width = 100 # @param {type: "number"}

# @markdown ---

# @markdown Select image pixels as prompts for the SAM model:

# @markdown `-->` Click on the image below:
# @markdown
# @markdown *   `Left` click: Positive prompt (segmentation target)
# @markdown *   `Right` click: Negative prompt (avoid segmentation)

# @markdown **Hint**: We show your prompts after processing the prompt-image-pair in the cell below.

display(IPython.display.HTML('''
<img src="%s" width={width}/>
<script>
    document.querySelector("img").addEventListener("click", function(event) {
        var x = event.pageX - this.offsetLeft;
        var y = event.pageY - this.offsetTop;

        var isRightClick = false;
        if ("which" in event)  // Gecko (Firefox), WebKit (Safari/Chrome) & Opera
            isRightClick = event.which == 3;
        else if ("button" in event)  // IE, Opera
            isRightClick = event.button == 2;

        google.colab.kernel.invokeFunction(
            'notebook.ClickCoordinates',
            [x, y, !isRightClick],
            {}
        );
    });
    document.querySelector("img").addEventListener("contextmenu", function(event) {
        var x = event.pageX - this.offsetLeft;
        var y = event.pageY - this.offsetTop;

        google.colab.kernel.invokeFunction(
            'notebook.ClickCoordinates',
            [x, y, false],
            {}
        );
    });
</script>
''' % load_data_url(image_path)))

In [None]:
# @title Run the SAM Inference
def load_img(path: str) -> np.ndarray:
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img


image = load_img(image_path)
predictor.set_image(image)

point_coords = np.array(coordinates)
point_labels = np.array(labels)
masks, scores, logits = predictor.predict(
    point_coords=point_coords,
    point_labels=point_labels,
    multimask_output=False,
    hq_token_only=False,
)
show_results(masks, scores, point_coords, point_labels, image)

## Acknowledgments

---

Thanks to the original [HQ-SAM](https://github.com/syscv/sam-hq) and [SAM](https://github.com/facebookresearch/segment-anything) authors!