# SAM3 远程无 GUI 两阶段分割（机械臂 + gripper 新 obj_id）

本 notebook 面向远程服务器（无 GUI）逐 cell 交互执行：

- 阶段 A：先做一次普通传播（bootstrap）填充缓存，再进行机械臂 points/labels refinement。
- 阶段 B：在关键帧使用 points + point_labels + **全新 obj_id** 新增 gripper，再传播。
- 导出支持对象集合选择：`arm-only` / `gripper-only` / `union` / `custom`。

> 注意：本 notebook 使用 inline 可视化与日志打印，不依赖 `%matplotlib widget`、按钮、鼠标事件回调。

In [None]:
import os
import glob
from pathlib import Path

import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

import sam3
from sam3.model_builder import build_sam3_video_predictor
from sam3.visualization_utils import prepare_masks_for_visualization, visualize_formatted_frame_output

plt.rcParams["axes.titlesize"] = 12
plt.rcParams["figure.titlesize"] = 12

# ==============================
# 关键配置区（建议仅修改本区）
# ==============================

# 1) 路径配置
VIDEO_PATH = "/data/haoxiang/data/airexo2/task_0013/train/scene_0001/cam_105422061350/color"
CHECKPOINT_PATH = "/data/haoxiang/sam3/models/facebook/sam3/sam3.pt"

# 2) 推理配置
CUDA_VISIBLE_DEVICES = "0,1,2,3"
APPLY_TEMPORAL_DISAMBIGUATION = False
PROPAGATION_DIRECTION_STAGE_A = "forward"  # 建议 forward
PROPAGATION_DIRECTION_STAGE_B = "forward"  # 关键帧新增对象后再 forward

# 3) 对象 ID 配置（左右机械臂 + 左右 gripper，四者必须互不重复）
arm_left_obj_id = 0
arm_right_obj_id = 1
gripper_left_obj_id = 2
gripper_right_obj_id = 3

# 4) 阶段 A：左右机械臂 bootstrap 文本提示（可独立配置）
# 为避免首次 points add_prompt 触发 cached outputs 断言，需先完成一次普通传播
ARM_LEFT_BOOTSTRAP_TEXT_PROMPT = ""  # 建议可填，如 "left robot arm"
ARM_LEFT_BOOTSTRAP_FALLBACK_TEXT_PROMPT = "left robot arm"
ARM_LEFT_BOOTSTRAP_FRAME_INDEX = None

ARM_RIGHT_BOOTSTRAP_TEXT_PROMPT = ""  # 建议可填，如 "right robot arm"
ARM_RIGHT_BOOTSTRAP_FALLBACK_TEXT_PROMPT = "right robot arm"
ARM_RIGHT_BOOTSTRAP_FRAME_INDEX = None

# 5) 阶段 A：左右机械臂 points refinement（可为空，左右独立）
# coord_type: "abs"(像素坐标) / "rel"([0,1]归一化坐标)
ARM_LEFT_INITIAL_PROMPTS = [
    {
        "frame_index": 120,
        "obj_id": 0,
        "coord_type": "abs",
        "points": [[962, 350]],
        "labels": [0],
    }
]

ARM_RIGHT_INITIAL_PROMPTS = []

# 6) 阶段 B：关键帧新增左右 gripper（points + point_labels + 新 obj_id）
# 支持多关键帧，左右 obj_id 可独立配置
GRIPPER_LEFT_KEYFRAME_PROMPTS = [
    {
        "frame_index": 120,
        "obj_id": 2,
        "coord_type": "abs",
        "points": [[973, 356], [930, 330]],
        "labels": [1, 0],
    }
]

GRIPPER_RIGHT_KEYFRAME_PROMPTS = []

# 6.5) 可视化标注导出接入（仅切换数据来源，不触发推理）
# True: Stage A/B 使用 ANNOTATION_EXPORT_PROMPTS
# False: Stage A/B 使用手工配置的 *_PROMPTS
USE_VISUAL_ANNOTATION_EXPORT = False
ANNOTATION_EXPORT_PROMPTS = None

# 7) 可视化配置
VIS_FRAME_STRIDE = 60
VIS_MAX_PLOTS = 8

# 8) 导出配置
# EXPORT_MODE: "arm-only" | "gripper-only" | "union" | "custom"
EXPORT_MODE = "union"
EXPORT_CUSTOM_OBJ_IDS = []
EXPORT_OUTPUT_DIR = "/data/haoxiang/propainter/masks_airexo_arm_gripper_union"
EXPORT_DILATE_RADIUS = 15

# 运行时变量（无需手动修改）
predictor = None
session_id = None
video_frames_for_vis = None
TOTAL_FRAMES = None
IMG_WIDTH = None
IMG_HEIGHT = None
arm_left_prompts_norm = None
arm_right_prompts_norm = None
gripper_left_prompts_norm = None
gripper_right_prompts_norm = None
outputs_stage_a = None
outputs_stage_b = None


## 工具函数（校验 / 坐标转换 / 推理流程 / 导出 / 清理）

In [None]:
def cleanup_process_group():
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        try:
            torch.distributed.destroy_process_group()
            print("[cleanup] distributed process group destroyed")
        except Exception as e:
            print(f"[warn] destroy_process_group() failed: {e}")


def cleanup_resources(predictor_obj=None, session_id_value=None):
    if predictor_obj is not None and session_id_value is not None:
        try:
            _ = predictor_obj.handle_request(
                request=dict(
                    type="close_session",
                    session_id=session_id_value,
                )
            )
            print(f"[cleanup] session closed: {session_id_value}")
        except Exception as e:
            print(f"[warn] close_session failed: {e}")

    if predictor_obj is not None:
        try:
            predictor_obj.shutdown()
            print("[cleanup] predictor shutdown finished")
        except Exception as e:
            print(f"[warn] predictor.shutdown() failed: {e}")

    cleanup_process_group()


