# SAM3 远程无 GUI 两阶段分割（机械臂 + gripper 新 obj_id）

本 notebook 面向远程服务器（无 GUI）逐 cell 交互执行：

- 阶段 A：先做一次普通传播（bootstrap）填充缓存，再进行机械臂 points/labels refinement。
- 阶段 B：必须在关键帧注入 gripper 独立对象（左右 obj_id），并进行双向传播。
- 导出固定为 arm-only：使用 Stage B 的 gripper obj_id 掩码执行扣除 `arm_only = arm_mask AND NOT gripper_mask`。


In [None]:
import os
import glob
from pathlib import Path

import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

import sam3
from sam3.model_builder import build_sam3_video_predictor
from sam3.visualization_utils import prepare_masks_for_visualization, visualize_formatted_frame_output

plt.rcParams["axes.titlesize"] = 12
plt.rcParams["figure.titlesize"] = 12

# ==============================
# 关键配置区（建议仅修改本区）
# ==============================

# 1) 路径配置
VIDEO_PATH = "/data/haoxiang/data/airexo2/task_0013/train/scene_0001/cam_105422061350/color"
CHECKPOINT_PATH = "/data/haoxiang/sam3/models/facebook/sam3/sam3.pt"

# 2) 推理配置
CUDA_VISIBLE_DEVICES = "0,1,2,3"
APPLY_TEMPORAL_DISAMBIGUATION = False
PROPAGATION_DIRECTION_STAGE_A = "forward"  # 建议 forward
PROPAGATION_DIRECTION_STAGE_B = "forward"  # 关键帧新增对象后再 forward

# 3) 对象 ID 配置（左右机械臂 + 左右 gripper，四者必须互不重复）
arm_left_obj_id = 0
arm_right_obj_id = 1
gripper_left_obj_id = 2
gripper_right_obj_id = 3

# 4) 阶段 A：左右机械臂 bootstrap 文本提示（可独立配置）
# 为避免首次 points add_prompt 触发 cached outputs 断言，需先完成一次普通传播
ARM_LEFT_BOOTSTRAP_TEXT_PROMPT = ""  # 建议可填，如 "left robot arm"
ARM_LEFT_BOOTSTRAP_FALLBACK_TEXT_PROMPT = "left robot arm"
ARM_LEFT_BOOTSTRAP_FRAME_INDEX = None

ARM_RIGHT_BOOTSTRAP_TEXT_PROMPT = ""  # 建议可填，如 "right robot arm"
ARM_RIGHT_BOOTSTRAP_FALLBACK_TEXT_PROMPT = "right robot arm"
ARM_RIGHT_BOOTSTRAP_FRAME_INDEX = None

# 5) 阶段 A：左右机械臂 points refinement（可为空，左右独立）
# coord_type: "abs"(像素坐标) / "rel"([0,1]归一化坐标)
ARM_LEFT_INITIAL_PROMPTS = [
    {
        "frame_index": 120,
        "obj_id": 0,
        "coord_type": "abs",
        "points": [[962, 350]],
        "labels": [0],
    }
]

ARM_RIGHT_INITIAL_PROMPTS = []

# 6) 阶段 B：关键帧新增左右 gripper（points + point_labels + 新 obj_id）
# 支持多关键帧，左右 obj_id 可独立配置
GRIPPER_LEFT_KEYFRAME_PROMPTS = [
    {
        "frame_index": 120,
        "obj_id": 2,
        "coord_type": "abs",
        "points": [[973, 356], [930, 330]],
        "labels": [1, 0],
    }
]

GRIPPER_RIGHT_KEYFRAME_PROMPTS = []

# 6.5) 可视化标注导出接入（仅切换数据来源，不触发推理）
# True: Stage A/B 使用 ANNOTATION_EXPORT_PROMPTS
# False: Stage A/B 使用手工配置的 *_PROMPTS
USE_VISUAL_ANNOTATION_EXPORT = False
ANNOTATION_EXPORT_PROMPTS = None

