In [2]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor

In [None]:


base_path =r'./tzb_data/val/val2'
list_path = os.path.join(base_path,'list.txt')
result_image_base = './tzb_data/result/images-val2' 
result_box_base = './tzb_data/result/box_val-2'

# 张量计算时的精度设置
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True



# sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
# model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

# 加载模型权重和配置文件
sam2_checkpoint = "/home/guest/tangxiangkai/sam2/sam2_logs/configs/sam2.1_training/sam2.1_hiera_b+_tzb_finetune.yaml/checkpoints/checkpoint.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint,device='cuda')



In [None]:
# 定义相关显示函数
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

import cv2

colors = [
    np.array([0, 128, 255], dtype=np.uint8),    # 橙色
    np.array([255, 128, 0], dtype=np.uint8),    # 蓝绿色
    np.array([0, 255, 0], dtype=np.uint8),      # 绿色
    np.array([0, 0, 255], dtype=np.uint8),      # 红色
    np.array([255, 0, 0], dtype=np.uint8),      # 蓝色
    np.array([255, 255, 0], dtype=np.uint8),    # 黄色
    np.array([255, 0, 255], dtype=np.uint8),    # 紫色
    np.array([0, 255, 255], dtype=np.uint8),    # 青色
    np.array([128, 0, 128], dtype=np.uint8),    # 深紫色
    np.array([128, 128, 0], dtype=np.uint8),    # 橄榄色
    np.array([0, 128, 128], dtype=np.uint8),    # 深青色
    np.array([128, 0, 0], dtype=np.uint8),      # 深蓝色
    np.array([0, 0, 128], dtype=np.uint8),      # 深红色
]

def add_mask2(image, mask, color_id):
    # 单通道
    int_mask = mask.astype(np.uint8)
    int_mask_3d = np.dstack((int_mask, int_mask, int_mask))
    # 创建橙色掩码图像
    # 橙色 (B, G, R)
    mask_color = colors[color_id]
    mask = np.full_like(image, mask_color)

    # 将掩码应用于橙色掩码图像
    mask[int_mask == 0] = 0

    # 使用 cv2.addWeighted 叠加原始图像和橙色掩码图像
    alpha = 0.6  # 原始图像权重
    beta = 1 - 0.6  # 橙色掩码权重
    gamma = 0  # 偏移量

    # 使用掩码矩阵来控制叠加
    res = cv2.addWeighted(image, alpha, mask, beta, gamma, dtype=cv2.CV_8U)

    # 将mask中为黑色部分保留原图,0的区域为True, 非零区域为False
    # 获取黑色区域
    black_areas = int_mask_3d == 0
    res[black_areas] = image[black_areas]
    return res


def get_first_frame_bbox_xyxy(bbox_dir,bbox_list):
    # print(frame_list)
    data_frame = []
    with open(os.path.join(bbox_dir,bbox_list[0]),'r') as f:
        for line in f:
            parts = line.strip().split()  # 按空格分割
            data_frame.append(parts)
    # print(data_frame)
    return data_frame

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

import numpy as np
import torch


def mask_to_yolo_bbox(mask):
    """
    将掩码转为 YOLO 格式 bbox: x_center y_center width height (归一化)
    mask: torch.Tensor(H, W) 或 np.ndarray(H, W)
    """
    # 如果是 torch 张量，转 numpy
    if isinstance(mask, torch.Tensor):
        mask = mask.detach().cpu().numpy()

    mask = mask[0]

    H, W = mask.shape

    # 阈值化，确保是二值
    binary_mask = (mask > 0.5).astype(np.uint8)

    # 找到所有非零像素位置
    ys, xs = np.where(binary_mask > 0)

    if len(xs) == 0 or len(ys) == 0:
        return None  # 没有目标

    # 计算边界框
    x_min, x_max = xs.min(), xs.max()
    y_min, y_max = ys.min(), ys.max()

    # 计算中心点和宽高
    bbox_w = x_max - x_min + 1
    bbox_h = y_max - y_min + 1
    x_center = x_min + bbox_w / 2
    y_center = y_min + bbox_h / 2

    # 归一化到 YOLO 格式
    x_center /= W
    y_center /= H
    bbox_w   /= W
    bbox_h   /= H

    return x_center, y_center, bbox_w, bbox_h