def load_video_frames_for_visualization(video_path):
    # 仅用于可视化，不参与模型推理输入。
    if isinstance(video_path, str) and video_path.endswith(".mp4"):
        cap = cv2.VideoCapture(video_path)
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
        cap.release()
        return frames

    if isinstance(video_path, str) and os.path.isdir(video_path):
        frame_names = sorted(
            glob.glob(os.path.join(video_path, "*.jpg"))
            + glob.glob(os.path.join(video_path, "*.jpeg"))
            + glob.glob(os.path.join(video_path, "*.png"))
        )
        if not frame_names:
            raise ValueError(f"视频目录为空或无可识别帧文件: {video_path}")
        frames = []
        for fp in frame_names:
            img = cv2.imread(fp)
            if img is None:
                raise ValueError(f"读取帧失败: {fp}")
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            frames.append(img)
        return frames

    raise ValueError(f"不支持的 VIDEO_PATH: {video_path}")


def get_frame_size(video_frames):
    if not video_frames:
        raise ValueError("video_frames 为空")
    frame0 = video_frames[0]
    if isinstance(frame0, torch.Tensor):
        frame0 = frame0.detach().cpu().numpy()
    if frame0.ndim == 4:
        frame0 = frame0[0]
    h, w = frame0.shape[:2]
    return int(w), int(h)


def abs_to_rel_points(points_abs, img_w, img_h):
    return [[float(x) / img_w, float(y) / img_h] for x, y in points_abs]


def validate_prompt_entry(entry, total_frames, img_w, img_h, tag="prompt"):
    required = ["frame_index", "obj_id", "points", "labels", "coord_type"]
    for k in required:
        if k not in entry:
            raise ValueError(f"[{tag}] 缺少字段: {k}; entry={entry}")

    frame_index = entry["frame_index"]
    obj_id = entry["obj_id"]
    points = entry["points"]
    labels = entry["labels"]
    coord_type = entry["coord_type"]

    if frame_index is None or not isinstance(frame_index, int):
        raise ValueError(f"[{tag}] frame_index 必须为 int，当前: {frame_index}")
    if frame_index < 0 or frame_index >= total_frames:
        raise ValueError(f"[{tag}] frame_index 越界: {frame_index}, 合法范围 [0, {total_frames - 1}]")

    if obj_id is None or not isinstance(obj_id, int):
        raise ValueError(f"[{tag}] obj_id 不能为空且必须为 int，当前: {obj_id}")

    if coord_type not in {"abs", "rel"}:
        raise ValueError(f"[{tag}] coord_type 仅支持 abs/rel，当前: {coord_type}")

    if not isinstance(points, (list, tuple)) or len(points) == 0:
        raise ValueError(f"[{tag}] points 不能为空")
    if not isinstance(labels, (list, tuple)) or len(labels) == 0:
        raise ValueError(f"[{tag}] labels 不能为空")
    if len(points) != len(labels):
        raise ValueError(f"[{tag}] points/labels 长度不一致: {len(points)} vs {len(labels)}")

    for i, (p, lb) in enumerate(zip(points, labels)):
        if not isinstance(p, (list, tuple)) or len(p) != 2:
            raise ValueError(f"[{tag}] 第{i}个点格式错误，应为 [x, y]，当前: {p}")
        x, y = float(p[0]), float(p[1])

        if coord_type == "abs":
            if not (0 <= x < img_w and 0 <= y < img_h):
                raise ValueError(
                    f"[{tag}] 第{i}个 abs 点越界: ({x}, {y}), 图像尺寸=({img_w}, {img_h})"
                )
        else:
            if not (0.0 <= x <= 1.0 and 0.0 <= y <= 1.0):
                raise ValueError(f"[{tag}] 第{i}个 rel 点越界: ({x}, {y}), 应在 [0,1]")

        if int(lb) not in {0, 1}:
            raise ValueError(f"[{tag}] 第{i}个 label 非法: {lb}, 仅支持 0/1")


def normalize_prompt_entry(entry, img_w, img_h):
    points = [[float(p[0]), float(p[1])] for p in entry["points"]]
    labels = [int(v) for v in entry["labels"]]

    if entry["coord_type"] == "abs":
        points_rel = abs_to_rel_points(points, img_w, img_h)
    else:
        points_rel = points

    return dict(
        frame_index=int(entry["frame_index"]),
        obj_id=int(entry["obj_id"]),
        points_rel=points_rel,
        labels=labels,
    )


def validate_and_normalize_prompt_list(
    prompt_list,
    total_frames,
    img_w,
    img_h,
    tag,
    allow_empty=False,
):
    if not isinstance(prompt_list, (list, tuple)):
        raise ValueError(f"[{tag}] 必须为 list")
    if len(prompt_list) == 0 and not allow_empty:
        raise ValueError(f"[{tag}] 不能为空")

    normalized = []
    for idx, entry in enumerate(prompt_list):
        if not isinstance(entry, dict):
            raise ValueError(f"[{tag}] 第{idx}项必须为 dict")
        validate_prompt_entry(entry, total_frames, img_w, img_h, tag=f"{tag}[{idx}]")
        normalized.append(normalize_prompt_entry(entry, img_w, img_h))
    return normalized