# 7) 可视化配置
VIS_FRAME_STRIDE = 60
VIS_MAX_PLOTS = 8

# 8) 导出配置（固定 arm-only）
EXPORT_OUTPUT_DIR = "/data/haoxiang/propainter/masks_airexo_arm_only"
EXPORT_DILATE_RADIUS = 15
EXPORT_LOG_EVERY = 50

# 运行时变量（无需手动修改）
predictor = None
session_id = None
video_frames_for_vis = None
TOTAL_FRAMES = None
IMG_WIDTH = None
IMG_HEIGHT = None
arm_left_prompts_norm = None
arm_right_prompts_norm = None
gripper_left_prompts_norm = None
gripper_right_prompts_norm = None
outputs_stage_a = None
outputs_stage_b = None


## 工具函数（校验 / 坐标转换 / 推理流程 / 导出 / 清理）

In [None]:
from mask_pipeline_tools import (
    cleanup_process_group,
    cleanup_resources,
    load_video_frames_for_visualization,
    get_frame_size,
    abs_to_rel_points,
    validate_prompt_entry,
    normalize_prompt_entry,
    validate_and_normalize_prompt_list,
    validate_obj_id_constraints,
    propagate_in_video,
    propagate_bidirectional_and_merge,
    add_point_prompt,
    apply_prompt_list,
    add_text_prompt,
    resolve_stage_a_bootstrap_configs,
    visualize_outputs,
    save_arm_only_masks_for_propainter,
)

## 1) 载入可视化帧 + 校验配置 + 初始化 predictor/session

In [None]:
# Notebook 反复执行时，先清理旧资源，避免 NCCL 初始化异常
if "predictor" in globals() and predictor is not None:
    print("[init] cleaning previous predictor/session before re-run")
    cleanup_resources(predictor_obj=predictor, session_id_value=session_id)

os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES
print(f"[init] CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")

video_frames_for_vis = load_video_frames_for_visualization(VIDEO_PATH)
TOTAL_FRAMES = len(video_frames_for_vis)
IMG_WIDTH, IMG_HEIGHT = get_frame_size(video_frames_for_vis)
print(f"[init] loaded frames={TOTAL_FRAMES}, size={IMG_WIDTH}x{IMG_HEIGHT}")

if USE_VISUAL_ANNOTATION_EXPORT:
    if isinstance(ANNOTATION_EXPORT_PROMPTS, dict):
        ARM_LEFT_INITIAL_PROMPTS = ANNOTATION_EXPORT_PROMPTS.get(
            "ARM_LEFT_INITIAL_PROMPTS", ARM_LEFT_INITIAL_PROMPTS
        )
        ARM_RIGHT_INITIAL_PROMPTS = ANNOTATION_EXPORT_PROMPTS.get(
            "ARM_RIGHT_INITIAL_PROMPTS", ARM_RIGHT_INITIAL_PROMPTS
        )
        GRIPPER_LEFT_KEYFRAME_PROMPTS = ANNOTATION_EXPORT_PROMPTS.get(
            "GRIPPER_LEFT_KEYFRAME_PROMPTS", GRIPPER_LEFT_KEYFRAME_PROMPTS
        )
        GRIPPER_RIGHT_KEYFRAME_PROMPTS = ANNOTATION_EXPORT_PROMPTS.get(
            "GRIPPER_RIGHT_KEYFRAME_PROMPTS", GRIPPER_RIGHT_KEYFRAME_PROMPTS
        )
        print("[annotation] using prompts exported from visual annotation cells")
    else:
        print(
            "[annotation][warn] USE_VISUAL_ANNOTATION_EXPORT=True 但 ANNOTATION_EXPORT_PROMPTS 无效，"
            "继续使用手工配置的 *_PROMPTS"
        )

arm_left_prompts_norm = validate_and_normalize_prompt_list(
    ARM_LEFT_INITIAL_PROMPTS,
    total_frames=TOTAL_FRAMES,
    img_w=IMG_WIDTH,
    img_h=IMG_HEIGHT,
    tag="ARM_LEFT_INITIAL_PROMPTS",
    allow_empty=True,
)