In [None]:

# box模式

# video_num = 'data_06'
#  读取包含所有视频帧的图片路径



data_list = []

with open(list_path,'r') as f:
    for line in f:
        parts = line.strip().split()  
        data_list.append(parts)

for video_list in data_list:
    video_num = video_list[0]
    # if int(video_num.split('_')[1])!=8:
    #     continue
    video_dir = os.path.join(base_path,'images',f'{video_num}')
    bbox_path = os.path.join(base_path,'labels',f'{video_num }')
    # scan all the JPEG frame names in this directory
    frame_names = [p for p in os.listdir(video_dir) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]]
    frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

    bbox_names = [t for t in  os.listdir(bbox_path) if os.path.splitext(t)[-1] == '.txt']
    bbox_names.sort(key= lambda t:int(os.path.splitext(t)[0]))

    first_bboxes = get_first_frame_bbox_xyxy(bbox_path,bbox_names)

    # if video_num == 'data_04':
    #     print(first_bboxes)
        # 以下全部缩进
    # if video_num == 'data_03':
    first_img = cv.imread(os.path.join(video_dir,frame_names[0]))
    cls_dict = {}
    anno_frame_idx = 0

    inference_state = predictor.init_state(
        video_path=video_dir,
        offload_video_to_cpu=True
        )


    # 添加首帧信息到模型中
    for idx,cls_bbox in enumerate(first_bboxes):
        print(idx,cls_bbox)
        # 跟踪目标  id,用以映射回原类型
        anno_obj_id = idx

        #跟踪类别
        cls = cls_bbox[0]
        bbox = cls_bbox[1:]

        # 将跟踪目标唯一ID与目标的类别对应
        cls_dict[anno_obj_id] = cls

        img_h,img_w,_ = first_img.shape
        x,y,w,h = map(float,bbox)
        h = h * img_h
        w = w * img_w

        x1 = round(x*img_w - w/2)
        y1 = round(y*img_h - h/2)
        x2 = round(x1+w)
        y2 = round(y1+h)

        #用的点，看能否用目标框
        point = np.array([[int(x*img_w),int(y*img_h)]], dtype=np.float32)
        label = np.array([1], np.int32)

        box = np.array([x1,y1,x2,y2],dtype=np.float32)
        # label = np.array([1], np.int32)

        _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( # 为预测模型 添加首帧信息
            inference_state=inference_state,
            frame_idx=anno_frame_idx,
            obj_id=anno_obj_id,
            box = box,
            points = point,
            labels = label
            )

    video_segments = {}  
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
        video_segments[out_frame_idx] = {
            out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
            for i, out_obj_id in enumerate(out_obj_ids)
        }



    for out_frame_idx in range(0, len(frame_names)):
        color_num = 0
        # 结果保存路径
        now_img = cv2.imread(os.path.join(video_dir, frame_names[out_frame_idx]))
        result_img_path = os.path.join(result_image_base,f'{video_num}')
        result_box_path = os.path.join(result_box_base,f'{video_num}')
        result_bbox_path = os.path.join('./tzb_data/score_test/data/predictions_val2',f'{video_num}')
        
        os.makedirs(result_img_path,exist_ok=True)
        os.makedirs(result_box_path,exist_ok=True)
        os.makedirs(result_bbox_path,exist_ok=True)
        
        for out_obj_id, out_mask in video_segments[out_frame_idx].items():
            bbox = []
            bbox.append(cls_dict[out_obj_id])
            yolo_bbox = mask_to_yolo_bbox(out_mask)
            print(yolo_bbox)
            # print("YOLO bbox:", yolo_bbox)
            if yolo_bbox is None:
                print("警告：掩码没有检测到有效目标框")
                # 根据情况选择跳过，或用默认框代替
                new_line=''
            else:
                bbox.extend(yolo_bbox)
                # bbox.append(0.9)
                # print('box',bbox)
                new_line = ' '.join(str(x) for x in bbox)
                # print("newline",new_line)
            txt_file = os.path.join(result_box_path,str(out_frame_idx)+'.txt')
            with open(txt_file,'a') as f:
                f.write(new_line+'\n')
            txt_file = os.path.join(result_bbox_path,str(out_frame_idx)+'.txt')
            with open(txt_file,'a') as f:
                f.write(new_line+'\n')
            now_img = add_mask2(now_img, np.squeeze(out_mask),int(cls_dict[anno_obj_id]))
            color_num += 1
        mask_out_name = os.path.join(result_img_path,'{}.jpg'.format(str(out_frame_idx).zfill(5)))
        cv2.imwrite(mask_out_name, now_img)

    predictor.reset_state(inference_state)
    torch.cuda.empty_cache()

