In [1]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from tqdm import tqdm


# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

using device: cuda


In [2]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

predictor = SAM2ImagePredictor(sam2_model)

In [3]:
import sys
import os
# 获取当前 Notebook 文件的目录路径
current_dir = os.path.dirname(os.path.abspath("__file__"))
#假设 switch nerf 目录在项目根目录下，获取项目根目录
project_root = os.path.abspath(os.path.join(current_dir, ".."))
#将项目根目录添加到 sys.path 中
if project_root not in sys.path:
    sys.path.append(project_root)
#打印路径和目录内容以确保正确性
print("current directory:",current_dir)
print("project root directory:",project_root)
print("sys.path:",sys.path)

from utils_pan import get_box_from_labelme_json, load_json, update_json_with_polygon, save_json

current directory: e:\PycharmProjects\1-BigModel\sam2-main\notebooks
project root directory: e:\PycharmProjects\1-BigModel\sam2-main
sys.path: ['d:\\SoftWare\\Anaconda\\envs\\auto-labelme\\python310.zip', 'd:\\SoftWare\\Anaconda\\envs\\auto-labelme\\DLLs', 'd:\\SoftWare\\Anaconda\\envs\\auto-labelme\\lib', 'd:\\SoftWare\\Anaconda\\envs\\auto-labelme', '', 'C:\\Users\\ZealousQun\\AppData\\Roaming\\Python\\Python310\\site-packages', 'C:\\Users\\ZealousQun\\AppData\\Roaming\\Python\\Python310\\site-packages\\win32', 'C:\\Users\\ZealousQun\\AppData\\Roaming\\Python\\Python310\\site-packages\\win32\\lib', 'C:\\Users\\ZealousQun\\AppData\\Roaming\\Python\\Python310\\site-packages\\Pythonwin', 'd:\\SoftWare\\Anaconda\\envs\\auto-labelme\\lib\\site-packages', 'e:\\PycharmProjects\\1-BigModel\\sam2-main']


In [4]:
root_path = 'E:\\PycharmProjects\\yolov8_datasets\\ChuangDian\\20250108_plastic_bottle\\0111\\updated_json'
label = 'pingti'

files = os.listdir(root_path)
for file in tqdm(files, desc='Processing'):
    if file.endswith('.bmp'):
        img = Image.open(os.path.join(root_path, file))
        img = np.array(img.convert("RGB"))
        json_path = os.path.join(root_path, file.replace('.bmp', '.json'))
        input_box = get_box_from_labelme_json(json_path, label, shape_type='polygon')
        
        predictor.set_image(img)
        masks, scores, _ = predictor.predict(
        point_coords=None, point_labels=None, box=input_box[None, :], multimask_output=False)
        
        # 转换为 8 位无符号整数类型
        mask = masks[0].astype(np.uint8) * 255  # 将前景变为255，背景为0

        # 查找轮廓
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # 选择最大的轮廓作为前景的边缘
        # 如果有多个轮廓，可以选择面积最大的轮廓作为主体
        max_contour = max(contours, key=cv2.contourArea)

        # 获取该轮廓的多边形逼近
        epsilon = 0.001 * cv2.arcLength(max_contour, True)  # 误差范围，可以调整
        polygon = cv2.approxPolyDP(max_contour, epsilon, True)
        
        # 保存更新后的 JSON 文件
        json_data = load_json(json_path)
        update_json_with_polygon(json_data, label, polygon)
        save_json(json_data, json_path)


Processing: 100%|██████████| 898/898 [01:37<00:00,  9.22it/s]