arm_right_prompts_norm = validate_and_normalize_prompt_list(
    ARM_RIGHT_INITIAL_PROMPTS,
    total_frames=TOTAL_FRAMES,
    img_w=IMG_WIDTH,
    img_h=IMG_HEIGHT,
    tag="ARM_RIGHT_INITIAL_PROMPTS",
    allow_empty=True,
)

gripper_left_prompts_norm = validate_and_normalize_prompt_list(
    GRIPPER_LEFT_KEYFRAME_PROMPTS,
    total_frames=TOTAL_FRAMES,
    img_w=IMG_WIDTH,
    img_h=IMG_HEIGHT,
    tag="GRIPPER_LEFT_KEYFRAME_PROMPTS",
    allow_empty=True,
)

gripper_right_prompts_norm = validate_and_normalize_prompt_list(
    GRIPPER_RIGHT_KEYFRAME_PROMPTS,
    total_frames=TOTAL_FRAMES,
    img_w=IMG_WIDTH,
    img_h=IMG_HEIGHT,
    tag="GRIPPER_RIGHT_KEYFRAME_PROMPTS",
    allow_empty=True,
)

validate_obj_id_constraints(
    arm_left_obj_id=arm_left_obj_id,
    arm_right_obj_id=arm_right_obj_id,
    gripper_left_obj_id=gripper_left_obj_id,
    gripper_right_obj_id=gripper_right_obj_id,
    arm_left_prompts=arm_left_prompts_norm,
    arm_right_prompts=arm_right_prompts_norm,
    gripper_left_prompts=gripper_left_prompts_norm,
    gripper_right_prompts=gripper_right_prompts_norm,
)

print("[init] prompt validation passed")
print(
    f"[init] ARM prompt count: left={len(arm_left_prompts_norm)}, right={len(arm_right_prompts_norm)} | "
    f"GRIPPER prompt count: left={len(gripper_left_prompts_norm)}, right={len(gripper_right_prompts_norm)}"
)
print(
    f"[init] configured object IDs: arms={[arm_left_obj_id, arm_right_obj_id]}, "
    f"grippers={[gripper_left_obj_id, gripper_right_obj_id]}"
)

gpus_to_use = range(torch.cuda.device_count())
predictor = build_sam3_video_predictor(
    checkpoint_path=CHECKPOINT_PATH,
    gpus_to_use=gpus_to_use,
    apply_temporal_disambiguation=APPLY_TEMPORAL_DISAMBIGUATION,
)

start_response = predictor.handle_request(
    request=dict(
        type="start_session",
        resource_path=VIDEO_PATH,
    )
)
session_id = start_response["session_id"]
print(f"[init] session started: {session_id}")


## 1.5) 可视化点标注（解耦采集，不触发推理）

最小使用说明：
1. 先运行 `1) 初始化`（上一段 code cell，确保已加载 `video_frames_for_vis` / `TOTAL_FRAMES`）。
2. 运行下方可视化标注 cell：
   - `Frame` 选择标注帧；
   - `Object` 选择对象（左臂/右臂/左夹爪/右夹爪）；
   - `Point Label` 选择 positive(1) / negative(0)；
   - 在图上点击添加点；
   - `Clear Current Obj@Frame` 清空当前对象在当前帧的点。
3. 点击 `Export Prompts` 后会生成 `ANNOTATION_EXPORT_PROMPTS`，并自动将 `USE_VISUAL_ANNOTATION_EXPORT=True`。
4. 继续运行 Stage A / Stage B，主流程不变，仅数据来源切到导出的结构化提示。

In [None]:
# 可视化标注：在消费坐标前，先做可见点选并导出为 *_PROMPTS 兼容结构
import json
import traceback