def validate_obj_id_constraints(
    arm_left_obj_id,
    arm_right_obj_id,
    gripper_left_obj_id,
    gripper_right_obj_id,
    arm_left_prompts,
    arm_right_prompts,
    gripper_left_prompts,
    gripper_right_prompts,
):
    obj_items = [
        ("arm_left_obj_id", arm_left_obj_id),
        ("arm_right_obj_id", arm_right_obj_id),
        ("gripper_left_obj_id", gripper_left_obj_id),
        ("gripper_right_obj_id", gripper_right_obj_id),
    ]

    for name, obj_id in obj_items:
        if obj_id is None or not isinstance(obj_id, int):
            raise ValueError(f"{name} 不能为空且必须为 int，当前: {obj_id}")

    all_obj_ids = [obj_id for _, obj_id in obj_items]
    if len(set(all_obj_ids)) != 4:
        raise ValueError(
            "对象 ID 冲突：arm_left_obj_id / arm_right_obj_id / "
            "gripper_left_obj_id / gripper_right_obj_id 必须互不重复，当前="
            f"{all_obj_ids}"
        )

    def _check_prompt_obj_ids(prompt_list, expected_obj_id, prompt_tag):
        for i, p in enumerate(prompt_list):
            if p["obj_id"] != expected_obj_id:
                raise ValueError(
                    f"{prompt_tag}[{i}].obj_id={p['obj_id']} 与期望 obj_id={expected_obj_id} 不一致"
                )

    _check_prompt_obj_ids(arm_left_prompts, arm_left_obj_id, "ARM_LEFT_INITIAL_PROMPTS")
    _check_prompt_obj_ids(arm_right_prompts, arm_right_obj_id, "ARM_RIGHT_INITIAL_PROMPTS")
    _check_prompt_obj_ids(gripper_left_prompts, gripper_left_obj_id, "GRIPPER_LEFT_KEYFRAME_PROMPTS")
    _check_prompt_obj_ids(gripper_right_prompts, gripper_right_obj_id, "GRIPPER_RIGHT_KEYFRAME_PROMPTS")


def propagate_in_video(predictor_obj, session_id_value, propagation_direction="forward"):
    outputs_per_frame = {}
    for response in predictor_obj.handle_stream_request(
        request=dict(
            type="propagate_in_video",
            session_id=session_id_value,
            propagation_direction=propagation_direction,
        )
    ):
        outputs_per_frame[response["frame_index"]] = response["outputs"]
    return outputs_per_frame


def propagate_bidirectional_and_merge(predictor_obj, session_id_value, stage_name=""):
    outputs_forward = propagate_in_video(
        predictor_obj=predictor_obj,
        session_id_value=session_id_value,
        propagation_direction="forward",
    )
    outputs_backward = propagate_in_video(
        predictor_obj=predictor_obj,
        session_id_value=session_id_value,
        propagation_direction="backward",
    )

    merged_outputs = {}
    merged_outputs.update(outputs_forward)
    merged_outputs.update(outputs_backward)

    stage_prefix = f"[{stage_name}] " if stage_name else ""
    print(
        f"{stage_prefix}propagation summary | forward_frames={len(outputs_forward)} "
        f"backward_frames={len(outputs_backward)} merged_frames={len(merged_outputs)}"
    )
    return merged_outputs


def add_point_prompt(predictor_obj, session_id_value, prompt, stage_name=""):
    points_tensor = torch.tensor(prompt["points_rel"], dtype=torch.float32)
    labels_tensor = torch.tensor(prompt["labels"], dtype=torch.int32)

    _ = predictor_obj.handle_request(
        request=dict(
            type="add_prompt",
            session_id=session_id_value,
            frame_index=prompt["frame_index"],
            points=points_tensor,
            point_labels=labels_tensor,
            obj_id=prompt["obj_id"],
        )
    )
    print(
        f"[{stage_name}] add_prompt done | frame={prompt['frame_index']} obj_id={prompt['obj_id']} points={len(prompt['points_rel'])}"
    )


def apply_prompt_list(predictor_obj, session_id_value, prompt_list, stage_name=""):
    for p in prompt_list:
        add_point_prompt(predictor_obj, session_id_value, p, stage_name=stage_name)


def add_text_prompt(predictor_obj, session_id_value, frame_index, text_prompt, stage_name=""):
    if not isinstance(text_prompt, str) or len(text_prompt.strip()) == 0:
        raise ValueError("text_prompt 不能为空，请设置 ARM_BOOTSTRAP_TEXT_PROMPT")

    _ = predictor_obj.handle_request(
        request=dict(
            type="add_prompt",
            session_id=session_id_value,
            frame_index=int(frame_index),
            text=text_prompt.strip(),
        )
    )
    print(
        f"[{stage_name}] add_text_prompt done | frame={int(frame_index)} text={text_prompt.strip()!r}"
    )


def resolve_stage_a_bootstrap_configs(stage_side_configs, total_frames):
    if not isinstance(stage_side_configs, (list, tuple)) or len(stage_side_configs) == 0:
        raise ValueError("stage_side_configs 必须是非空 list")

    resolved_configs = []

    for idx, cfg in enumerate(stage_side_configs):
        side_name = str(cfg.get("side_name", f"side-{idx}"))
        prompt_list = cfg.get("prompt_list", [])
        bootstrap_text_prompt = cfg.get("bootstrap_text_prompt", "")
        fallback_text_prompt = cfg.get("fallback_text_prompt", "")
        bootstrap_frame_index = cfg.get("bootstrap_frame_index", None)

        has_points = len(prompt_list) > 0
        user_text = bootstrap_text_prompt.strip() if isinstance(bootstrap_text_prompt, str) else ""
        fallback_text = fallback_text_prompt.strip() if isinstance(fallback_text_prompt, str) else ""

        if user_text:
            resolved_text = user_text
            text_source = "configured"
        elif has_points and fallback_text:
            resolved_text = fallback_text
            text_source = "fallback"
            print(
                f"[stage A][bootstrap][{side_name}] 文本提示为空，自动使用 fallback: {resolved_text!r}"
            )
        elif has_points:
            raise ValueError(
                f"[{side_name}] 检测到 arm points，但 bootstrap 文本为空且无 fallback。"
                "请设置对应侧的 *_BOOTSTRAP_TEXT_PROMPT 或 *_BOOTSTRAP_FALLBACK_TEXT_PROMPT。"
            )
        else:
            continue

        if bootstrap_frame_index is None:
            resolved_frame = int(prompt_list[0]["frame_index"]) if has_points else 0
        else:
            if not isinstance(bootstrap_frame_index, int):
                raise ValueError(
                    f"[{side_name}] bootstrap_frame_index 必须为 int 或 None，当前: {bootstrap_frame_index}"
                )
            if bootstrap_frame_index < 0 or bootstrap_frame_index >= total_frames:
                raise ValueError(
                    f"[{side_name}] bootstrap_frame_index 越界: {bootstrap_frame_index}, "
                    f"合法范围 [0, {total_frames - 1}]"
                )
            resolved_frame = int(bootstrap_frame_index)

        resolved_configs.append(
            dict(
                side_name=side_name,
                text_prompt=resolved_text,
                text_source=text_source,
                frame_index=resolved_frame,
            )
        )

    if len(resolved_configs) == 0:
        raise ValueError(
            "阶段A缺少可用提示：左右机械臂均未提供 bootstrap 文本，也没有可触发 fallback 的 points。"
        )

    return resolved_configs


