# SAM3 解耦两 Session 分割：Arm（text bootstrap）+ Gripper（标注点）

**设计目标**：消除 arm 与 gripper 分割之间的相互干扰。

## 架构

| | Session 1 — Arm | Session 2 — Gripper |
|---|---|---|
| 初始化方式 | text bootstrap `"robot and cable"` | 仅标注点（JSON） |
| 是否需要标注 | **不需要** | **必须**（来自 JSON） |
| 另一方对象是否存在 | 无 gripper 对象 | 无 arm 对象 |
| 传播方式 | 双向 | 双向 |
| 输出 | `outputs_arm` | `outputs_gripper` |

**Merge 步骤**：`arm_only = dilated_arm AND NOT dilated_gripper`，纯 NumPy，无 SAM3 依赖。

## 与原 notebook 的区别

原 `generate_mask_airexo_data_gripper_points.ipynb`：
- Stage A（arm text bootstrap）→ Stage B（同一 session 注入 gripper）
- gripper 对象写入 arm session 的时序记忆，造成干扰

本 notebook：
- Session 1 完全关闭后，Session 2 才初始化——两者共享零状态
- Merge 逻辑从两个独立 outputs dict 分别读取，不存在 obj_id 混淆

In [1]:
# imports

# ============================================================
# Imports
# ============================================================
%load_ext autoreload
%autoreload 2

import json
import os
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image

import sam3
from sam3.model_builder import build_sam3_video_predictor

from mask_pipeline_tools import (
    add_text_prompt,
    apply_prompt_list,
    cleanup_process_group,
    cleanup_resources,
    get_frame_size,
    iter_object_masks_from_frame_output,
    load_video_frames_for_visualization,
    propagate_bidirectional_and_merge,
    validate_and_normalize_prompt_list,
    visualize_outputs,
)
from annotation_ui_tools import (
    create_annotation_store,
    create_annotation_ui,
    load_annotation_prompts_json,
    seed_store_from_prompt_map,
)

plt.rcParams["axes.titlesize"] = 12
plt.rcParams["figure.titlesize"] = 12

  import pkg_resources


In [None]:
TASK_NAME = "task_0013"
SCENE_NAME = "scene_0003" 

In [None]:
# 初始化

# ============================================================
# 配置 — 只需修改此区域
# ============================================================

# 路径
VIDEO_PATH = f"/data/haoxiang/data/airexo2/{TASK_NAME}/train/{SCENE_NAME}/cam_105422061350/color"
CHECKPOINT_PATH = "/data/haoxiang/sam3/models/facebook/sam3/sam3.pt"
# 标注 JSON（gripper 点位）
ANNOTATION_JSON_PATH = str(
    Path(VIDEO_PATH).resolve().parent / "annotation_prompts_gripper_points.json"
)
# 标注 JSON（arm 点位）
ARM_ANNOTATION_JSON_PATH = str(
    Path(VIDEO_PATH).resolve().parent / "annotation_prompts_arm_points.json"
)

# 推理
CUDA_VISIBLE_DEVICES = "0,1,2,3"
APPLY_TEMPORAL_DISAMBIGUATION = False

# Object IDs
ARM_OBJ_ID = 0       # 主臂（text bootstrap 自动分配，通常为 0）
ARM_OBJ_ID_2 = 1     # 第二臂（左臂）；None = 单臂模式

# Session 2：gripper obj_id 须与 JSON 标注文件中的值一致
GRIPPER_LEFT_OBJ_ID = 2
GRIPPER_RIGHT_OBJ_ID = 3

# Session 1 — arm text bootstrap
ARM_TEXT_PROMPT = "robot and cable"
ARM_TEXT_BOOTSTRAP_FRAME_INDEX = 0

# 可视化
VIS_FRAME_STRIDE = 60
VIS_MAX_PLOTS = 8

# 导出
# EXPORT_OUTPUT_DIR = "/data/haoxiang/propainter/masks_airexo_arm_only_decoupled"
EXPORT_OUTPUT_DIR = f"/data/haoxiang/data/airexo2_processed/{TASK_NAME}/{SCENE_NAME}"
EXPORT_ARM_DILATE_RADIUS = 15    # arm mask 膨胀半径
EXPORT_GRIPPER_DILATE_RADIUS = 15  # gripper mask 膨胀半径（用于从 arm 中挖去 gripper 区域）
EXPORT_LOG_EVERY = 50

# Checkpoint 目录（每个 session 结束后自动保存中间 mask，便于重启后跳过已完成的 session）
# 恢复场景：Session 2 卡死重启后 → arm checkpoint 自动加载，只需重跑 Session 2；反之亦然
_export_base = Path(EXPORT_OUTPUT_DIR)
CHECKPOINT_ARM_DIR     = str(_export_base.parent / (_export_base.name + "_ckpt_arm"))
CHECKPOINT_GRIPPER_DIR = str(_export_base.parent / (_export_base.name + "_ckpt_gripper"))

# ============================================================
# 运行时状态（不需要修改）
# ============================================================
video_frames_for_vis = None
TOTAL_FRAMES = None
IMG_WIDTH = None
IMG_HEIGHT = None

predictor_arm = None
session_id_arm = None
outputs_arm = None

predictor_gripper = None
session_id_gripper = None
outputs_gripper = None

arm_prompts_norm = None          # arm 点标注（加载后存于此）
gripper_left_prompts_norm = None
gripper_right_prompts_norm = None

