# SAM3 Arm 分割（Session 1）

**运行顺序**：`mask_config.yaml` → imports → config → init → [Arm 标注 UI] → Session 1 Arm 分割 → cleanup

- text bootstrap 初始化，可选点标注精化
- 分割结果保存到 checkpoint（`_ckpt_arm` 目录）
- 与 `mask_2_gripper.ipynb` 完全独立，可任意顺序运行
- 运行完成后在 `mask_3_merge.ipynb` 中合并


In [None]:
SCENE_NAME = "scene_0018"

In [None]:
# imports

# ============================================================
# Imports
# ============================================================
%load_ext autoreload
%autoreload 2

import json
import os
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from PIL import Image

import sam3
from sam3.model_builder import build_sam3_video_predictor

from mask_pipeline_tools import (
    add_text_prompt,
    apply_prompt_list,
    cleanup_process_group,
    cleanup_resources,
    get_frame_size,
    iter_object_masks_from_frame_output,
    load_video_frames_for_visualization,
    propagate_bidirectional_and_merge,
    validate_and_normalize_prompt_list,
    visualize_outputs,
)
from annotation_ui_tools import (
    create_annotation_store,
    create_annotation_ui,
    load_annotation_prompts_json,
    seed_store_from_prompt_map,
)

plt.rcParams["axes.titlesize"] = 12
plt.rcParams["figure.titlesize"] = 12

# 配置加载

# ============================================================
# 配置加载 — 修改 mask_config.yaml 来调整参数
# ============================================================
import yaml

_CONFIG_PATH = Path("mask_config.yaml")
with open(_CONFIG_PATH, "r") as _f:
    _cfg = yaml.safe_load(_f)

TASK_NAME = _cfg['task_name']
# SCENE_NAME = _cfg['scene_name']
CHECKPOINT_PATH = _cfg['sam3_checkpoint']
CUDA_VISIBLE_DEVICES = _cfg['cuda_visible_devices_arm']
APPLY_TEMPORAL_DISAMBIGUATION = bool(_cfg['apply_temporal_disambiguation'])
ARM_OBJ_ID = int(_cfg['arm_obj_id'])
ARM_OBJ_ID_2 = int(_cfg['arm_obj_id_2']) if _cfg.get('arm_obj_id_2') is not None else None
GRIPPER_LEFT_OBJ_ID = int(_cfg['gripper_left_obj_id'])
GRIPPER_RIGHT_OBJ_ID = int(_cfg['gripper_right_obj_id'])
ARM_TEXT_PROMPT = _cfg['arm_text_prompt']
ARM_TEXT_BOOTSTRAP_FRAME_INDEX = int(_cfg['arm_text_bootstrap_frame_index'])
VIS_FRAME_STRIDE = int(_cfg['vis_frame_stride'])
VIS_MAX_PLOTS = int(_cfg['vis_max_plots'])
EXPORT_ARM_DILATE_RADIUS = int(_cfg['export_arm_dilate_radius'])
EXPORT_GRIPPER_DILATE_RADIUS = int(_cfg['export_gripper_dilate_radius'])
EXPORT_LOG_EVERY = int(_cfg['export_log_every'])

# Derived paths
_data_base = _cfg['data_base_dir']
_export_base_str = _cfg['export_base_dir']
VIDEO_PATH = f"{_data_base}/{TASK_NAME}/train/{SCENE_NAME}/cam_105422061350/color"
ANNOTATION_JSON_PATH = str(Path(VIDEO_PATH).parent / "annotation_prompts_gripper_points.json")
ARM_ANNOTATION_JSON_PATH = str(Path(VIDEO_PATH).parent / "annotation_prompts_arm_points.json")
EXPORT_OUTPUT_DIR = f"{_export_base_str}/{TASK_NAME}/{SCENE_NAME}"
_export_base = Path(EXPORT_OUTPUT_DIR)
CHECKPOINT_ARM_DIR     = str(_export_base.parent / (_export_base.name + "_ckpt_arm"))
CHECKPOINT_GRIPPER_DIR = str(_export_base.parent / (_export_base.name + "_ckpt_gripper"))