frame loading (JPEG): 100%|██████████| 240/240 [00:06<00:00, 37.11it/s]


0 ['1', '0.435156', '0.180664', '0.029687', '0.041016']
1 ['1', '0.356250', '0.188477', '0.031250', '0.029297']
2 ['1', '0.403125', '0.327148', '0.025000', '0.037109']
3 ['1', '0.415625', '0.355469', '0.025000', '0.027344']
4 ['1', '0.421875', '0.381836', '0.028125', '0.025391']
5 ['1', '0.432812', '0.405273', '0.028125', '0.025391']
6 ['1', '0.441406', '0.432617', '0.029687', '0.033203']
7 ['1', '0.462500', '0.414062', '0.025000', '0.042969']
8 ['1', '0.325000', '0.517578', '0.021875', '0.042969']
9 ['1', '0.346875', '0.512695', '0.028125', '0.052734']
10 ['1', '0.335156', '0.455078', '0.039062', '0.054688']
11 ['1', '0.321094', '0.410156', '0.039062', '0.042969']
12 ['1', '0.185938', '0.413086', '0.021875', '0.033203']
13 ['1', '0.208594', '0.414062', '0.023438', '0.054688']
14 ['1', '0.228906', '0.411133', '0.020313', '0.064453']
15 ['1', '0.245312', '0.406250', '0.028125', '0.062500']
16 ['1', '0.271875', '0.410156', '0.031250', '0.058594']
17 ['1', '0.313281', '0.346680', '0.02968

propagate in video: 100%|██████████| 240/240 [06:47<00:00,  1.70s/it]


(np.float64(0.4359375), np.float64(0.1826171875), np.float64(0.025), np.float64(0.033203125))
(np.float64(0.35546875), np.float64(0.189453125), np.float64(0.0234375), np.float64(0.02734375))
(np.float64(0.403125), np.float64(0.326171875), np.float64(0.021875), np.float64(0.03125))
(np.float64(0.41640625), np.float64(0.35546875), np.float64(0.0203125), np.float64(0.02734375))
(np.float64(0.42421875), np.float64(0.380859375), np.float64(0.0234375), np.float64(0.02734375))
(np.float64(0.434375), np.float64(0.40625), np.float64(0.021875), np.float64(0.03125))
(np.float64(0.44375), np.float64(0.431640625), np.float64(0.025), np.float64(0.03515625))
(np.float64(0.4625), np.float64(0.416015625), np.float64(0.025), np.float64(0.03515625))
(np.float64(0.32421875), np.float64(0.5205078125), np.float64(0.0234375), np.float64(0.037109375))
(np.float64(0.346875), np.float64(0.51171875), np.float64(0.028125), np.float64(0.0546875))
(np.float64(0.33671875), np.float64(0.455078125), np.float64(0.03281