# ============================================================
# 加载视频帧（两个 session 共用）
# ============================================================

# 重复执行时先清理残留 predictor，防止 NCCL 初始化异常
if predictor_arm is not None:
    print("[init] cleaning up previous arm predictor")
    cleanup_resources(predictor_obj=predictor_arm, session_id_value=session_id_arm)
    predictor_arm = None
    session_id_arm = None

if predictor_gripper is not None:
    print("[init] cleaning up previous gripper predictor")
    cleanup_resources(predictor_obj=predictor_gripper, session_id_value=session_id_gripper)
    predictor_gripper = None
    session_id_gripper = None

os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES
print(f"[init] CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")

video_frames_for_vis = load_video_frames_for_visualization(VIDEO_PATH)
TOTAL_FRAMES = len(video_frames_for_vis)
IMG_WIDTH, IMG_HEIGHT = get_frame_size(video_frames_for_vis)

print(f"[init] frames={TOTAL_FRAMES}  size={IMG_WIDTH}x{IMG_HEIGHT}")
print(f"[init] VIDEO_PATH={VIDEO_PATH}")
print(f"[init] ANNOTATION_JSON_PATH={ANNOTATION_JSON_PATH}")
print(f"[init] ARM_ANNOTATION_JSON_PATH={ARM_ANNOTATION_JSON_PATH}")
print(f"[init] CHECKPOINT_ARM_DIR={CHECKPOINT_ARM_DIR}")
print(f"[init] CHECKPOINT_GRIPPER_DIR={CHECKPOINT_GRIPPER_DIR}")

In [None]:
# Arm 标注 UI
# ============================================================
# Arm 标注 UI
# 为 arm_cable（主臂/右臂, obj_id=ARM_OBJ_ID）和
#    arm_cable_2（左臂, obj_id=ARM_OBJ_ID_2）标注关键帧点
# 标注完成后点击 Export 保存到 ARM_ANNOTATION_JSON_PATH
# 若已有有效 JSON，可跳过此 cell，直接运行下方的 JSON 加载 cell
# ============================================================

ARM_ANNOTATION_OBJECT_SPECS = {
    "arm_cable": {
        "display": "Arm + Cable (主臂/右)",
        "obj_id": int(ARM_OBJ_ID),
        "target": "ARM_CABLE_INITIAL_PROMPTS",
    },
}
if ARM_OBJ_ID_2 is not None:
    ARM_ANNOTATION_OBJECT_SPECS["arm_cable_2"] = {
        "display": "Arm 2 (左臂)",
        "obj_id": int(ARM_OBJ_ID_2),
        "target": "ARM_CABLE_2_INITIAL_PROMPTS",
    }

# 若 JSON 已存在，用它 seed UI
_arm_seed_prompt_map = {
    "ARM_CABLE_INITIAL_PROMPTS": [],
    "ARM_CABLE_2_INITIAL_PROMPTS": [],
    "GRIPPER_LEFT_KEYFRAME_PROMPTS": [],
    "GRIPPER_RIGHT_KEYFRAME_PROMPTS": [],
}
if Path(ARM_ANNOTATION_JSON_PATH).exists():
    try:
        _arm_existing = load_annotation_prompts_json(
            json_path=ARM_ANNOTATION_JSON_PATH, status_prefix="[arm-annotation/seed]"
        )
        _arm_seed_prompt_map.update(_arm_existing)
        print(f"[arm-annotation] seeding UI from existing JSON: {ARM_ANNOTATION_JSON_PATH}")
    except Exception as _e:
        print(f"[arm-annotation][warn] could not seed from JSON: {_e}")

_arm_annotation_store = create_annotation_store(ARM_ANNOTATION_OBJECT_SPECS)
seed_store_from_prompt_map(
    store=_arm_annotation_store,
    object_specs=ARM_ANNOTATION_OBJECT_SPECS,
    prompt_map=_arm_seed_prompt_map,
    img_w=IMG_WIDTH,
    img_h=IMG_HEIGHT,
)

# Export 回调：补齐 schema 要求的全部键后保存
# save_json_on_export=False 原因同 gripper UI：内部保存会在补齐 ARM 字段前调用
# validate_export_prompt_map 导致 KeyError
_arm_export_result = {}

def _on_arm_export(export_prompts):
    global _arm_export_result
    _arm_export_result = {
        "ARM_CABLE_INITIAL_PROMPTS": export_prompts.get("ARM_CABLE_INITIAL_PROMPTS", []),
        "ARM_CABLE_2_INITIAL_PROMPTS": export_prompts.get("ARM_CABLE_2_INITIAL_PROMPTS", []),
        "GRIPPER_LEFT_KEYFRAME_PROMPTS": [],   # schema 要求，留空
        "GRIPPER_RIGHT_KEYFRAME_PROMPTS": [],  # schema 要求，留空
    }
    from annotation_ui_tools import save_annotation_prompts_json as _save_json
    _save_json(
        export_prompts=_arm_export_result,
        json_path=ARM_ANNOTATION_JSON_PATH,
        status_prefix="[arm-annotation]",
    )
    print(f"[arm-annotation] saved to {ARM_ANNOTATION_JSON_PATH}")

_arm_annotation_ui = create_annotation_ui(
    video_frames_for_vis=video_frames_for_vis,
    total_frames=TOTAL_FRAMES,
    img_width=IMG_WIDTH,
    img_height=IMG_HEIGHT,
    object_specs=ARM_ANNOTATION_OBJECT_SPECS,
    annotation_store=_arm_annotation_store,
    on_export=_on_arm_export,
    auto_display=True,
    status_prefix="[arm-annotation]",
    export_json_path=ARM_ANNOTATION_JSON_PATH,
    save_json_on_export=False,
)