print(f"[config] task={TASK_NAME}  scene={SCENE_NAME}")
print(f"[config] VIDEO_PATH={VIDEO_PATH}")
print(f"[config] ANNOTATION_JSON_PATH={ANNOTATION_JSON_PATH}")
print(f"[config] ARM_ANNOTATION_JSON_PATH={ARM_ANNOTATION_JSON_PATH}")
print(f"[config] CHECKPOINT_ARM_DIR={CHECKPOINT_ARM_DIR}")
print(f"[config] CHECKPOINT_GRIPPER_DIR={CHECKPOINT_GRIPPER_DIR}")
print(f"[config] EXPORT_OUTPUT_DIR={EXPORT_OUTPUT_DIR}")

# 初始化

# ============================================================
# 运行时状态（不需要修改）
# ============================================================
video_frames_for_vis = None
TOTAL_FRAMES = None
IMG_WIDTH = None
IMG_HEIGHT = None

predictor_arm = None
session_id_arm = None
outputs_arm = None

arm_prompts_norm = None          # arm 点标注（加载后存于此）

# ============================================================
# 加载视频帧
# ============================================================

# 重复执行时先清理残留 predictor，防止 NCCL 初始化异常
if predictor_arm is not None:
    print('[init] cleaning up previous arm predictor')
    cleanup_resources(predictor_obj=predictor_arm, session_id_value=session_id_arm)
    predictor_arm = None
    session_id_arm = None

os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
print(f'[init] CUDA_VISIBLE_DEVICES={os.environ.get("CUDA_VISIBLE_DEVICES")}')

video_frames_for_vis = load_video_frames_for_visualization(VIDEO_PATH)
TOTAL_FRAMES = len(video_frames_for_vis)
IMG_WIDTH, IMG_HEIGHT = get_frame_size(video_frames_for_vis)

print(f'[init] frames={TOTAL_FRAMES}  size={IMG_WIDTH}x{IMG_HEIGHT}')
print(f'[init] VIDEO_PATH={VIDEO_PATH}')
print(f'[init] ARM_ANNOTATION_JSON_PATH={ARM_ANNOTATION_JSON_PATH}')
print(f'[init] CHECKPOINT_ARM_DIR={CHECKPOINT_ARM_DIR}')


In [None]:
# Arm 标注 UI
# ============================================================
# Arm 标注 UI
# ============================================================
# 标注模式
# "overwrite" : 覆盖保存，UI 空白启动（默认行为）
# "append"    : 追加模式，自动加载已有 JSON 到 UI，可在已有帧添加点或新增帧
# ============================================================
ANNOTATION_MODE = "append"  # ← 按需切换为 "overwrite"

# 为 arm_cable（主臂/右臂, obj_id=ARM_OBJ_ID）和
#    arm_cable_2（左臂, obj_id=ARM_OBJ_ID_2）标注关键帧点
# 标注完成后点击 Export 保存到 ARM_ANNOTATION_JSON_PATH
# 若已有有效 JSON，可跳过此 cell，直接运行下方的 JSON 加载 cell
# ============================================================

ARM_ANNOTATION_OBJECT_SPECS = {
    "arm_cable": {
        "display": "Arm + Cable (主臂/右)",
        "obj_id": int(ARM_OBJ_ID),
        "target": "ARM_CABLE_INITIAL_PROMPTS",
    },
}
if ARM_OBJ_ID_2 is not None:
    ARM_ANNOTATION_OBJECT_SPECS["arm_cable_2"] = {
        "display": "Arm 2 (左臂)",
        "obj_id": int(ARM_OBJ_ID_2),
        "target": "ARM_CABLE_2_INITIAL_PROMPTS",
    }

# 若 JSON 已存在，用它 seed UI
_arm_seed_prompt_map = {
    "ARM_CABLE_INITIAL_PROMPTS": [],
    "ARM_CABLE_2_INITIAL_PROMPTS": [],
    "GRIPPER_LEFT_KEYFRAME_PROMPTS": [],
    "GRIPPER_RIGHT_KEYFRAME_PROMPTS": [],
}
if Path(ARM_ANNOTATION_JSON_PATH).exists():
    try:
        _arm_existing = load_annotation_prompts_json(
            json_path=ARM_ANNOTATION_JSON_PATH, status_prefix="[arm-annotation/seed]"
        )
        _arm_seed_prompt_map.update(_arm_existing)
        print(f"[arm-annotation] seeding UI from existing JSON: {ARM_ANNOTATION_JSON_PATH}")
    except Exception as _e:
        print(f"[arm-annotation][warn] could not seed from JSON: {_e}")

_arm_annotation_store = create_annotation_store(ARM_ANNOTATION_OBJECT_SPECS)
seed_store_from_prompt_map(
    store=_arm_annotation_store,
    object_specs=ARM_ANNOTATION_OBJECT_SPECS,
    prompt_map=_arm_seed_prompt_map,
    img_w=IMG_WIDTH,
    img_h=IMG_HEIGHT,
)

