# Object masks from prompts with SAM

author: https://github.com/Fluorite-Eyes
2023/10/26

In [1]:
from IPython.display import display, HTML
display(HTML(
"""
<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
"""
))

If running locally using jupyter, first install segment_anything in your environment using the installation instructions in the repository. If running from Google Colab, set using_colab=True below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'.

In [2]:
using_colab = False

In [3]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
    
    !mkdir images
    !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg
    !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/groceries.jpg
        
    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [4]:
import numpy as np
import torch
import sys
import matplotlib.pyplot as plt
import cv2

# 加载自定义背景图片
image = cv2.imread('images/185.jpg')
if image is None:
    print("Could not load the custom background image.")
    exit()

# Selecting objects with SAM

First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results.

In [5]:
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

Process the image to produce an image embedding by calling SamPredictor.set_image. SamPredictor remembers this embedding and will use it for subsequent mask prediction.

In [6]:
predictor.set_image(image)

In [7]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    


In [8]:
# 调整背景图片的大小以匹配画布
image = cv2.imread('images/185.jpg')
canvas = image.copy()
canvas_backup = canvas.copy()  # 创建canvas的备份

# 用于记录左键和右键点击的坐标
left_clicks = []
right_clicks = []

# 回调函数，用于处理鼠标事件
def mouse_callback(event, x, y, flags, param):
    global canvas
    global canvas_backup
    if event == cv2.EVENT_LBUTTONDOWN:
        left_clicks.append((x, y))
        # 在鼠标点击的位置绘制一个蓝色点
        cv2.circle(canvas, (x, y), 5, (255, 105, 180), -1)
        positive_point = np.array(left_clicks)
        negtive_point = np.array(right_clicks)
        print("positive_point: ", positive_point)
        if negtive_point.size != 0:
            input_point = np.concatenate((positive_point, negtive_point), axis = 0)
        else:
            input_point = positive_point
        print("input_point: ", input_point)
        positive_nbr = len(positive_point)
        negtive_nbr = len(negtive_point)
        input_label = np.zeros(positive_nbr + negtive_nbr)
        input_label[:positive_nbr] = 1
        print("input_label: ", input_label)
        
        masks, scores, logits = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=True,
        )
        mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask
        masks, _, _ = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            mask_input=mask_input[None, :, :],
            multimask_output=False,
        )

        # 使用 masks[0] 作为掩码
        mask = masks[0]

        # 创建颜色覆盖（例如，红色覆盖）
        color_overlay = np.zeros_like(canvas)
        color_overlay[:, :] = (0, 0, 100)  # 这里使用红色，你可以更改为其他颜色

        # 将掩码区域用颜色覆盖
        canvas = canvas_backup.copy()
        canvas[mask] = color_overlay[mask]
        
        for point in positive_point:
            cv2.circle(canvas, tuple(point), 5, (255, 105, 180), -1)

        # 在图像上显示 negative_point 中的点
        for point in negtive_point:
            cv2.circle(canvas, tuple(point), 5, (100,149,237), -1)
        
    elif event == cv2.EVENT_RBUTTONDOWN:
        right_clicks.append((x, y))
        # 在鼠标点击的位置绘制一个绿色点
        cv2.circle(canvas, (x, y), 5, (100,149,237), -1)
        positive_point = np.array(left_clicks)
        negtive_point = np.array(right_clicks)
        print("negtive_point: ", negtive_point)
        input_point = np.concatenate((positive_point, negtive_point), axis = 0)
        positive_nbr = len(positive_point)
        negtive_nbr = len(negtive_point)
        input_label = np.zeros(positive_nbr + negtive_nbr)
        input_label[:positive_nbr] = 1
        print("input_label: ", input_label)
        
        
        masks, scores, logits = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=True,
        )
        mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask
        masks, _, _ = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            mask_input=mask_input[None, :, :],
            multimask_output=False,
        )

        # 使用 masks[0] 作为掩码
        mask = masks[0]
        print(masks)

        # 创建颜色覆盖（例如，红色覆盖）
        color_overlay = np.zeros_like(canvas)
        color_overlay[:, :] = (0, 0, 100)  # 这里使用红色，你可以更改为其他颜色

        # 将掩码区域用颜色覆盖
        canvas = canvas_backup.copy()
        canvas[mask] = color_overlay[mask]

        for point in positive_point:
            cv2.circle(canvas, tuple(point), 5, (255, 105, 180), -1)

        # 在图像上显示 negative_point 中的点
        for point in negtive_point:
            cv2.circle(canvas, tuple(point), 5, (100,149,237), -1)

# 创建窗口并设置鼠标回调函数
cv2.namedWindow("Canvas")
cv2.setMouseCallback("Canvas", mouse_callback)

while True:
    cv2.imshow("Canvas", canvas)
    key = cv2.waitKey(1) & 0xFF
    if key == 27:  # 按下Esc键退出
        break

# 输出左键和右键点击的坐标
print("Positive Points:", left_clicks)
print("Negtive Points:", right_clicks)

cv2.destroyAllWindows()