if not _arm_annotation_ui.get("widget_ready", False):
    print(f"[arm-annotation][fallback] Widget not available: {_arm_annotation_ui.get('widget_error')}")
    print(f"[arm-annotation][fallback] 请直接编辑 JSON 文件: {ARM_ANNOTATION_JSON_PATH}")

In [None]:
# Session 1 — Arm 分割

# ============================================================
# Session 1 — Arm 分割
# Step 1: text bootstrap → 初始双向传播（填充缓存）
# Step 2: 若有 arm 点标注，apply_prompt_list → 最终双向传播
# ============================================================

print("[session1/arm] initializing predictor ...")
gpus_to_use = range(torch.cuda.device_count())
predictor_arm = build_sam3_video_predictor(
    checkpoint_path=CHECKPOINT_PATH,
    gpus_to_use=gpus_to_use,
    apply_temporal_disambiguation=APPLY_TEMPORAL_DISAMBIGUATION,
)

start_response = predictor_arm.handle_request(
    request=dict(type="start_session", resource_path=VIDEO_PATH)
)
session_id_arm = start_response["session_id"]
print(f"[session1/arm] session started: session_id={session_id_arm}")

# ----------------------------------------------------------
# Step 1: Text bootstrap → 初始双向传播（填充缓存）
# ----------------------------------------------------------
print(f"[session1/arm] text bootstrap: frame={ARM_TEXT_BOOTSTRAP_FRAME_INDEX}  prompt={ARM_TEXT_PROMPT!r}")
add_text_prompt(
    predictor_obj=predictor_arm,
    session_id_value=session_id_arm,
    frame_index=ARM_TEXT_BOOTSTRAP_FRAME_INDEX,
    text_prompt=ARM_TEXT_PROMPT,
    stage_name="session1/arm",
)

print("[session1/arm] propagating (bidirectional, initial pass) ...")
outputs_arm = propagate_bidirectional_and_merge(
    predictor_obj=predictor_arm,
    session_id_value=session_id_arm,
    stage_name="session1/arm",
)
print(f"[session1/arm] initial propagation done: {len(outputs_arm)} frames")

# ----------------------------------------------------------
# Step 2: 若有 arm 点标注，应用后重新传播
# arm_prompts_norm 由上方「加载并校验 Arm 标注」cell 填充
# 若未运行该 cell 或 JSON 不存在，arm_prompts_norm 为 None 或 []，跳过此步
# ----------------------------------------------------------
if arm_prompts_norm:
    print(f"[session1/arm] applying {len(arm_prompts_norm)} arm point prompt entries ...")
    apply_prompt_list(
        predictor_obj=predictor_arm,
        session_id_value=session_id_arm,
        prompt_list=arm_prompts_norm,
        stage_name="session1/arm_refine",
    )
    print("[session1/arm] re-propagating after arm point refinement ...")
    outputs_arm = propagate_bidirectional_and_merge(
        predictor_obj=predictor_arm,
        session_id_value=session_id_arm,
        stage_name="session1/arm_refine",
    )
    print(f"[session1/arm] refinement done: {len(outputs_arm)} frames")
else:
    print("[session1/arm] no arm point prompts, skipping refinement (text bootstrap only)")

# ----------------------------------------------------------
# 校验 ARM_OBJ_ID 是否出现在输出中
# ----------------------------------------------------------
_arm_ids_found = set()
for _fo in outputs_arm.values():
    for _oid, _ in iter_object_masks_from_frame_output(_fo):
        _arm_ids_found.add(int(_oid))
print(f"[session1/arm] obj_ids in outputs: {sorted(_arm_ids_found)}")
if ARM_OBJ_ID not in _arm_ids_found:
    print(
        f"[session1/arm][WARN] ARM_OBJ_ID={ARM_OBJ_ID} not found. "
        f"Found: {sorted(_arm_ids_found)}. "
        f"Please update ARM_OBJ_ID in config to match the actual id."
    )
if ARM_OBJ_ID_2 is not None and ARM_OBJ_ID_2 not in _arm_ids_found:
    print(
        f"[session1/arm][WARN] ARM_OBJ_ID_2={ARM_OBJ_ID_2} not found in outputs. "
        f"Left arm may not have been detected. Check arm point annotations."
    )

# 关闭 Session 1；outputs_arm 保留在内存供 Merge 使用
print("[session1/arm] shutting down predictor ...")
cleanup_resources(predictor_obj=predictor_arm, session_id_value=session_id_arm)
predictor_arm = None
session_id_arm = None
print(f"[session1/arm] done. outputs_arm preserved ({len(outputs_arm)} frames).")