ANNOTATION_MANUAL_TEMPLATE = {
    "ARM_LEFT_INITIAL_PROMPTS": ARM_LEFT_INITIAL_PROMPTS,
    "ARM_RIGHT_INITIAL_PROMPTS": ARM_RIGHT_INITIAL_PROMPTS,
    "GRIPPER_LEFT_KEYFRAME_PROMPTS": GRIPPER_LEFT_KEYFRAME_PROMPTS,
    "GRIPPER_RIGHT_KEYFRAME_PROMPTS": GRIPPER_RIGHT_KEYFRAME_PROMPTS,
}

try:
    from annotation_ui_tools import (
        build_annotation_object_specs,
        create_annotation_store,
        create_annotation_ui,
        seed_store_from_prompt_map,
    )
    _annotation_module_error = None
except Exception as e:
    _annotation_module_error = e
    _annotation_module_traceback = traceback.format_exc()

# 默认导出内容先等于当前配置（fallback / 未点击导出时可手工赋值）
if not isinstance(ANNOTATION_EXPORT_PROMPTS, dict):
    ANNOTATION_EXPORT_PROMPTS = dict(ANNOTATION_MANUAL_TEMPLATE)

if video_frames_for_vis is None or TOTAL_FRAMES is None:
    print("[annotation][warn] 未检测到可视化帧，请先运行上一个初始化 cell 再执行本 cell。")
elif _annotation_module_error is not None:
    print("[annotation][error] annotation_ui_tools 模块导入失败，已回退到手工结构输入，不中断 notebook。")
    print(f"[annotation][error] 触发原因: {_annotation_module_error}")
    print("[annotation][error] traceback:")
    print(_annotation_module_traceback)
    print("[annotation][fallback] 请直接编辑 ANNOTATION_MANUAL_TEMPLATE 或四个 *_PROMPTS 变量。")
    print("[annotation][fallback] 然后设置 USE_VISUAL_ANNOTATION_EXPORT=True 并赋值：")
    print("    ANNOTATION_EXPORT_PROMPTS = ANNOTATION_MANUAL_TEMPLATE")
else:
    ANNOTATION_OBJECT_SPECS = build_annotation_object_specs(
        arm_left_obj_id=arm_left_obj_id,
        arm_right_obj_id=arm_right_obj_id,
        gripper_left_obj_id=gripper_left_obj_id,
        gripper_right_obj_id=gripper_right_obj_id,
    )

    annotation_store = create_annotation_store(ANNOTATION_OBJECT_SPECS)
    seed_store_from_prompt_map(
        store=annotation_store,
        object_specs=ANNOTATION_OBJECT_SPECS,
        prompt_map=ANNOTATION_MANUAL_TEMPLATE,
        img_w=IMG_WIDTH,
        img_h=IMG_HEIGHT,
    )

    def _on_annotation_export(export_prompts):
        global ANNOTATION_EXPORT_PROMPTS, USE_VISUAL_ANNOTATION_EXPORT
        ANNOTATION_EXPORT_PROMPTS = dict(export_prompts)
        USE_VISUAL_ANNOTATION_EXPORT = True

    ANNOTATION_UI = create_annotation_ui(
        video_frames_for_vis=video_frames_for_vis,
        total_frames=TOTAL_FRAMES,
        img_width=IMG_WIDTH,
        img_height=IMG_HEIGHT,
        object_specs=ANNOTATION_OBJECT_SPECS,
        annotation_store=annotation_store,
        on_export=_on_annotation_export,
        auto_display=True,
        status_prefix="[annotation]",
    )

    if not ANNOTATION_UI.get("widget_ready", False):
        widget_error = ANNOTATION_UI.get("widget_error")
        print(
            "[annotation][fallback] ipympl/ipywidgets 不可用，已回退到手工结构输入，不中断 notebook。"
        )
        print(f"[annotation][fallback] 触发原因: {widget_error}")
        print("[annotation][fallback] 请直接编辑 ANNOTATION_MANUAL_TEMPLATE 或四个 *_PROMPTS 变量。")
        print("[annotation][fallback] 然后设置 USE_VISUAL_ANNOTATION_EXPORT=True 并赋值：")
        print("    ANNOTATION_EXPORT_PROMPTS = ANNOTATION_MANUAL_TEMPLATE")

    _annotation_state = ANNOTATION_UI.get("state")
    if (
        _annotation_state is not None
        and isinstance(getattr(_annotation_state, "export_prompts", None), dict)
        and _annotation_state.export_prompts
    ):
        ANNOTATION_EXPORT_PROMPTS = dict(_annotation_state.export_prompts)
        USE_VISUAL_ANNOTATION_EXPORT = bool(
            getattr(_annotation_state, "use_visual_annotation_export", False)
        )