positive_point:  [[240 231]]
input_point:  [[240 231]]
input_label:  [1.]
negtive_point:  [[475 236]]
input_label:  [1. 0.]
[[[False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]
  ...
  [False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]]]
positive_point:  [[240 231]
 [500 164]]
input_point:  [[240 231]
 [500 164]
 [475 236]]
input_label:  [1. 1. 0.]
Positive Points: [(240, 231), (500, 164)]
Negtive Points: [(475, 236)]


In [9]:
from scipy.ndimage import zoom
class Segmentix:

    def resize_mask(
            self, ref_mask: np.ndarray, longest_side: int = 256
    ) -> tuple[np.ndarray, int, int]:
        """
        Resize an image to have its longest side equal to the specified value.

        Args:
            ref_mask (np.ndarray): The image to be resized.
            longest_side (int, optional): The length of the longest side after resizing. Default is 256.

        Returns:
            tuple[np.ndarray, int, int]: The resized image and its new height and width.
        """
        height, width = ref_mask.shape[:2]
        print("dfvsxvxvxcv",height, width, ref_mask.shape)
        if height > width:
            new_height = 256
            new_width = int(width * (256 / height))
            scale_factor_height = new_height / ref_mask.shape[0]
            scale_factor_width = new_width / ref_mask.shape[1]
            ref_mask = zoom(ref_mask, zoom=(scale_factor_height, scale_factor_width), order=3)
            if ref_mask.shape[0] > 256:
                ref_mask = np.resize(ref_mask, (256, new_width))
            print("sdfsfd", ref_mask.shape)
        else:
            new_width = 256
            new_height = int(height * (256 / width))
            scale_factor_height = new_height / ref_mask.shape[0]
            scale_factor_width = new_width / ref_mask.shape[1]
            ref_mask = zoom(ref_mask, zoom=(scale_factor_height, scale_factor_width), order=3)
            print("sdfsfd", ref_mask.shape)
            if ref_mask.shape[1] > 256:
                ref_mask = np.resize(ref_mask, (new_height, 256))
        return (
            ref_mask,
            new_height,
            new_width
        )

    def pad_mask(
        self,
        ref_mask: np.ndarray,
        new_height: int,
        new_width: int,
        pad_all_sides: bool = False,
    ) -> np.ndarray:
        """
        Add padding to an image to make it square.

        Args:
            ref_mask (np.ndarray): The image to be padded.
            new_height (int): The height of the image after resizing.
            new_width (int): The width of the image after resizing.
            pad_all_sides (bool, optional): Whether to pad all sides of the image equally. If False, padding will be added to the bottom and right sides. Default is False.

        Returns:
            np.ndarray: The padded image.
        """
        result = np.full((256, 256), -1)
        result[-new_height:, -new_width:] = ref_mask
        print("Asfasfafsasf",result)
        return result

    def reference_to_sam_mask(
            self, ref_mask: np.ndarray, threshold: int = 127, pad_all_sides: bool = False
    ) -> np.ndarray:
        """
        Convert a grayscale mask to a binary mask, resize it to have its longest side equal to 256, and add padding to make it square.

        Args:
            ref_mask (np.ndarray): The grayscale mask to be processed.
            threshold (int, optional): The threshold value for the binarization. Default is 127.
            pad_all_sides (bool, optional): Whether to pad all sides of the image equally. If False, padding will be added to the bottom and right sides. Default is False.

        Returns:
            np.ndarray: The processed binary mask.
        """

        # Resize to have the longest side 256.
        resized_mask, new_height, new_width = self.resize_mask(ref_mask)

        # Add padding to make it square.
        square_mask = self.pad_mask(resized_mask, new_height, new_width, pad_all_sides)

        # Expand SAM mask's dimensions to 1xHxW (1x256x256).
        return square_mask

In [10]:

canvas = image.copy()
canvas_backup = canvas.copy()  # 创建canvas的备份
# 用于记录鼠标点击的坐标
points = []
i = 0
# 回调函数，用于处理鼠标事件
def mouse_callback2(event, x, y, flags, param):
    global points, canvas, canvas_backup, mask, ref_mask

    if event == cv2.EVENT_LBUTTONDOWN:
        canvas = canvas_backup.copy()
        points.append((x, y))
        cv2.circle(canvas, (x, y), 3, (255, 255, 255), -1)
        # 将鼠标点击的点坐标转换为填充多边形
        polygon = np.array(points, dtype=np.int32)
        cv2.fillPoly(canvas, [polygon], 255)
        # 获取canvas的形状
        canvas_shape = canvas.shape[:2]

        # 创建一个与canvas的前两维相同尺寸的全零数组，数据类型为int32
        mask = np.full(canvas_shape[:2], -1, dtype=np.int32)
        
        # 使用cv2.fillPoly函数填充多边形来在mask上创建相应区域的白色区域
        cv2.fillPoly(mask, [polygon], 1)
        count_ones = (mask == 1).sum()
        print("Number of 1s in the array:", count_ones)
        
        # 找到所有值为1的像素的坐标
        ones_indices = np.argwhere(mask == 1)

        # 随机选择8个点
        num_points = 8
        random_indices = np.random.choice(len(ones_indices), num_points, replace=False)
        # 获取随机选择的点的坐标并存储在数组中
        random_points = ones_indices[random_indices]
        random_points = random_points[:, [1, 0]]
        # 打印随机选择的值为1的像素的坐标
        print("Random points with value 1:")
        print(random_points)

        # Convert reference mask to SAM format & run predictor.
        segmentix = Segmentix()
        sam_mask: np.ndarray = segmentix.reference_to_sam_mask(mask)

        # 计算 random_points 的长度
        length = len(random_points)

        # 生成相应长度的包含1的向量
        ones_vector = np.ones(length, dtype=np.int32)

        mask, scores, logits = predictor.predict(
            point_coords=random_points,
            point_labels=ones_vector,
            multimask_output=False,
        )
        mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask
        masks, _, _ = predictor.predict(
            point_coords=random_points,
            point_labels=ones_vector,
            mask_input=mask_input[None, :, :],
            multimask_output=False,
        )
        
        # 使用 masks[0] 作为掩码
        mask3 = masks[0]

        # 创建颜色覆盖（例如，红色覆盖）
        color_overlay = np.zeros_like(canvas)
        color_overlay[:, :] = (0, 0, 100)  # 这里使用红色，你可以更改为其他颜色

        # 将掩码区域用颜色覆盖
        canvas = canvas_backup.copy()
        canvas[mask3] = color_overlay[mask3]
        
        cv2.fillPoly(canvas, [polygon], 255)
        
        for point in random_points:
            cv2.circle(canvas, tuple(point), 5, (255, 105, 180), -1)
        
# 创建一个窗口用于显示图像
cv2.namedWindow('Mask Creation')

# 创建窗口并绑定鼠标事件处理函数
cv2.setMouseCallback('Mask Creation', mouse_callback2)

while True:
    cv2.imshow('Mask Creation', canvas)
    key = cv2.waitKey(1)

    if key == 27:  # 按下ESC键退出
        break
    elif key == ord('c'):  # 按下 'c' 键清除所有点击点
        points = []
        c = np.zeros((400, 400, 3), dtype=np.uint8)
        cv2.imshow('Mask Creation', c)

cv2.destroyAllWindows()



# print(sam_mask)
# plt.imshow(sam_mask, cmap='viridis')  # 使用'viridis'颜色映射，您可以根据需要选择其他颜色映射
# plt.colorbar()  # 添加颜色条
# plt.title("Array Visualization")
# plt.show()


Number of 1s in the array: 1


ValueError: Cannot take a larger sample than population when 'replace=False'

Number of 1s in the array: 36
Random points with value 1:
[[236 221]
 [234 223]
 [230 228]
 [227 232]
 [222 239]
 [220 241]
 [244 211]
 [226 234]]
dfvsxvxvxcv 788 695 (788, 695)
sdfsfd (256, 225)
Asfasfafsasf [[-1 -1 -1 ... -1 -1 -1]
 [-1 -1 -1 ... -1 -1 -1]
 [-1 -1 -1 ... -1 -1 -1]
 ...
 [-1 -1 -1 ... -1 -1 -1]
 [-1 -1 -1 ... -1 -1 -1]
 [-1 -1 -1 ... -1 -1 -1]]
Number of 1s in the array: 1141
Random points with value 1:
[[249 227]
 [230 252]
 [232 227]
 [246 232]
 [246 259]
 [251 258]
 [249 259]
 [224 240]]
dfvsxvxvxcv 788 695 (788, 695)
sdfsfd (256, 225)
Asfasfafsasf [[-1 -1 -1 ... -1 -1 -1]
 [-1 -1 -1 ... -1 -1 -1]
 [-1 -1 -1 ... -1 -1 -1]
 ...
 [-1 -1 -1 ... -1 -1 -1]
 [-1 -1 -1 ... -1 -1 -1]
 [-1 -1 -1 ... -1 -1 -1]]
Number of 1s in the array: 2229
Random points with value 1:
[[246 214]
 [262 232]
 [248 225]
 [263 255]
 [255 234]
 [266 233]
 [246 228]
 [259 215]]
dfvsxvxvxcv 788 695 (788, 695)
sdfsfd (256, 225)
Asfasfafsasf [[-1 -1 -1 ... -1 -1 -1]
 [-1 -1 -1 ... -1 -1 -1]
 [-1 -1

In [11]:
plt.imshow(image)
show_mask(masks, plt.gca())
plt.axis('off')
plt.show() 


# 使用 masks[0] 作为掩码
        mask = masks[0]

        # 创建颜色覆盖（例如，红色覆盖）
        color_overlay = np.zeros_like(canvas)
        color_overlay[:, :] = (0, 0, 100)  # 这里使用红色，你可以更改为其他颜色

        # 将掩码区域用颜色覆盖
        canvas = canvas_backup.copy()
        canvas[mask] = color_overlay[mask]

IndentationError: unexpected indent (3705293468.py, line 8)