# ============================================================
# Checkpoint: 将 arm union masks 保存到磁盘
# 作用：若 Session 2 卡死需重启，可从 checkpoint 加载 arm 结果，跳过重跑 Session 1
# ============================================================
print(f"[ckpt/arm] saving pre-dilate arm masks → {CHECKPOINT_ARM_DIR}")
os.makedirs(CHECKPOINT_ARM_DIR, exist_ok=True)
_ckpt_w, _ckpt_h = get_frame_size(video_frames_for_vis)
for _ckpt_fi in range(TOTAL_FRAMES):
    _ckpt_u = np.zeros((_ckpt_h, _ckpt_w), np.uint8)
    for _, _ckpt_m in iter_object_masks_from_frame_output(outputs_arm.get(_ckpt_fi, {})):
        if isinstance(_ckpt_m, torch.Tensor):
            _ckpt_m = _ckpt_m.detach().cpu().numpy()
        if _ckpt_m.ndim > 2:
            _ckpt_m = np.squeeze(_ckpt_m)
        _ckpt_u = np.maximum(_ckpt_u, (_ckpt_m > 0).astype(np.uint8) * 255)
    Image.fromarray(_ckpt_u, "L").save(os.path.join(CHECKPOINT_ARM_DIR, f"{_ckpt_fi:05d}.png"))
print(f"[ckpt/arm] done: {TOTAL_FRAMES} frames → {CHECKPOINT_ARM_DIR}")

# ============================================================
# 可视化 Session 1 arm 结果
# 检查 arm 掩膜质量，若不满意请调整标注点后重新运行 Session 1
# ============================================================
visualize_outputs(
    outputs_per_frame=outputs_arm,
    video_frames=video_frames_for_vis,
    stride=VIS_FRAME_STRIDE,
    max_plots=VIS_MAX_PLOTS,
    title="Session 1: Arm Segmentation (text bootstrap + point refinement)",
)

In [None]:
# Gripper 标注 UI

# ============================================================
# Gripper 标注 UI
# 在此处为 gripper_left / gripper_right 添加标注点，然后点击 Export 保存 JSON
# 若已有有效 JSON，可跳过此 cell，直接运行下方的 JSON 加载 cell
# ============================================================

# 仅含 gripper 对象，不含 arm（arm 无需标注）
GRIPPER_ANNOTATION_OBJECT_SPECS = {
    "gripper_left": {
        "display": "Gripper Left",
        "obj_id": int(GRIPPER_LEFT_OBJ_ID),
        "target": "GRIPPER_LEFT_KEYFRAME_PROMPTS",
    },
    "gripper_right": {
        "display": "Gripper Right",
        "obj_id": int(GRIPPER_RIGHT_OBJ_ID),
        "target": "GRIPPER_RIGHT_KEYFRAME_PROMPTS",
    },
}

# 若 JSON 已存在，用它 seed UI（不强制加载，仅用于填充初始状态）
_seed_prompt_map = {
    "ARM_CABLE_INITIAL_PROMPTS": [],
    "GRIPPER_LEFT_KEYFRAME_PROMPTS": [],
    "GRIPPER_RIGHT_KEYFRAME_PROMPTS": [],
}
if Path(ANNOTATION_JSON_PATH).exists():
    try:
        _existing = load_annotation_prompts_json(
            json_path=ANNOTATION_JSON_PATH, status_prefix="[annotation/seed]"
        )
        _seed_prompt_map.update(_existing)
        print(f"[annotation] seeding UI from existing JSON: {ANNOTATION_JSON_PATH}")
    except Exception as _e:
        print(f"[annotation][warn] could not seed from JSON: {_e}")

_annotation_store = create_annotation_store(GRIPPER_ANNOTATION_OBJECT_SPECS)
seed_store_from_prompt_map(
    store=_annotation_store,
    object_specs=GRIPPER_ANNOTATION_OBJECT_SPECS,
    prompt_map=_seed_prompt_map,
    img_w=IMG_WIDTH,
    img_h=IMG_HEIGHT,
)

# Export 回调：补齐 ARM_CABLE_INITIAL_PROMPTS 字段后自行保存 JSON。
# 注意：必须用 save_json_on_export=False 禁用 create_annotation_ui 的内部保存，
# 因为内部保存会在添加 ARM 字段之前就调用 validate_export_prompt_map，导致 KeyError。
_gripper_export_result = {}

def _on_gripper_export(export_prompts):
    global _gripper_export_result
    # 补齐 ARM_CABLE_INITIAL_PROMPTS，满足 JSON schema 要求
    _gripper_export_result = {
        "ARM_CABLE_INITIAL_PROMPTS": [],
        "GRIPPER_LEFT_KEYFRAME_PROMPTS": export_prompts.get("GRIPPER_LEFT_KEYFRAME_PROMPTS", []),
        "GRIPPER_RIGHT_KEYFRAME_PROMPTS": export_prompts.get("GRIPPER_RIGHT_KEYFRAME_PROMPTS", []),
    }
    from annotation_ui_tools import save_annotation_prompts_json as _save_json
    _save_json(
        export_prompts=_gripper_export_result,
        json_path=ANNOTATION_JSON_PATH,
        status_prefix="[annotation/gripper]",
    )
    print(f"[annotation] saved to {ANNOTATION_JSON_PATH}")

_annotation_ui = create_annotation_ui(
    video_frames_for_vis=video_frames_for_vis,
    total_frames=TOTAL_FRAMES,
    img_width=IMG_WIDTH,
    img_height=IMG_HEIGHT,
    object_specs=GRIPPER_ANNOTATION_OBJECT_SPECS,
    annotation_store=_annotation_store,
    on_export=_on_gripper_export,
    auto_display=True,
    status_prefix="[annotation/gripper]",
    export_json_path=ANNOTATION_JSON_PATH,
    save_json_on_export=False,  # 禁用内部保存，由 _on_gripper_export 回调负责
)