def visualize_outputs(outputs_per_frame, video_frames, stride=60, max_plots=8, title="SAM3 outputs"):
    outputs_for_vis = prepare_masks_for_visualization(outputs_per_frame)
    frame_indices = list(range(0, len(video_frames), stride))[:max_plots]

    if not frame_indices:
        frame_indices = [0]

    plt.close("all")
    for frame_idx in frame_indices:
        visualize_formatted_frame_output(
            frame_idx,
            video_frames,
            outputs_list=[outputs_for_vis],
            titles=[title],
            figsize=(6, 4),
        )


def resolve_export_obj_ids(export_mode, arm_obj_ids, gripper_obj_ids, custom_obj_ids=None):
    custom_obj_ids = custom_obj_ids or []
    arm_obj_ids = sorted(set(int(x) for x in arm_obj_ids))
    gripper_obj_ids = sorted(set(int(x) for x in gripper_obj_ids))

    if export_mode == "arm-only":
        return arm_obj_ids
    if export_mode == "gripper-only":
        return gripper_obj_ids
    if export_mode == "union":
        return sorted(set(arm_obj_ids + gripper_obj_ids))
    if export_mode == "custom":
        if len(custom_obj_ids) == 0:
            raise ValueError("EXPORT_MODE=custom 时，EXPORT_CUSTOM_OBJ_IDS 不能为空")
        return sorted(set(int(x) for x in custom_obj_ids))

    raise ValueError(f"不支持的 EXPORT_MODE: {export_mode}")


def save_masks_for_propainter(
    outputs_per_frame,
    video_frames,
    output_dir,
    target_obj_ids,
    dilate_radius=8,
):
    os.makedirs(output_dir, exist_ok=True)

    img_w, img_h = get_frame_size(video_frames)
    num_frames = len(video_frames)

    kernel = None
    if dilate_radius > 0:
        kernel_size = 2 * int(dilate_radius) + 1
        kernel = np.ones((kernel_size, kernel_size), np.uint8)
        print(f"[export] dilation enabled: radius={dilate_radius}, kernel={kernel_size}x{kernel_size}")

    print(f"[export] size={img_w}x{img_h}, frames={num_frames}, target_obj_ids={target_obj_ids}")

    saved_paths = []
    target_set = set(target_obj_ids)

    for frame_idx in range(num_frames):
        combined_mask = np.zeros((img_h, img_w), dtype=np.uint8)

        obj_dict = outputs_per_frame.get(frame_idx, {})
        for obj_id, mask in obj_dict.items():
            if int(obj_id) not in target_set:
                continue

            if isinstance(mask, torch.Tensor):
                mask = mask.detach().cpu().numpy()
            if mask.ndim > 2:
                mask = np.squeeze(mask)

            binary = (mask > 0).astype(np.uint8) * 255
            combined_mask = np.maximum(combined_mask, binary)

        if kernel is not None and np.any(combined_mask):
            combined_mask = cv2.dilate(combined_mask, kernel, iterations=1)

        out_fp = os.path.join(output_dir, f"{frame_idx:05d}.png")
        Image.fromarray(combined_mask, mode="L").save(out_fp)
        saved_paths.append(out_fp)

        if frame_idx % 50 == 0:
            print(f"[export] frame {frame_idx:05d}/{num_frames - 1:05d} done")

    print(f"[export] finished, saved {len(saved_paths)} masks to: {output_dir}")
    return saved_paths

## 1) 载入可视化帧 + 校验配置 + 初始化 predictor/session

In [None]:
# Notebook 反复执行时，先清理旧资源，避免 NCCL 初始化异常
if "predictor" in globals() and predictor is not None:
    print("[init] cleaning previous predictor/session before re-run")
    cleanup_resources(predictor_obj=predictor, session_id_value=session_id)

os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES
print(f"[init] CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")

video_frames_for_vis = load_video_frames_for_visualization(VIDEO_PATH)
TOTAL_FRAMES = len(video_frames_for_vis)
IMG_WIDTH, IMG_HEIGHT = get_frame_size(video_frames_for_vis)
print(f"[init] loaded frames={TOTAL_FRAMES}, size={IMG_WIDTH}x{IMG_HEIGHT}")

