### Step 1. Initialize sam2

In [None]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )


In [2]:
np.random.seed(42)

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()

In [3]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

sam2_checkpoint = "pretrained_models/sam2.1_hiera_large.pt" # input your own model path
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

predictor = SAM2ImagePredictor(sam2_model)

### Step 2: Load reason seg data

In [None]:
from datasets import load_from_disk
import glob

data_dir = "data/Segmentation/reason_seg/train" # input your own train data path

json_path_list = sorted(glob.glob(data_dir + "/*.json"))
json_path_list[0]


In [None]:
import json
annotation = json.load(open(json_path_list[0], 'r'))
texts = annotation["text"]
print(texts)
print(len(texts))

In [8]:
import json
import cv2

def get_mask_from_json(json_path, height, width):
    try:
        with open(json_path, "r") as r:
            anno = json.loads(r.read())
    except:
        with open(json_path, "r", encoding="cp1252") as r:
            anno = json.loads(r.read())

    inform = anno["shapes"]

    ### sort polies by area
    area_list = []
    valid_poly_list = []
    for i in inform:
        label_id = i["label"]
        points = i["points"]
        if "flag" == label_id.lower():  ## meaningless deprecated annotations
            continue

        tmp_mask = np.zeros((height, width), dtype=np.uint8)
        cv2.polylines(tmp_mask, np.array([points], dtype=np.int32), True, 1, 1)
        cv2.fillPoly(tmp_mask, np.array([points], dtype=np.int32), 1)
        tmp_area = tmp_mask.sum()

        area_list.append(tmp_area)
        valid_poly_list.append(i)

    ### ground-truth mask
    sort_index = np.argsort(area_list)[::-1].astype(np.int32)
    sort_index = list(sort_index)
    sort_inform = []
    for s_idx in sort_index:
        sort_inform.append(valid_poly_list[s_idx])

    mask = np.zeros((height, width), dtype=np.uint8)
    for i in sort_inform:
        label_id = i["label"]
        points = i["points"]

        if "ignore" in label_id.lower():
            label_value = 255  # ignored during evaluation
        else:
            label_value = 1  # target

        cv2.polylines(mask, np.array([points], dtype=np.int32), True, label_value, 1)
        cv2.fillPoly(mask, np.array([points], dtype=np.int32), label_value)

    mask = mask.astype(bool)  
    return mask


In [9]:
from scipy import ndimage
def get_two_representative_points(m):
    """
    找到两个能较好描述mask形状的点
    
    Args:
        m: 二值图像数组
    
    Returns:
        tuple: ((x1, y1), (x2, y2)) 两个代表性点的坐标
    """
    y_indices, x_indices = np.where(m == 1)
    if len(x_indices) == 0 or len(y_indices) == 0:
        return None, None
    
    # 计算距离变换
    dist_transform = ndimage.distance_transform_edt(m)
    
    # 找到第一个点（全局最大值点）
    y1, x1 = np.unravel_index(dist_transform.argmax(), dist_transform.shape)
    
    # 计算mask的重心
    center_y = int(np.mean(y_indices))
    center_x = int(np.mean(x_indices))
    
    # 将点分为两组：距离第一个点较远的点和较近的点
    points = np.column_stack((y_indices, x_indices))
    distances_to_first = ((points[:, 0] - y1) ** 2 + (points[:, 1] - x1) ** 2) ** 0.5
    
    # 找到距离第一个点最远的点集
    far_points = points[distances_to_first > np.median(distances_to_first)]
    
    if len(far_points) > 0:
        # 在远点中找到距离变换值最大的点作为第二个点
        far_dist_values = dist_transform[far_points[:, 0], far_points[:, 1]]
        second_point_idx = np.argmax(far_dist_values)
        y2, x2 = far_points[second_point_idx]
    else:
        # 如果没有合适的远点，使用重心附近的点
        local_region = dist_transform[
            max(0, center_y - 10):min(m.shape[0], center_y + 10),
            max(0, center_x - 10):min(m.shape[1], center_x + 10)
        ]
        local_y, local_x = np.unravel_index(local_region.argmax(), local_region.shape)
        y2 = local_y + max(0, center_y - 10)
        x2 = local_x + max(0, center_x - 10)
    
    # 确保两个点都在mask上
    if m[y1, x1] == 0:
        distances = (x_indices - x1)**2 + (y_indices - y1)**2
        nearest_idx = np.argmin(distances)
        x1, y1 = int(x_indices[nearest_idx]), int(y_indices[nearest_idx])
    
    if m[y2, x2] == 0:
        distances = (x_indices - x2)**2 + (y_indices - y2)**2
        nearest_idx = np.argmin(distances)
        x2, y2 = int(x_indices[nearest_idx]), int(y_indices[nearest_idx])
    
    return [x1, y1], [x2, y2] 

In [15]:
def get_mask_from_point(predictor, input_point, input_label, box):
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        box=box,
        multimask_output=False,
    )
    sorted_ind = np.argsort(scores)[::-1]
    masks = masks[sorted_ind]
    scores = scores[sorted_ind]
    logits = logits[sorted_ind]
    return masks

import numpy as np

def compute_iou(mask1, mask2):
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    if union == 0:
        return 0
    return intersection / union