if not _annotation_ui.get("widget_ready", False):
    print(f"[annotation][fallback] Widget not available: {_annotation_ui.get('widget_error')}")
    print(f"[annotation][fallback] 请直接编辑 JSON 文件: {ANNOTATION_JSON_PATH}")

In [None]:
# Gripper 分割

# ============================================================
# 从 JSON 加载并校验 Gripper 标注
# 运行此 cell 前请确认已通过标注 UI 导出 JSON（或手动编辑好 JSON）
# ============================================================

print(f"[session2/gripper] loading annotation from JSON: {ANNOTATION_JSON_PATH}")

try:
    annotation_prompts = load_annotation_prompts_json(
        json_path=ANNOTATION_JSON_PATH, status_prefix="[session2/gripper]"
    )
except FileNotFoundError as _e:
    raise RuntimeError(
        f"[session2/gripper][FATAL] JSON not found: {ANNOTATION_JSON_PATH}\n"
        "请先在标注 UI 中点击 Export Prompts 导出 JSON。"
    ) from _e
except Exception as _e:
    raise RuntimeError(f"[session2/gripper][FATAL] 读取 JSON 失败: {_e}") from _e

gripper_left_prompts_norm = validate_and_normalize_prompt_list(
    annotation_prompts.get("GRIPPER_LEFT_KEYFRAME_PROMPTS", []),
    total_frames=TOTAL_FRAMES,
    img_w=IMG_WIDTH,
    img_h=IMG_HEIGHT,
    tag="GRIPPER_LEFT_KEYFRAME_PROMPTS",
    allow_empty=True,
)
gripper_right_prompts_norm = validate_and_normalize_prompt_list(
    annotation_prompts.get("GRIPPER_RIGHT_KEYFRAME_PROMPTS", []),
    total_frames=TOTAL_FRAMES,
    img_w=IMG_WIDTH,
    img_h=IMG_HEIGHT,
    tag="GRIPPER_RIGHT_KEYFRAME_PROMPTS",
    allow_empty=True,
)

# 至少一侧必须有标注
if len(gripper_left_prompts_norm) == 0 and len(gripper_right_prompts_norm) == 0:
    raise RuntimeError(
        "[session2/gripper][FATAL] GRIPPER_LEFT 和 GRIPPER_RIGHT 均无标注点。\n"
        "请在标注 UI 中为至少一侧 gripper 添加点位并重新导出。"
    )

# 校验 obj_id 与配置一致
for _p in gripper_left_prompts_norm:
    if _p["obj_id"] != GRIPPER_LEFT_OBJ_ID:
        raise ValueError(
            f"[session2/gripper] GRIPPER_LEFT prompt obj_id={_p['obj_id']} "
            f"!= GRIPPER_LEFT_OBJ_ID={GRIPPER_LEFT_OBJ_ID}。请检查 JSON 或配置。"
        )
for _p in gripper_right_prompts_norm:
    if _p["obj_id"] != GRIPPER_RIGHT_OBJ_ID:
        raise ValueError(
            f"[session2/gripper] GRIPPER_RIGHT prompt obj_id={_p['obj_id']} "
            f"!= GRIPPER_RIGHT_OBJ_ID={GRIPPER_RIGHT_OBJ_ID}。请检查 JSON 或配置。"
        )

_left_frames = sorted({p["frame_index"] for p in gripper_left_prompts_norm})
_right_frames = sorted({p["frame_index"] for p in gripper_right_prompts_norm})
print(f"[session2/gripper] left  prompt entries={len(gripper_left_prompts_norm)}  keyframes={_left_frames}")
print(f"[session2/gripper] right prompt entries={len(gripper_right_prompts_norm)}  keyframes={_right_frames}")
print("[session2/gripper] validation passed. Ready for Session 2.")

# ============================================================
# Session 2 — Gripper 分割
# 全新 predictor，先 text bootstrap 填充缓存，再加 gripper 标注点
#
# 为什么需要 text bootstrap？
#   SAM3 的 add_tracker_new_points（点标注路径）要求 cached_frame_outputs 已存在。
#   必须先跑一次 propagation 才能加点标注。Session 2 从零开始所以需要先 bootstrap。
#   bootstrap 使用与 Session 1 相同的 ARM_TEXT_PROMPT 仅用于初始化缓存。
#   最终 merge 时只从 outputs_gripper 提取 gripper obj_ids，arm 部分被忽略。
#   Session 1 的 arm 分割完全独立，不受任何影响。
# ============================================================

print("[session2/gripper] initializing fresh predictor (independent of Session 1) ...")
gpus_to_use = range(torch.cuda.device_count())
predictor_gripper = build_sam3_video_predictor(
    checkpoint_path=CHECKPOINT_PATH,
    gpus_to_use=gpus_to_use,
    apply_temporal_disambiguation=APPLY_TEMPORAL_DISAMBIGUATION,
)

start_response = predictor_gripper.handle_request(
    request=dict(type="start_session", resource_path=VIDEO_PATH)
)
session_id_gripper = start_response["session_id"]
print(f"[session2/gripper] session started: session_id={session_id_gripper}")
assert predictor_arm is None, "[BUG] predictor_arm should be None at this point"