if USE_VISUAL_ANNOTATION_EXPORT:
    if isinstance(ANNOTATION_EXPORT_PROMPTS, dict):
        ARM_LEFT_INITIAL_PROMPTS = ANNOTATION_EXPORT_PROMPTS.get(
            "ARM_LEFT_INITIAL_PROMPTS", ARM_LEFT_INITIAL_PROMPTS
        )
        ARM_RIGHT_INITIAL_PROMPTS = ANNOTATION_EXPORT_PROMPTS.get(
            "ARM_RIGHT_INITIAL_PROMPTS", ARM_RIGHT_INITIAL_PROMPTS
        )
        GRIPPER_LEFT_KEYFRAME_PROMPTS = ANNOTATION_EXPORT_PROMPTS.get(
            "GRIPPER_LEFT_KEYFRAME_PROMPTS", GRIPPER_LEFT_KEYFRAME_PROMPTS
        )
        GRIPPER_RIGHT_KEYFRAME_PROMPTS = ANNOTATION_EXPORT_PROMPTS.get(
            "GRIPPER_RIGHT_KEYFRAME_PROMPTS", GRIPPER_RIGHT_KEYFRAME_PROMPTS
        )
        print("[annotation] using prompts exported from visual annotation cells")
    else:
        print(
            "[annotation][warn] USE_VISUAL_ANNOTATION_EXPORT=True 但 ANNOTATION_EXPORT_PROMPTS 无效，"
            "继续使用手工配置的 *_PROMPTS"
        )

arm_left_prompts_norm = validate_and_normalize_prompt_list(
    ARM_LEFT_INITIAL_PROMPTS,
    total_frames=TOTAL_FRAMES,
    img_w=IMG_WIDTH,
    img_h=IMG_HEIGHT,
    tag="ARM_LEFT_INITIAL_PROMPTS",
    allow_empty=True,
)

arm_right_prompts_norm = validate_and_normalize_prompt_list(
    ARM_RIGHT_INITIAL_PROMPTS,
    total_frames=TOTAL_FRAMES,
    img_w=IMG_WIDTH,
    img_h=IMG_HEIGHT,
    tag="ARM_RIGHT_INITIAL_PROMPTS",
    allow_empty=True,
)

gripper_left_prompts_norm = validate_and_normalize_prompt_list(
    GRIPPER_LEFT_KEYFRAME_PROMPTS,
    total_frames=TOTAL_FRAMES,
    img_w=IMG_WIDTH,
    img_h=IMG_HEIGHT,
    tag="GRIPPER_LEFT_KEYFRAME_PROMPTS",
    allow_empty=True,
)

gripper_right_prompts_norm = validate_and_normalize_prompt_list(
    GRIPPER_RIGHT_KEYFRAME_PROMPTS,
    total_frames=TOTAL_FRAMES,
    img_w=IMG_WIDTH,
    img_h=IMG_HEIGHT,
    tag="GRIPPER_RIGHT_KEYFRAME_PROMPTS",
    allow_empty=True,
)

validate_obj_id_constraints(
    arm_left_obj_id=arm_left_obj_id,
    arm_right_obj_id=arm_right_obj_id,
    gripper_left_obj_id=gripper_left_obj_id,
    gripper_right_obj_id=gripper_right_obj_id,
    arm_left_prompts=arm_left_prompts_norm,
    arm_right_prompts=arm_right_prompts_norm,
    gripper_left_prompts=gripper_left_prompts_norm,
    gripper_right_prompts=gripper_right_prompts_norm,
)

print("[init] prompt validation passed")
print(
    f"[init] ARM prompt count: left={len(arm_left_prompts_norm)}, right={len(arm_right_prompts_norm)} | "
    f"GRIPPER prompt count: left={len(gripper_left_prompts_norm)}, right={len(gripper_right_prompts_norm)}"
)
print(
    f"[init] configured object IDs: arms={[arm_left_obj_id, arm_right_obj_id]}, "
    f"grippers={[gripper_left_obj_id, gripper_right_obj_id]}"
)

gpus_to_use = range(torch.cuda.device_count())
predictor = build_sam3_video_predictor(
    checkpoint_path=CHECKPOINT_PATH,
    gpus_to_use=gpus_to_use,
    apply_temporal_disambiguation=APPLY_TEMPORAL_DISAMBIGUATION,
)

start_response = predictor.handle_request(
    request=dict(
        type="start_session",
        resource_path=VIDEO_PATH,
    )
)
session_id = start_response["session_id"]
print(f"[init] session started: {session_id}")


## 1.5) 可视化点标注（解耦采集，不触发推理）

最小使用说明：
1. 先运行 `1) 初始化`（上一段 code cell，确保已加载 `video_frames_for_vis` / `TOTAL_FRAMES`）。
2. 运行下方可视化标注 cell：
   - `Frame` 选择标注帧；
   - `Object` 选择对象（左臂/右臂/左夹爪/右夹爪）；
   - `Point Label` 选择 positive(1) / negative(0)；
   - 在图上点击添加点；
   - `Clear Current Obj@Frame` 清空当前对象在当前帧的点。
3. 点击 `Export Prompts` 后会生成 `ANNOTATION_EXPORT_PROMPTS`，并自动将 `USE_VISUAL_ANNOTATION_EXPORT=True`。
4. 继续运行 Stage A / Stage B，主流程不变，仅数据来源切到导出的结构化提示。

In [None]:
# 可视化标注：在消费坐标前，先做可见点选并导出为 *_PROMPTS 兼容结构
import json

ANNOTATION_OBJECT_SPECS = {
    "arm_left": {
        "display": "左臂",
        "obj_id": int(arm_left_obj_id),
        "target": "ARM_LEFT_INITIAL_PROMPTS",
    },
    "arm_right": {
        "display": "右臂",
        "obj_id": int(arm_right_obj_id),
        "target": "ARM_RIGHT_INITIAL_PROMPTS",
    },
    "gripper_left": {
        "display": "左夹爪",
        "obj_id": int(gripper_left_obj_id),
        "target": "GRIPPER_LEFT_KEYFRAME_PROMPTS",
    },
    "gripper_right": {
        "display": "右夹爪",
        "obj_id": int(gripper_right_obj_id),
        "target": "GRIPPER_RIGHT_KEYFRAME_PROMPTS",
    },
}


def _append_click(store, obj_key, frame_idx, x, y, label):
    frame_idx = int(frame_idx)
    store[obj_key].setdefault(frame_idx, []).append(
        {"x": int(x), "y": int(y), "label": int(label)}
    )