def is_bbox_contained(inner_bbox, outer_bbox):
    """
    判断bbox1是否完全包含在bbox2中
    bbox格式: [x1, y1, x2, y2]
    """
    return (inner_bbox[0] >= outer_bbox[0] and  # bbox1的左边界在bbox2的左边界的右边
            inner_bbox[1] >= outer_bbox[1] and  # bbox1的上边界在bbox2的上边界的下边
            inner_bbox[2] <= outer_bbox[2] and  # bbox1的右边界在bbox2的右边界的左边
            inner_bbox[3] <= outer_bbox[3])     # bbox1的下边界在bbox2的下边界的上边

### Step 3: Generate annotation list

In [None]:
import numpy as np
from tqdm import tqdm  # 导入tqdm
import json  # 导入json模块
import cv2

threshold_iou = 0.6  # threshold_iou IOU:  0.659445961
cnt = 0

seg_zero_annotation_list = []

for idx, json_path in tqdm(enumerate(json_path_list), desc="Processing images"):  # 使用tqdm包装循环
    image_path = json_path.replace(".json", ".jpg")
    image_id = image_path.split("/")[-1].split(".")[0]
    
    
    anno = json.loads(open(json_path, "r").read())
    text = anno["text"][0]

    
    # # set Image to SAM2
    # image = Image.open(image_path)
    # image = np.array(image.convert("RGB"))
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    height, width, _ = image.shape
    predictor.set_image(image)

    
    inform = anno["shapes"]

    ### sort polies by area
    area_list = []
    valid_poly_list = []
    for i in inform:
        label_id = i["label"]
        points = i["points"]
        if "flag" == label_id.lower():  ## meaningless deprecated annotations
            continue

        tmp_mask = np.zeros((height, width), dtype=np.uint8)
        cv2.polylines(tmp_mask, np.array([points], dtype=np.int32), True, 1, 1)
        cv2.fillPoly(tmp_mask, np.array([points], dtype=np.int32), 1)
        tmp_area = tmp_mask.sum()

        area_list.append(tmp_area)
        valid_poly_list.append(i)

    ### ground-truth mask
    sort_index = np.argsort(area_list)[::-1].astype(np.int32)
    sort_index = list(sort_index)
    sort_inform = []
    for s_idx in sort_index:
        sort_inform.append(valid_poly_list[s_idx])

    
    bboxes_list = []
    points_list = []
    prev_bbox = None
    for i in sort_inform:
        m = np.zeros((height, width), dtype=np.uint8)
        label_id = i["label"]
        points = i["points"]

        if "ignore" in label_id.lower():
            label_value = 255  # ignored during evaluation
        else:
            label_value = 1  # target

        cv2.polylines(m, np.array([points], dtype=np.int32), True, label_value, 1)
        cv2.fillPoly(m, np.array([points], dtype=np.int32), label_value)

        m = m.astype(bool).astype(np.uint8)  
        # plt.imshow(m)
        # plt.show()
        left = np.where(m == 1)[1].min()
        top = np.where(m == 1)[0].min()
        right = np.where(m == 1)[1].max()
        bottom = np.where(m == 1)[0].max()
        box = [left, top, right, bottom]
        # print(box)
        if prev_bbox is not None:
            if is_bbox_contained(box, prev_bbox):
                continue
            else:
                prev_bbox = box
        points_1, points_2 = get_two_representative_points(m)
        point = points_1
        label = 1
        
        mask_pred = get_mask_from_point(predictor, np.array([point]), np.array([label]), np.array(box))

        mask_pred = mask_pred[0].astype(bool)
        mask_gt = m.astype(bool)
        iou = compute_iou(mask_pred, mask_gt)
        
        if iou < threshold_iou:
            continue
        
        
        bboxes_list.append(box)
        points_list.append(point)

    if len(bboxes_list) <= 0:
        continue
    
    seg_zero_annotation_list.append({
        "id": "reason_seg_" + image_id,
        "image_id": image_id,
        "image_path": image_path,
        "problem": text,
        "bboxes": bboxes_list,
        "center_points": points_list
    })
        
    cnt += 1
        
    # if cnt > 10:
    #     break
        
            
print(f"总共发现 {len(seg_zero_annotation_list)} 个高于阈值IOU的案例")

In [None]:
seg_zero_annotation_list[0]

In [None]:
for item in seg_zero_annotation_list:
    item['bboxes'] = [list(map(int, bbox)) for bbox in item['bboxes']]
    item['center_points'] = [list(map(int, center_point)) for center_point in item['center_points']]
seg_zero_annotation_list[0]

### Step 4: Save and show examples

In [43]:
with open(f'seg_zero_reasonseg_annotation_list_all_various_item.json', 'w', encoding='utf-8') as f:
    json.dump(seg_zero_annotation_list, f, ensure_ascii=False, indent=4)

In [None]:
import cv2 

item = seg_zero_annotation_list[30]

print(item['problem'])
print(item['bboxes'])
print(item['center_points'])

image_path = item['image_path']
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


for bbox, center_point in zip(item['bboxes'], item['center_points']):
    cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 2)
    cv2.circle(image, (center_point[0], center_point[1]), 5, (0, 255, 0), -1)
    
plt.imshow(image)
plt.show()


### Step 5: Please refer to gen_training_dataset.py