# ----------------------------------------------------------
# Step 1: Text bootstrap + 传播以填充缓存
# add_tracker_new_points 要求 cached_frame_outputs 已存在，必须先 propagate
# ----------------------------------------------------------
print(f"[session2/gripper] bootstrapping cache: frame={ARM_TEXT_BOOTSTRAP_FRAME_INDEX}  prompt={ARM_TEXT_PROMPT!r}")
add_text_prompt(
    predictor_obj=predictor_gripper,
    session_id_value=session_id_gripper,
    frame_index=ARM_TEXT_BOOTSTRAP_FRAME_INDEX,
    text_prompt=ARM_TEXT_PROMPT,
    stage_name="session2/bootstrap",
)
print("[session2/gripper] propagating to fill cache (bootstrap pass) ...")
_outputs_bootstrap = propagate_bidirectional_and_merge(
    predictor_obj=predictor_gripper,
    session_id_value=session_id_gripper,
    stage_name="session2/bootstrap",
)
print(f"[session2/gripper] cache populated ({len(_outputs_bootstrap)} frames)")

# ----------------------------------------------------------
# Step 2: 添加 gripper 标注点（缓存已就绪）
# ----------------------------------------------------------
if len(gripper_left_prompts_norm) > 0:
    print(f"[session2/gripper] applying {len(gripper_left_prompts_norm)} left gripper prompt entries ...")
    apply_prompt_list(
        predictor_obj=predictor_gripper,
        session_id_value=session_id_gripper,
        prompt_list=gripper_left_prompts_norm,
        stage_name="session2/gripper_left",
    )
else:
    print("[session2/gripper] no left gripper prompts (skipping)")

if len(gripper_right_prompts_norm) > 0:
    print(f"[session2/gripper] applying {len(gripper_right_prompts_norm)} right gripper prompt entries ...")
    apply_prompt_list(
        predictor_obj=predictor_gripper,
        session_id_value=session_id_gripper,
        prompt_list=gripper_right_prompts_norm,
        stage_name="session2/gripper_right",
    )
else:
    print("[session2/gripper] no right gripper prompts (skipping)")

# ----------------------------------------------------------
# Step 3: 最终双向传播
# ----------------------------------------------------------
print("[session2/gripper] final propagation (bidirectional) ...")
outputs_gripper = propagate_bidirectional_and_merge(
    predictor_obj=predictor_gripper,
    session_id_value=session_id_gripper,
    stage_name="session2/gripper",
)
print(f"[session2/gripper] propagation done: {len(outputs_gripper)} frames")

# 校验 gripper obj_id 在输出中存在
_gripper_ids_found = set()
for _fo in outputs_gripper.values():
    for _oid, _ in iter_object_masks_from_frame_output(_fo):
        _gripper_ids_found.add(int(_oid))
print(f"[session2/gripper] obj_ids in outputs: {sorted(_gripper_ids_found)}")

_expected_gripper_ids = set()
if len(gripper_left_prompts_norm) > 0:
    _expected_gripper_ids.add(GRIPPER_LEFT_OBJ_ID)
if len(gripper_right_prompts_norm) > 0:
    _expected_gripper_ids.add(GRIPPER_RIGHT_OBJ_ID)
_missing = _expected_gripper_ids - _gripper_ids_found
if _missing:
    raise ValueError(
        f"[session2/gripper][FATAL] Expected obj_ids {_missing} not in outputs. "
        f"Check annotation point quality and keyframe positions."
    )

# 关闭 Session 2；outputs_gripper 保留在内存供 Merge 使用
# merge 时只提取 gripper obj_ids（arm 的 obj_id=0 会被过滤掉）
print("[session2/gripper] shutting down predictor ...")
cleanup_resources(predictor_obj=predictor_gripper, session_id_value=session_id_gripper)
predictor_gripper = None
session_id_gripper = None
print(f"[session2/gripper] done. outputs_gripper preserved ({len(outputs_gripper)} frames).")

# ============================================================
# Checkpoint: 将 gripper union masks 保存到磁盘（已过滤 gripper obj_ids）
# 作用：若 Session 1 卡死需重启，可从 checkpoint 加载 gripper 结果，跳过重跑 Session 2
# ============================================================
print(f"[ckpt/gripper] saving pre-dilate gripper masks → {CHECKPOINT_GRIPPER_DIR}")
os.makedirs(CHECKPOINT_GRIPPER_DIR, exist_ok=True)
_ckpt_w, _ckpt_h = get_frame_size(video_frames_for_vis)
_ckpt_gripper_set = set()
if len(gripper_left_prompts_norm) > 0:
    _ckpt_gripper_set.add(GRIPPER_LEFT_OBJ_ID)
if len(gripper_right_prompts_norm) > 0:
    _ckpt_gripper_set.add(GRIPPER_RIGHT_OBJ_ID)
for _ckpt_fi in range(TOTAL_FRAMES):
    _ckpt_u = np.zeros((_ckpt_h, _ckpt_w), np.uint8)
    for _ckpt_oid, _ckpt_m in iter_object_masks_from_frame_output(outputs_gripper.get(_ckpt_fi, {})):
        if int(_ckpt_oid) not in _ckpt_gripper_set:
            continue
        if isinstance(_ckpt_m, torch.Tensor):
            _ckpt_m = _ckpt_m.detach().cpu().numpy()
        if _ckpt_m.ndim > 2:
            _ckpt_m = np.squeeze(_ckpt_m)
        _ckpt_u = np.maximum(_ckpt_u, (_ckpt_m > 0).astype(np.uint8) * 255)
    Image.fromarray(_ckpt_u, "L").save(os.path.join(CHECKPOINT_GRIPPER_DIR, f"{_ckpt_fi:05d}.png"))
print(f"[ckpt/gripper] done: {TOTAL_FRAMES} frames → {CHECKPOINT_GRIPPER_DIR}")