def _seed_store_from_prompt_list(store, obj_key, prompt_list, img_w, img_h):
    for entry in (prompt_list or []):
        frame_idx = int(entry["frame_index"])
        coord_type = entry.get("coord_type", "abs")
        points = entry.get("points", [])
        labels = entry.get("labels", [])
        for p, lb in zip(points, labels):
            if coord_type == "rel":
                x = int(round(float(p[0]) * max(img_w - 1, 1)))
                y = int(round(float(p[1]) * max(img_h - 1, 1)))
            else:
                x = int(round(float(p[0])))
                y = int(round(float(p[1])))
            x = max(0, min(int(img_w) - 1, x))
            y = max(0, min(int(img_h) - 1, y))
            _append_click(store, obj_key, frame_idx, x, y, int(lb))


def _prompt_list_from_store(store, obj_key, obj_id):
    prompt_list = []
    for frame_idx in sorted(store[obj_key].keys()):
        clicks = store[obj_key][frame_idx]
        if len(clicks) == 0:
            continue
        prompt_list.append(
            {
                "frame_index": int(frame_idx),
                "obj_id": int(obj_id),
                "coord_type": "abs",
                "points": [[int(c["x"]), int(c["y"])] for c in clicks],
                "labels": [int(c["label"]) for c in clicks],
            }
        )
    return prompt_list


def export_annotation_prompts(annotation_store):
    return {
        "ARM_LEFT_INITIAL_PROMPTS": _prompt_list_from_store(
            annotation_store, "arm_left", ANNOTATION_OBJECT_SPECS["arm_left"]["obj_id"]
        ),
        "ARM_RIGHT_INITIAL_PROMPTS": _prompt_list_from_store(
            annotation_store, "arm_right", ANNOTATION_OBJECT_SPECS["arm_right"]["obj_id"]
        ),
        "GRIPPER_LEFT_KEYFRAME_PROMPTS": _prompt_list_from_store(
            annotation_store,
            "gripper_left",
            ANNOTATION_OBJECT_SPECS["gripper_left"]["obj_id"],
        ),
        "GRIPPER_RIGHT_KEYFRAME_PROMPTS": _prompt_list_from_store(
            annotation_store,
            "gripper_right",
            ANNOTATION_OBJECT_SPECS["gripper_right"]["obj_id"],
        ),
    }


ANNOTATION_MANUAL_TEMPLATE = {
    "ARM_LEFT_INITIAL_PROMPTS": ARM_LEFT_INITIAL_PROMPTS,
    "ARM_RIGHT_INITIAL_PROMPTS": ARM_RIGHT_INITIAL_PROMPTS,
    "GRIPPER_LEFT_KEYFRAME_PROMPTS": GRIPPER_LEFT_KEYFRAME_PROMPTS,
    "GRIPPER_RIGHT_KEYFRAME_PROMPTS": GRIPPER_RIGHT_KEYFRAME_PROMPTS,
}

annotation_store = {k: {} for k in ANNOTATION_OBJECT_SPECS.keys()}

# 预填充：把当前手工配置加载到可视化标注器里（便于增量编辑）
_seed_store_from_prompt_list(annotation_store, "arm_left", ARM_LEFT_INITIAL_PROMPTS, IMG_WIDTH, IMG_HEIGHT)
_seed_store_from_prompt_list(annotation_store, "arm_right", ARM_RIGHT_INITIAL_PROMPTS, IMG_WIDTH, IMG_HEIGHT)
_seed_store_from_prompt_list(
    annotation_store,
    "gripper_left",
    GRIPPER_LEFT_KEYFRAME_PROMPTS,
    IMG_WIDTH,
    IMG_HEIGHT,
)
_seed_store_from_prompt_list(
    annotation_store,
    "gripper_right",
    GRIPPER_RIGHT_KEYFRAME_PROMPTS,
    IMG_WIDTH,
    IMG_HEIGHT,
)

# 默认导出内容先等于当前配置（fallback / 未点击导出时可手工赋值）
if not isinstance(ANNOTATION_EXPORT_PROMPTS, dict):
    ANNOTATION_EXPORT_PROMPTS = dict(ANNOTATION_MANUAL_TEMPLATE)

widget_ready = False
widget_error = None

try:
    from IPython import get_ipython

    ip = get_ipython()
    if ip is None:
        raise RuntimeError("当前不在 IPython/Jupyter 环境")

    # 尽量对齐参考 notebook 的交互体验
    ip.run_line_magic("matplotlib", "widget")

    import ipywidgets as widgets
    from IPython.display import display

    widget_ready = True
except Exception as e:
    widget_error = e
    widget_ready = False

if video_frames_for_vis is None or TOTAL_FRAMES is None:
    print("[annotation][warn] 未检测到可视化帧，请先运行上一个初始化 cell 再执行本 cell。")
elif not widget_ready:
    print(
        "[annotation][fallback] ipympl/ipywidgets 不可用，已回退到手工结构输入，不中断 notebook。"
    )
    print(f"[annotation][fallback] 触发原因: {widget_error}")
    print("[annotation][fallback] 请直接编辑 ANNOTATION_MANUAL_TEMPLATE 或四个 *_PROMPTS 变量。")
    print("[annotation][fallback] 然后设置 USE_VISUAL_ANNOTATION_EXPORT=True 并赋值：")
    print("    ANNOTATION_EXPORT_PROMPTS = ANNOTATION_MANUAL_TEMPLATE")