## 2) 阶段 A：先 bootstrap 普通传播，再做机械臂 points refinement


In [None]:
stage_a_bootstrap_configs = resolve_stage_a_bootstrap_configs(
    stage_side_configs=[
        dict(
            side_name="left",
            bootstrap_text_prompt=ARM_LEFT_BOOTSTRAP_TEXT_PROMPT,
            fallback_text_prompt=ARM_LEFT_BOOTSTRAP_FALLBACK_TEXT_PROMPT,
            bootstrap_frame_index=ARM_LEFT_BOOTSTRAP_FRAME_INDEX,
            prompt_list=arm_left_prompts_norm,
        ),
        dict(
            side_name="right",
            bootstrap_text_prompt=ARM_RIGHT_BOOTSTRAP_TEXT_PROMPT,
            fallback_text_prompt=ARM_RIGHT_BOOTSTRAP_FALLBACK_TEXT_PROMPT,
            bootstrap_frame_index=ARM_RIGHT_BOOTSTRAP_FRAME_INDEX,
            prompt_list=arm_right_prompts_norm,
        ),
    ],
    total_frames=TOTAL_FRAMES,
)

stage_a_active_obj_ids = sorted({arm_left_obj_id, arm_right_obj_id})
print(f"[stage A] active_obj_ids={stage_a_active_obj_ids}")

for cfg in stage_a_bootstrap_configs:
    print(
        f"[stage A][bootstrap] side={cfg['side_name']} frame={cfg['frame_index']} "
        f"text_source={cfg['text_source']}"
    )
    add_text_prompt(
        predictor_obj=predictor,
        session_id_value=session_id,
        frame_index=cfg["frame_index"],
        text_prompt=cfg["text_prompt"],
        stage_name=f"stage A/bootstrap/{cfg['side_name']}",
    )

print("[stage A][bootstrap] propagating bidirectional (forward + backward) ...")
outputs_stage_a = propagate_bidirectional_and_merge(
    predictor_obj=predictor,
    session_id_value=session_id,
    stage_name="stage A/bootstrap",
)
print(f"[stage A][bootstrap] done, merged frame outputs={len(outputs_stage_a)}")

arm_prompts_merged = sorted(
    list(arm_left_prompts_norm) + list(arm_right_prompts_norm),
    key=lambda x: (x["frame_index"], x["obj_id"]),
)

if len(arm_prompts_merged) > 0:
    print("[stage A][refinement] applying left/right arm point prompts after bootstrap ...")
    apply_prompt_list(
        predictor_obj=predictor,
        session_id_value=session_id,
        prompt_list=arm_prompts_merged,
        stage_name="stage A/refinement",
    )

    print("[stage A][refinement] propagating bidirectional (forward + backward) ...")
    outputs_stage_a = propagate_bidirectional_and_merge(
        predictor_obj=predictor,
        session_id_value=session_id,
        stage_name="stage A/refinement",
    )
    print(f"[stage A][refinement] done, merged frame outputs={len(outputs_stage_a)}")
else:
    print("[stage A][refinement] no arm point prompts configured; keeping bootstrap outputs as stage A result")