# ============================================================
# 可视化 Session 2 gripper 结果
# 检查 gripper 掩膜质量，若不满意请修改标注点并重新运行 Session 2
# ============================================================
visualize_outputs(
    outputs_per_frame=outputs_gripper,
    video_frames=video_frames_for_vis,
    stride=VIS_FRAME_STRIDE,
    max_plots=VIS_MAX_PLOTS,
    title="Session 2: Gripper Segmentation (annotation points only)",
)

In [None]:
# Merge + 导出 arm_only 掩膜

# ============================================================
# Merge + 导出 arm_only 掩膜
# 逻辑：先分别膨胀 arm / gripper，再 boolean 相减
#
# 数据来源（两路均支持 checkpoint 回退）：
#   arm:     1) outputs_arm（内存）     2) CHECKPOINT_ARM_DIR（磁盘）
#   gripper: 1) outputs_gripper（内存）  2) CHECKPOINT_GRIPPER_DIR（磁盘）
#
# 恢复场景：
#   Session 2 卡死重启 → outputs_arm=None → 自动加载 checkpoint_arm，只需重跑 Session 2
#   Session 1 卡死重启 → outputs_gripper=None → 自动加载 checkpoint_gripper，只需重跑 Session 1
# ============================================================

# --- 判断各路数据来源 ---
_USE_ARM_CKPT = False
_USE_GRI_CKPT = False

if outputs_arm is None or len(outputs_arm) == 0:
    _arm_ckpt_files = os.listdir(CHECKPOINT_ARM_DIR) if os.path.isdir(CHECKPOINT_ARM_DIR) else []
    if _arm_ckpt_files:
        print(f"[merge] outputs_arm 不在内存，从 checkpoint 加载: {CHECKPOINT_ARM_DIR}  ({len(_arm_ckpt_files)} files)")
        _USE_ARM_CKPT = True
    else:
        raise RuntimeError(
            "[merge] outputs_arm 为空且无 arm checkpoint。\n"
            f"请先运行 Session 1 cell，或确认 checkpoint 目录存在: {CHECKPOINT_ARM_DIR}"
        )

if outputs_gripper is None or len(outputs_gripper) == 0:
    _gri_ckpt_files = os.listdir(CHECKPOINT_GRIPPER_DIR) if os.path.isdir(CHECKPOINT_GRIPPER_DIR) else []
    if _gri_ckpt_files:
        print(f"[merge] outputs_gripper 不在内存，从 checkpoint 加载: {CHECKPOINT_GRIPPER_DIR}  ({len(_gri_ckpt_files)} files)")
        _USE_GRI_CKPT = True
    else:
        raise RuntimeError(
            "[merge] outputs_gripper 为空且无 gripper checkpoint。\n"
            f"请先运行 Session 2 cell，或确认 checkpoint 目录存在: {CHECKPOINT_GRIPPER_DIR}"
        )

# --- gripper obj_id set（仅内存路径需要，checkpoint 已预过滤）---
_gripper_set = set()
if not _USE_GRI_CKPT:
    _gl = gripper_left_prompts_norm  if gripper_left_prompts_norm  else []
    _gr = gripper_right_prompts_norm if gripper_right_prompts_norm else []
    if len(_gl) > 0: _gripper_set.add(GRIPPER_LEFT_OBJ_ID)
    if len(_gr) > 0: _gripper_set.add(GRIPPER_RIGHT_OBJ_ID)

# --- 调试信息 ---
if not _USE_ARM_CKPT:
    _arm_obj_ids_present = set()
    for _fo in outputs_arm.values():
        for _oid, _ in iter_object_masks_from_frame_output(_fo):
            _arm_obj_ids_present.add(int(_oid))
    print(f"[merge] arm  source: memory  obj_ids={sorted(_arm_obj_ids_present)}")
else:
    print(f"[merge] arm  source: checkpoint")

print(f"[merge] gripper source: {'checkpoint' if _USE_GRI_CKPT else f'memory  obj_ids={sorted(_gripper_set)}'}")
print(f"[merge] arm_dilate_radius={EXPORT_ARM_DILATE_RADIUS}  gripper_dilate_radius={EXPORT_GRIPPER_DILATE_RADIUS}")
print(f"[merge] output_dir={EXPORT_OUTPUT_DIR}")

os.makedirs(EXPORT_OUTPUT_DIR, exist_ok=True)

def _make_kernel(radius):
    if radius <= 0:
        return None
    ks = 2 * int(radius) + 1
    return np.ones((ks, ks), np.uint8)

_arm_kernel     = _make_kernel(EXPORT_ARM_DILATE_RADIUS)
_gripper_kernel = _make_kernel(EXPORT_GRIPPER_DILATE_RADIUS)
print(f"[merge] arm kernel: {2*EXPORT_ARM_DILATE_RADIUS+1}x{2*EXPORT_ARM_DILATE_RADIUS+1}  "
      f"gripper kernel: {2*EXPORT_GRIPPER_DILATE_RADIUS+1}x{2*EXPORT_GRIPPER_DILATE_RADIUS+1}")

_img_w, _img_h = get_frame_size(video_frames_for_vis)