# Export 回调：补齐 schema 要求的全部键后保存
# save_json_on_export=False 原因同 gripper UI：内部保存会在补齐 ARM 字段前调用
# validate_export_prompt_map 导致 KeyError
_arm_export_result = {}

def _on_arm_export(export_prompts):
    global _arm_export_result
    _arm_export_result = {
        "ARM_CABLE_INITIAL_PROMPTS": export_prompts.get("ARM_CABLE_INITIAL_PROMPTS", []),
        "ARM_CABLE_2_INITIAL_PROMPTS": export_prompts.get("ARM_CABLE_2_INITIAL_PROMPTS", []),
        "GRIPPER_LEFT_KEYFRAME_PROMPTS": [],   # schema 要求，留空
        "GRIPPER_RIGHT_KEYFRAME_PROMPTS": [],  # schema 要求，留空
    }
    from annotation_ui_tools import save_annotation_prompts_json as _save_json
    _save_json(
        export_prompts=_arm_export_result,
        json_path=ARM_ANNOTATION_JSON_PATH,
        status_prefix="[arm-annotation]",
    )
    print(f"[arm-annotation] saved to {ARM_ANNOTATION_JSON_PATH}")

_arm_annotation_ui = create_annotation_ui(
    video_frames_for_vis=video_frames_for_vis,
    total_frames=TOTAL_FRAMES,
    img_width=IMG_WIDTH,
    img_height=IMG_HEIGHT,
    object_specs=ARM_ANNOTATION_OBJECT_SPECS,
    annotation_store=_arm_annotation_store,
    on_export=_on_arm_export,
    auto_display=True,
    status_prefix="[arm-annotation]",
    export_json_path=ARM_ANNOTATION_JSON_PATH,
    save_json_on_export=False,
)


# ── append 模式：UI 初始化后重新加载已有标注，使其可见并可追加 ──────────────
# create_annotation_ui 内部会清空 annotation_store；
# 在此重新 seed，使已有点位显示在画布上，用户可直接在其基础上追加。
# 导出时 UI 输出 旧点 + 新点，export 回调原样保存即可（无需额外合并）。
if ANNOTATION_MODE == "append" and Path(ARM_ANNOTATION_JSON_PATH).exists():
    try:
        _arm_append_data = load_annotation_prompts_json(
            json_path=ARM_ANNOTATION_JSON_PATH, status_prefix="[arm-annotation/append]"
        )
        seed_store_from_prompt_map(
            store=_arm_annotation_store,
            object_specs=ARM_ANNOTATION_OBJECT_SPECS,
            prompt_map=_arm_append_data,
            img_w=IMG_WIDTH,
            img_h=IMG_HEIGHT,
        )
        if _arm_annotation_ui.get("draw"):
            _arm_annotation_ui["draw"](force_draw=True, full_reset=True)
        _n_arm_frames = sum(len(v) for v in _arm_annotation_store.values())
        print(f"[arm-annotation/append] 已有标注已加载到 UI（{_n_arm_frames} 个对象-帧），可继续追加新点")
    except Exception as _e:
        print(f"[arm-annotation/append][warn] 加载已有标注失败，UI 以空白启动: {_e}")

if not _arm_annotation_ui.get("widget_ready", False):
    print(f"[arm-annotation][fallback] Widget not available: {_arm_annotation_ui.get('widget_error')}")
    print(f"[arm-annotation][fallback] 请直接编辑 JSON 文件: {ARM_ANNOTATION_JSON_PATH}")

In [None]:
# Session 1 — Arm 分割

# ============================================================
# 从 JSON 加载并校验 Arm 点标注（可选）
# 若 ARM_ANNOTATION_JSON_PATH 不存在，arm_prompts_norm = []，Session 1 仅用 text bootstrap
# ============================================================

arm_prompts_norm = []

if not Path(ARM_ANNOTATION_JSON_PATH).exists():
    print(f"[session1/arm] ARM_ANNOTATION_JSON_PATH not found, skipping arm point prompts.")
    print(f"[session1/arm] path: {ARM_ANNOTATION_JSON_PATH}")