else:
    frame_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=max(0, int(TOTAL_FRAMES) - 1),
        step=1,
        description="Frame",
        continuous_update=False,
        layout=widgets.Layout(width="380px"),
    )

    object_dropdown = widgets.Dropdown(
        options=[
            (
                f"{spec['display']} (obj_id={spec['obj_id']})",
                key,
            )
            for key, spec in ANNOTATION_OBJECT_SPECS.items()
        ],
        value="arm_left",
        description="Object",
        layout=widgets.Layout(width="380px"),
    )

    label_toggle = widgets.ToggleButtons(
        options=[("positive(1)", 1), ("negative(0)", 0)],
        value=1,
        description="Point Label",
        layout=widgets.Layout(width="380px"),
    )

    clear_btn = widgets.Button(
        description="Clear Current Obj@Frame",
        button_style="warning",
        layout=widgets.Layout(width="220px"),
    )
    refresh_btn = widgets.Button(
        description="Refresh",
        button_style="",
        layout=widgets.Layout(width="100px"),
    )
    export_btn = widgets.Button(
        description="Export Prompts",
        button_style="success",
        layout=widgets.Layout(width="140px"),
    )

    status_out = widgets.Output(layout=widgets.Layout(border="1px solid #aaa"))
    export_out = widgets.Output(layout=widgets.Layout(border="1px solid #aaa"))

    plt.close("all")
    ann_fig, ann_ax = plt.subplots(1, 1, figsize=(9, 6))
    ann_fig.canvas.toolbar_visible = True

    def _current_ctx():
        return int(frame_slider.value), str(object_dropdown.value), int(label_toggle.value)

    def _draw_annotation_canvas():
        frame_idx, obj_key, _ = _current_ctx()
        frame = video_frames_for_vis[frame_idx]

        ann_ax.clear()
        ann_ax.imshow(frame)
        ann_ax.set_title(
            f"Frame={frame_idx} | Object={ANNOTATION_OBJECT_SPECS[obj_key]['display']} "
            f"(obj_id={ANNOTATION_OBJECT_SPECS[obj_key]['obj_id']}) | "
            f"CurrentLabel={label_toggle.value}"
        )
        ann_ax.set_axis_off()

        clicks = annotation_store[obj_key].get(frame_idx, [])
        for idx, c in enumerate(clicks):
            x, y, lb = int(c["x"]), int(c["y"]), int(c["label"])
            color = "lime" if lb == 1 else "red"
            marker = "o" if lb == 1 else "x"
            ann_ax.plot(x, y, marker=marker, color=color, markersize=8, markeredgewidth=2)
            ann_ax.text(
                x + 6,
                y,
                f"{idx}:{lb}",
                color="white",
                fontsize=9,
                bbox=dict(boxstyle="round,pad=0.2", facecolor="black", alpha=0.5),
            )

        with status_out:
            status_out.clear_output()
            print(
                f"[annotation] 当前对象={ANNOTATION_OBJECT_SPECS[obj_key]['display']} "
                f"frame={frame_idx} 点数={len(clicks)}"
            )
            print("[annotation] 点击图像可添加点；绿色=o=positive(1)，红色=x=negative(0)")

        ann_fig.canvas.draw_idle()

    def _on_canvas_click(event):
        if event.inaxes != ann_ax or event.xdata is None or event.ydata is None:
            return

        frame_idx, obj_key, point_label = _current_ctx()
        x = int(round(float(event.xdata)))
        y = int(round(float(event.ydata)))
        x = max(0, min(int(IMG_WIDTH) - 1, x))
        y = max(0, min(int(IMG_HEIGHT) - 1, y))

        _append_click(annotation_store, obj_key, frame_idx, x, y, point_label)
        _draw_annotation_canvas()

    def _on_clear_clicked(_):
        frame_idx, obj_key, _ = _current_ctx()
        if frame_idx in annotation_store[obj_key]:
            annotation_store[obj_key].pop(frame_idx, None)
        _draw_annotation_canvas()

    def _on_refresh_clicked(_):
        _draw_annotation_canvas()

    def _on_export_clicked(_):
        global ANNOTATION_EXPORT_PROMPTS, USE_VISUAL_ANNOTATION_EXPORT

        ANNOTATION_EXPORT_PROMPTS = export_annotation_prompts(annotation_store)
        USE_VISUAL_ANNOTATION_EXPORT = True

        with export_out:
            export_out.clear_output()
            summary = {
                k: len(v)
                for k, v in ANNOTATION_EXPORT_PROMPTS.items()
            }
            print("[annotation] 导出完成，已自动设置 USE_VISUAL_ANNOTATION_EXPORT=True")
            print(f"[annotation] 各对象关键帧条目数: {summary}")
            print("[annotation] 导出结构（可直接被 Stage A/B 消费）:")
            print(json.dumps(ANNOTATION_EXPORT_PROMPTS, ensure_ascii=False, indent=2))

    frame_slider.observe(lambda _: _draw_annotation_canvas(), names="value")
    object_dropdown.observe(lambda _: _draw_annotation_canvas(), names="value")
    label_toggle.observe(lambda _: _draw_annotation_canvas(), names="value")
    clear_btn.on_click(_on_clear_clicked)
    refresh_btn.on_click(_on_refresh_clicked)
    export_btn.on_click(_on_export_clicked)
    ann_fig.canvas.mpl_connect("button_press_event", _on_canvas_click)

    controls = widgets.VBox(
        [
            widgets.HBox([frame_slider, object_dropdown]),
            widgets.HBox([label_toggle]),
            widgets.HBox([clear_btn, refresh_btn, export_btn]),
            status_out,
            export_out,
        ]
    )

    display(controls)
    _draw_annotation_canvas()
    plt.show()


## 2) 阶段 A：先 bootstrap 普通传播，再做机械臂 points refinement