# --- union mask helpers（统一接口，内存/checkpoint 透明切换）---
def _get_arm_union(frame_idx):
    if _USE_ARM_CKPT:
        return np.array(
            Image.open(os.path.join(CHECKPOINT_ARM_DIR, f"{frame_idx:05d}.png")).convert("L")
        )
    _u = np.zeros((_img_h, _img_w), np.uint8)
    for _, _m in iter_object_masks_from_frame_output(outputs_arm.get(frame_idx, {})):
        if isinstance(_m, torch.Tensor): _m = _m.detach().cpu().numpy()
        if _m.ndim > 2: _m = np.squeeze(_m)
        _u = np.maximum(_u, (_m > 0).astype(np.uint8) * 255)
    return _u

def _get_gripper_union(frame_idx):
    if _USE_GRI_CKPT:
        return np.array(
            Image.open(os.path.join(CHECKPOINT_GRIPPER_DIR, f"{frame_idx:05d}.png")).convert("L")
        )
    _u = np.zeros((_img_h, _img_w), np.uint8)
    for _oid, _m in iter_object_masks_from_frame_output(outputs_gripper.get(frame_idx, {})):
        if int(_oid) not in _gripper_set: continue
        if isinstance(_m, torch.Tensor): _m = _m.detach().cpu().numpy()
        if _m.ndim > 2: _m = np.squeeze(_m)
        _u = np.maximum(_u, (_m > 0).astype(np.uint8) * 255)
    return _u

# --- 主循环 ---
_saved_paths = []
_total_gripper_px = 0

for _frame_idx in range(TOTAL_FRAMES):
    _arm_union     = _get_arm_union(_frame_idx)
    _gripper_union = _get_gripper_union(_frame_idx)

    # Step 1: 分别膨胀
    if _arm_kernel is not None and np.any(_arm_union):
        _arm_union = cv2.dilate(_arm_union, _arm_kernel, iterations=1)
    if _gripper_kernel is not None and np.any(_gripper_union):
        _gripper_union = cv2.dilate(_gripper_union, _gripper_kernel, iterations=1)

    # Step 2: boolean 相减（dilated_arm AND NOT dilated_gripper）
    _arm_only = np.where((_arm_union > 0) & (_gripper_union == 0), 255, 0).astype(np.uint8)

    _total_gripper_px += int(np.count_nonzero(_gripper_union))

    if _frame_idx % max(int(EXPORT_LOG_EVERY), 1) == 0 or _frame_idx == TOTAL_FRAMES - 1:
        print(
            f"[merge][frame {_frame_idx:05d}] "
            f"arm_px={int(np.count_nonzero(_arm_union))} "
            f"gripper_px={int(np.count_nonzero(_gripper_union))} "
            f"arm_only_px={int(np.count_nonzero(_arm_only))}"
        )

    _out_fp = os.path.join(EXPORT_OUTPUT_DIR, f"{_frame_idx:05d}.png")
    Image.fromarray(_arm_only, mode="L").save(_out_fp)
    _saved_paths.append(_out_fp)

if _total_gripper_px == 0:
    raise ValueError(
        "[merge][FATAL] 所有帧的 gripper 像素总和为 0。\n"
        "GRIPPER_LEFT/RIGHT_OBJ_ID 可能与 outputs_gripper 中的实际 obj_id 不一致，\n"
        "或 checkpoint 内容无效。请检查 Session 2 可视化结果。"
    )

print(f"[merge] export complete: {len(_saved_paths)} masks → {EXPORT_OUTPUT_DIR}")

# ============================================================
# 可视化 Merge 结果（抽帧叠加检查）
# 左：原图   右：arm_only mask 叠加（绿色 = mask 区域）
# ============================================================

_vis_frames = list(range(0, TOTAL_FRAMES, VIS_FRAME_STRIDE))[:VIS_MAX_PLOTS]
plt.close("all")

for _fidx in _vis_frames:
    _mask_fp = os.path.join(EXPORT_OUTPUT_DIR, f"{_fidx:05d}.png")
    if not os.path.exists(_mask_fp):
        print(f"[vis/merge] mask not found: {_mask_fp}")
        continue

    _mask_img = np.array(Image.open(_mask_fp).convert("L"))
    _frame_img = video_frames_for_vis[_fidx].copy()

    _overlay = _frame_img.copy()
    _mb = _mask_img > 0
    _overlay[_mb] = (
        _frame_img[_mb] * 0.4 + np.array([0, 220, 0], dtype=np.float32) * 0.6
    ).astype(np.uint8)

    _fig, _axes = plt.subplots(1, 2, figsize=(13, 4))
    _axes[0].imshow(_frame_img)
    _axes[0].set_title(f"Frame {_fidx}: original")
    _axes[0].axis("off")
    _axes[1].imshow(_overlay)
    _axes[1].set_title(f"Frame {_fidx}: arm_only overlay")
    _axes[1].axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
# 清理

# ============================================================
# 防御性清理（重复运行或异常中断后确保资源释放）
# ============================================================

if predictor_arm is not None:
    print("[cleanup] cleaning arm predictor (not yet cleaned)")
    cleanup_resources(predictor_obj=predictor_arm, session_id_value=session_id_arm)
    predictor_arm = None
    session_id_arm = None
else:
    print("[cleanup] predictor_arm already cleaned")

if predictor_gripper is not None:
    print("[cleanup] cleaning gripper predictor (not yet cleaned)")
    cleanup_resources(predictor_obj=predictor_gripper, session_id_value=session_id_gripper)
    predictor_gripper = None
    session_id_gripper = None
else:
    print("[cleanup] predictor_gripper already cleaned")

cleanup_process_group()
print("[cleanup] all resources released")