else:
    try:
        _arm_json = load_annotation_prompts_json(
            json_path=ARM_ANNOTATION_JSON_PATH, status_prefix="[session1/arm]"
        )
    except Exception as _e:
        raise RuntimeError(f"[session1/arm][FATAL] 读取 arm JSON 失败: {_e}") from _e

    _arm_cable_raw   = _arm_json.get("ARM_CABLE_INITIAL_PROMPTS", [])
    _arm_cable_2_raw = _arm_json.get("ARM_CABLE_2_INITIAL_PROMPTS", [])

    _arm_cable_norm = validate_and_normalize_prompt_list(
        _arm_cable_raw,
        total_frames=TOTAL_FRAMES,
        img_w=IMG_WIDTH,
        img_h=IMG_HEIGHT,
        tag="ARM_CABLE_INITIAL_PROMPTS",
        allow_empty=True,
    )
    _arm_cable_2_norm = []
    if ARM_OBJ_ID_2 is not None:
        _arm_cable_2_norm = validate_and_normalize_prompt_list(
            _arm_cable_2_raw,
            total_frames=TOTAL_FRAMES,
            img_w=IMG_WIDTH,
            img_h=IMG_HEIGHT,
            tag="ARM_CABLE_2_INITIAL_PROMPTS",
            allow_empty=True,
        )
        for _p in _arm_cable_2_norm:
            if _p["obj_id"] != ARM_OBJ_ID_2:
                raise ValueError(
                    f"[session1/arm] ARM_CABLE_2 prompt obj_id={_p['obj_id']} "
                    f"!= ARM_OBJ_ID_2={ARM_OBJ_ID_2}。请检查 JSON 或配置。"
                )

    arm_prompts_norm = _arm_cable_norm + _arm_cable_2_norm

    _arm1_frames = sorted({p["frame_index"] for p in _arm_cable_norm})
    _arm2_frames = sorted({p["frame_index"] for p in _arm_cable_2_norm})
    print(f"[session1/arm] arm_cable   entries={len(_arm_cable_norm)}  keyframes={_arm1_frames}")
    print(f"[session1/arm] arm_cable_2 entries={len(_arm_cable_2_norm)}  keyframes={_arm2_frames}")
    print(f"[session1/arm] total arm prompt entries: {len(arm_prompts_norm)}")
    if len(arm_prompts_norm) == 0:
        print("[session1/arm][WARN] JSON 存在但无标注点，Session 1 将仅用 text bootstrap。")
    else:
        print("[session1/arm] arm point prompts loaded. Ready for Session 1 refinement.")

# ============================================================
# Session 1 — Arm 分割
# Step 1: text bootstrap → 初始双向传播（填充缓存）
# Step 2: 若有 arm 点标注，apply_prompt_list → 最终双向传播
# ============================================================

print("[session1/arm] initializing predictor ...")
# gpus_to_use = range(torch.cuda.device_count())
gpus_to_use = range(len(CUDA_VISIBLE_DEVICES.split(',')))

predictor_arm = build_sam3_video_predictor(
    checkpoint_path=CHECKPOINT_PATH,
    gpus_to_use=gpus_to_use,
    apply_temporal_disambiguation=APPLY_TEMPORAL_DISAMBIGUATION,
)

start_response = predictor_arm.handle_request(
    request=dict(type="start_session", resource_path=VIDEO_PATH)
)
session_id_arm = start_response["session_id"]
print(f"[session1/arm] session started: session_id={session_id_arm}")

# ----------------------------------------------------------
# Step 1: Text bootstrap → 初始双向传播（填充缓存）
# ----------------------------------------------------------
print(f"[session1/arm] text bootstrap: frame={ARM_TEXT_BOOTSTRAP_FRAME_INDEX}  prompt={ARM_TEXT_PROMPT!r}")
add_text_prompt(
    predictor_obj=predictor_arm,
    session_id_value=session_id_arm,
    frame_index=ARM_TEXT_BOOTSTRAP_FRAME_INDEX,
    text_prompt=ARM_TEXT_PROMPT,
    stage_name="session1/arm",
)

print("[session1/arm] propagating (bidirectional, initial pass) ...")
outputs_arm = propagate_bidirectional_and_merge(
    predictor_obj=predictor_arm,
    session_id_value=session_id_arm,
    stage_name="session1/arm",
)
print(f"[session1/arm] initial propagation done: {len(outputs_arm)} frames")

