In [None]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch

def show_anns(anns, borders=True):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.5]])
        img[m] = color_mask
        if borders:
            import cv2
            contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
            # Try to smooth contours
            contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
            cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1)

    ax.imshow(img)

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

image = Image.open('./img/test.png')
image = np.array(image.convert("RGB"))

sam2_1_checkpoint = "./models/sam2/sam2.1_hiera_base_plus.pt"
model_cfg_1 = "configs/sam2.1/sam2.1_hiera_b+.yaml"

sam2_1 = build_sam2(model_cfg_1, sam2_1_checkpoint, device=torch.device("cuda"), apply_postprocessing=False)


mask_generator_1 = SAM2AutomaticMaskGenerator(
    model=sam2_1,
    use_m2m=False,
    )

In [None]:
masks_sam2_1 = mask_generator_1.generate(image)

In [None]:
np.random.seed(3)
fig, axes = plt.subplots(1, 3, figsize=(24, 8))

# 原始图像
axes[0].imshow(image)
axes[0].set_title("Original Image")
axes[0].axis('off')

# SAM2.1 结果
np.random.seed(3)
axes[2].imshow(image)
plt.sca(axes[2])
show_anns(masks_sam2_1, borders=True)
axes[2].set_title(f"SAM2.1 result: \n mask_num = {len(masks_sam2_1)}")
axes[2].axis('off')

plt.tight_layout()
plt.show()