In [None]:
stage_a_bootstrap_configs = resolve_stage_a_bootstrap_configs(
    stage_side_configs=[
        dict(
            side_name="left",
            bootstrap_text_prompt=ARM_LEFT_BOOTSTRAP_TEXT_PROMPT,
            fallback_text_prompt=ARM_LEFT_BOOTSTRAP_FALLBACK_TEXT_PROMPT,
            bootstrap_frame_index=ARM_LEFT_BOOTSTRAP_FRAME_INDEX,
            prompt_list=arm_left_prompts_norm,
        ),
        dict(
            side_name="right",
            bootstrap_text_prompt=ARM_RIGHT_BOOTSTRAP_TEXT_PROMPT,
            fallback_text_prompt=ARM_RIGHT_BOOTSTRAP_FALLBACK_TEXT_PROMPT,
            bootstrap_frame_index=ARM_RIGHT_BOOTSTRAP_FRAME_INDEX,
            prompt_list=arm_right_prompts_norm,
        ),
    ],
    total_frames=TOTAL_FRAMES,
)

stage_a_active_obj_ids = sorted({arm_left_obj_id, arm_right_obj_id})
print(f"[stage A] active_obj_ids={stage_a_active_obj_ids}")

for cfg in stage_a_bootstrap_configs:
    print(
        f"[stage A][bootstrap] side={cfg['side_name']} frame={cfg['frame_index']} "
        f"text_source={cfg['text_source']}"
    )
    add_text_prompt(
        predictor_obj=predictor,
        session_id_value=session_id,
        frame_index=cfg["frame_index"],
        text_prompt=cfg["text_prompt"],
        stage_name=f"stage A/bootstrap/{cfg['side_name']}",
    )

print("[stage A][bootstrap] propagating bidirectional (forward + backward) ...")
outputs_stage_a = propagate_bidirectional_and_merge(
    predictor_obj=predictor,
    session_id_value=session_id,
    stage_name="stage A/bootstrap",
)
print(f"[stage A][bootstrap] done, merged frame outputs={len(outputs_stage_a)}")

arm_prompts_merged = sorted(
    list(arm_left_prompts_norm) + list(arm_right_prompts_norm),
    key=lambda x: (x["frame_index"], x["obj_id"]),
)

if len(arm_prompts_merged) > 0:
    print("[stage A][refinement] applying left/right arm point prompts after bootstrap ...")
    apply_prompt_list(
        predictor_obj=predictor,
        session_id_value=session_id,
        prompt_list=arm_prompts_merged,
        stage_name="stage A/refinement",
    )

    print("[stage A][refinement] propagating bidirectional (forward + backward) ...")
    outputs_stage_a = propagate_bidirectional_and_merge(
        predictor_obj=predictor,
        session_id_value=session_id,
        stage_name="stage A/refinement",
    )
    print(f"[stage A][refinement] done, merged frame outputs={len(outputs_stage_a)}")
else:
    print("[stage A][refinement] no arm point prompts configured; keeping bootstrap outputs as stage A result")


In [None]:
visualize_outputs(
    outputs_per_frame=outputs_stage_a,
    video_frames=video_frames_for_vis,
    stride=VIS_FRAME_STRIDE,
    max_plots=VIS_MAX_PLOTS,
    title="Stage A: Arm segmentation",
)

## 3) 阶段 B：关键帧 points 新增 gripper（全新 obj_id）并再次传播

In [None]:
stage_b_new_obj_ids = sorted({gripper_left_obj_id, gripper_right_obj_id})
stage_b_union_obj_ids = sorted({
    arm_left_obj_id,
    arm_right_obj_id,
    gripper_left_obj_id,
    gripper_right_obj_id,
})
print(f"[stage B] active_obj_ids(new grippers)={stage_b_new_obj_ids}")
print(f"[stage B] expected_obj_ids_after_merge={stage_b_union_obj_ids}")

gripper_prompts_merged = sorted(
    list(gripper_left_prompts_norm) + list(gripper_right_prompts_norm),
    key=lambda x: (x["frame_index"], x["obj_id"]),
)

if len(gripper_prompts_merged) > 0:
    print("[stage B] injecting left/right gripper prompts...")
    apply_prompt_list(
        predictor_obj=predictor,
        session_id_value=session_id,
        prompt_list=gripper_prompts_merged,
        stage_name="stage B",
    )
else:
    print("[stage B] no gripper prompts configured; skipping prompt injection")

print("[stage B] propagating bidirectional (forward + backward) ...")
outputs_stage_b = propagate_bidirectional_and_merge(
    predictor_obj=predictor,
    session_id_value=session_id,
    stage_name="stage B",
)
print(f"[stage B] done, merged frame outputs={len(outputs_stage_b)}")


In [None]:
visualize_outputs(
    outputs_per_frame=outputs_stage_b,
    video_frames=video_frames_for_vis,
    stride=VIS_FRAME_STRIDE,
    max_plots=VIS_MAX_PLOTS,
    title="Stage B: Arm + Gripper segmentation",
)

## 4) 导出 mask（支持 arm-only / gripper-only / union / custom）

In [None]:
export_obj_ids = resolve_export_obj_ids(
    export_mode=EXPORT_MODE,
    arm_obj_ids=[arm_left_obj_id, arm_right_obj_id],
    gripper_obj_ids=[gripper_left_obj_id, gripper_right_obj_id],
    custom_obj_ids=EXPORT_CUSTOM_OBJ_IDS,
)

print(f"[export] EXPORT_MODE={EXPORT_MODE}, resolved_obj_ids={export_obj_ids}")

mask_paths = save_masks_for_propainter(
    outputs_per_frame=outputs_stage_b,
    video_frames=video_frames_for_vis,
    output_dir=EXPORT_OUTPUT_DIR,
    target_obj_ids=export_obj_ids,
    dilate_radius=EXPORT_DILATE_RADIUS,
)

print(f"[export] sample: first={mask_paths[0] if mask_paths else 'N/A'}")
print(f"[export] sample: last={mask_paths[-1] if mask_paths else 'N/A'}")


## 5) 资源清理（session close + predictor shutdown + 进程组清理）

In [None]:
cleanup_resources(
    predictor_obj=predictor,
    session_id_value=session_id,
)

predictor = None
session_id = None
print("[cleanup] globals reset done")