# ----------------------------------------------------------
# Step 2: 若有 arm 点标注，应用后重新传播
# ----------------------------------------------------------
if arm_prompts_norm:
    print(f"[session1/arm] applying {len(arm_prompts_norm)} arm point prompt entries ...")
    apply_prompt_list(
        predictor_obj=predictor_arm,
        session_id_value=session_id_arm,
        prompt_list=arm_prompts_norm,
        stage_name="session1/arm_refine",
    )
    print("[session1/arm] re-propagating after arm point refinement ...")
    outputs_arm = propagate_bidirectional_and_merge(
        predictor_obj=predictor_arm,
        session_id_value=session_id_arm,
        stage_name="session1/arm_refine",
    )
    print(f"[session1/arm] refinement done: {len(outputs_arm)} frames")
else:
    print("[session1/arm] no arm point prompts, skipping refinement (text bootstrap only)")

# ----------------------------------------------------------
# 校验 ARM_OBJ_ID 是否出现在输出中
# ----------------------------------------------------------
_arm_ids_found = set()
for _fo in outputs_arm.values():
    for _oid, _ in iter_object_masks_from_frame_output(_fo):
        _arm_ids_found.add(int(_oid))
print(f"[session1/arm] obj_ids in outputs: {sorted(_arm_ids_found)}")
if ARM_OBJ_ID not in _arm_ids_found:
    print(
        f"[session1/arm][WARN] ARM_OBJ_ID={ARM_OBJ_ID} not found. "
        f"Found: {sorted(_arm_ids_found)}. "
        f"Please update ARM_OBJ_ID in config to match the actual id."
    )
if ARM_OBJ_ID_2 is not None and ARM_OBJ_ID_2 not in _arm_ids_found:
    print(
        f"[session1/arm][WARN] ARM_OBJ_ID_2={ARM_OBJ_ID_2} not found in outputs. "
        f"Left arm may not have been detected. Check arm point annotations."
    )

# 关闭 Session 1；outputs_arm 保留在内存供 Merge 使用
print("[session1/arm] shutting down predictor ...")
cleanup_resources(predictor_obj=predictor_arm, session_id_value=session_id_arm)
predictor_arm = None
session_id_arm = None
print(f"[session1/arm] done. outputs_arm preserved ({len(outputs_arm)} frames).")

# ============================================================
# Checkpoint: 将 arm union masks 保存到磁盘
# 作用：若 Session 2 卡死需重启，可从 checkpoint 加载 arm 结果，跳过重跑 Session 1
# ============================================================
print(f"[ckpt/arm] saving pre-dilate arm masks → {CHECKPOINT_ARM_DIR}")
os.makedirs(CHECKPOINT_ARM_DIR, exist_ok=True)
_ckpt_w, _ckpt_h = get_frame_size(video_frames_for_vis)
for _ckpt_fi in range(TOTAL_FRAMES):
    _ckpt_u = np.zeros((_ckpt_h, _ckpt_w), np.uint8)
    for _, _ckpt_m in iter_object_masks_from_frame_output(outputs_arm.get(_ckpt_fi, {})):
        if isinstance(_ckpt_m, torch.Tensor):
            _ckpt_m = _ckpt_m.detach().cpu().numpy()
        if _ckpt_m.ndim > 2:
            _ckpt_m = np.squeeze(_ckpt_m)
        _ckpt_u = np.maximum(_ckpt_u, (_ckpt_m > 0).astype(np.uint8) * 255)
    Image.fromarray(_ckpt_u, "L").save(os.path.join(CHECKPOINT_ARM_DIR, f"{_ckpt_fi:05d}.png"))
print(f"[ckpt/arm] done: {TOTAL_FRAMES} frames → {CHECKPOINT_ARM_DIR}")

# ============================================================
# 可视化 Session 1 arm 结果
# 检查 arm 掩膜质量，若不满意请调整标注点后重新运行 Session 1
# ============================================================
visualize_outputs(
    outputs_per_frame=outputs_arm,
    video_frames=video_frames_for_vis,
    stride=VIS_FRAME_STRIDE,
    max_plots=VIS_MAX_PLOTS,
    title="Session 1: Arm Segmentation (text bootstrap + point refinement)",
)

In [None]:
# 清理

# ============================================================
# 防御性清理（重复运行或异常中断后确保资源释放）
# ============================================================

if predictor_arm is not None:
    print('[cleanup] cleaning arm predictor (not yet cleaned)')
    cleanup_resources(predictor_obj=predictor_arm, session_id_value=session_id_arm)
    predictor_arm = None
    session_id_arm = None
else:
    print('[cleanup] predictor_arm already cleaned')

cleanup_process_group()
print('[cleanup] all resources released')