In [None]:
visualize_outputs(
    outputs_per_frame=outputs_stage_a,
    video_frames=video_frames_for_vis,
    stride=VIS_FRAME_STRIDE,
    max_plots=VIS_MAX_PLOTS,
    title="Stage A: Arm segmentation",
)

## 3) 阶段 B：关键帧 points 新增 gripper（全新 obj_id）并再次传播

In [None]:
stage_b_new_obj_ids = sorted({gripper_left_obj_id, gripper_right_obj_id})
stage_b_union_obj_ids = sorted({
    arm_left_obj_id,
    arm_right_obj_id,
    gripper_left_obj_id,
    gripper_right_obj_id,
})
print(f"[stage B] active_obj_ids(new grippers)={stage_b_new_obj_ids}")
print(f"[stage B] expected_obj_ids_after_merge={stage_b_union_obj_ids}")

if len(gripper_left_prompts_norm) == 0 and len(gripper_right_prompts_norm) == 0:
    raise ValueError(
        "[stage B] gripper 提示缺失：左右 gripper 均未提供关键帧 points。"
        "本流程必须注入 gripper 独立对象后才能导出 arm-only。"
    )

gripper_prompts_merged = sorted(
    list(gripper_left_prompts_norm) + list(gripper_right_prompts_norm),
    key=lambda x: (x["frame_index"], x["obj_id"]),
)

if len(gripper_prompts_merged) == 0:
    raise ValueError(
        "[stage B] gripper prompts 为空，已阻断。"
        "请先提供至少一侧 gripper 的有效关键帧提示。"
    )

print("[stage B] injecting left/right gripper prompts (required)...")
apply_prompt_list(
    predictor_obj=predictor,
    session_id_value=session_id,
    prompt_list=gripper_prompts_merged,
    stage_name="stage B",
)

print("[stage B] propagating bidirectional (forward + backward) ...")
outputs_stage_b = propagate_bidirectional_and_merge(
    predictor_obj=predictor,
    session_id_value=session_id,
    stage_name="stage B",
)
print(f"[stage B] done, merged frame outputs={len(outputs_stage_b)}")

has_any_gripper_obj = any(
    any(int(obj_id) in stage_b_new_obj_ids for obj_id in frame_outputs.keys())
    for frame_outputs in outputs_stage_b.values()
)
if not has_any_gripper_obj:
    raise ValueError(
        "[stage B] 传播结果中未发现 gripper obj_id，视为 Stage B 无有效结果。"
        "已阻断导出以避免错误 arm-only 掩码。"
    )


In [None]:
visualize_outputs(
    outputs_per_frame=outputs_stage_b,
    video_frames=video_frames_for_vis,
    stride=VIS_FRAME_STRIDE,
    max_plots=VIS_MAX_PLOTS,
    title="Stage B: Arm + Gripper segmentation",
)

## 4) 导出 mask（固定 arm-only：arm - gripper）

In [None]:
if outputs_stage_b is None or len(outputs_stage_b) == 0:
    raise ValueError("[export] outputs_stage_b 为空，无法导出 arm-only")

print("[export] mode fixed: arm-only subtraction (arm AND NOT gripper)")

mask_paths = save_arm_only_masks_for_propainter(
    outputs_per_frame=outputs_stage_b,
    video_frames=video_frames_for_vis,
    output_dir=EXPORT_OUTPUT_DIR,
    arm_obj_ids=[arm_left_obj_id, arm_right_obj_id],
    gripper_obj_ids=[gripper_left_obj_id, gripper_right_obj_id],
    dilate_radius=EXPORT_DILATE_RADIUS,
    log_every=EXPORT_LOG_EVERY,
)

print(f"[export] sample: first={mask_paths[0] if mask_paths else 'N/A'}")
print(f"[export] sample: last={mask_paths[-1] if mask_paths else 'N/A'}")


## 5) 资源清理（session close + predictor shutdown + 进程组清理）

In [None]:
cleanup_resources(
    predictor_obj=predictor,
    session_id_value=session_id,
)

predictor = None
session_id = None
print("[cleanup] globals reset done")