# SAM3 Mask Merge（从 Checkpoint 合并）

**运行顺序**：`mask_config.yaml` → imports → config → init → Merge + 导出

- 仅读取 checkpoint PNG 文件，不需要 SAM3 / GPU
- 需要先分别运行 `mask_1_arm.ipynb` 和 `mask_2_gripper.ipynb`
- 逻辑：`arm_only = dilated_arm AND NOT dilated_gripper`


In [None]:
# imports

# ============================================================
# Imports
# ============================================================
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import yaml
from PIL import Image

from mask_pipeline_tools import (
    get_frame_size,
    load_video_frames_for_visualization,
)

plt.rcParams["axes.titlesize"] = 12
plt.rcParams["figure.titlesize"] = 12


In [None]:
# 配置加载

# ============================================================
# 配置加载 — 修改 mask_config.yaml 来调整参数
# ============================================================
import yaml

_CONFIG_PATH = Path("mask_config.yaml")
with open(_CONFIG_PATH, "r") as _f:
    _cfg = yaml.safe_load(_f)

TASK_NAME = _cfg['task_name']
SCENE_NAME = _cfg['scene_name']
CHECKPOINT_PATH = _cfg['sam3_checkpoint']
CUDA_VISIBLE_DEVICES = _cfg['cuda_visible_devices']
APPLY_TEMPORAL_DISAMBIGUATION = bool(_cfg['apply_temporal_disambiguation'])
ARM_OBJ_ID = int(_cfg['arm_obj_id'])
ARM_OBJ_ID_2 = int(_cfg['arm_obj_id_2']) if _cfg.get('arm_obj_id_2') is not None else None
GRIPPER_LEFT_OBJ_ID = int(_cfg['gripper_left_obj_id'])
GRIPPER_RIGHT_OBJ_ID = int(_cfg['gripper_right_obj_id'])
ARM_TEXT_PROMPT = _cfg['arm_text_prompt']
ARM_TEXT_BOOTSTRAP_FRAME_INDEX = int(_cfg['arm_text_bootstrap_frame_index'])
VIS_FRAME_STRIDE = int(_cfg['vis_frame_stride'])
VIS_MAX_PLOTS = int(_cfg['vis_max_plots'])
EXPORT_ARM_DILATE_RADIUS = int(_cfg['export_arm_dilate_radius'])
EXPORT_GRIPPER_DILATE_RADIUS = int(_cfg['export_gripper_dilate_radius'])
EXPORT_LOG_EVERY = int(_cfg['export_log_every'])

# Derived paths
_data_base = _cfg['data_base_dir']
_export_base_str = _cfg['export_base_dir']
VIDEO_PATH = f"{_data_base}/{TASK_NAME}/train/{SCENE_NAME}/cam_105422061350/color"
ANNOTATION_JSON_PATH = str(Path(VIDEO_PATH).parent / "annotation_prompts_gripper_points.json")
ARM_ANNOTATION_JSON_PATH = str(Path(VIDEO_PATH).parent / "annotation_prompts_arm_points.json")
EXPORT_OUTPUT_DIR = f"{_export_base_str}/{TASK_NAME}/{SCENE_NAME}"
_export_base = Path(EXPORT_OUTPUT_DIR)
CHECKPOINT_ARM_DIR     = str(_export_base.parent / (_export_base.name + "_ckpt_arm"))
CHECKPOINT_GRIPPER_DIR = str(_export_base.parent / (_export_base.name + "_ckpt_gripper"))

print(f"[config] task={TASK_NAME}  scene={SCENE_NAME}")
print(f"[config] VIDEO_PATH={VIDEO_PATH}")
print(f"[config] ANNOTATION_JSON_PATH={ANNOTATION_JSON_PATH}")
print(f"[config] ARM_ANNOTATION_JSON_PATH={ARM_ANNOTATION_JSON_PATH}")
print(f"[config] CHECKPOINT_ARM_DIR={CHECKPOINT_ARM_DIR}")
print(f"[config] CHECKPOINT_GRIPPER_DIR={CHECKPOINT_GRIPPER_DIR}")
print(f"[config] EXPORT_OUTPUT_DIR={EXPORT_OUTPUT_DIR}")


In [None]:
# 初始化

# ============================================================
# 加载视频帧（仅用于可视化及获取帧尺寸）
# ============================================================
video_frames_for_vis = load_video_frames_for_visualization(VIDEO_PATH)
TOTAL_FRAMES = len(video_frames_for_vis)
IMG_WIDTH, IMG_HEIGHT = get_frame_size(video_frames_for_vis)

print(f'[init] frames={TOTAL_FRAMES}  size={IMG_WIDTH}x{IMG_HEIGHT}')
print(f'[init] CHECKPOINT_ARM_DIR={CHECKPOINT_ARM_DIR}')
print(f'[init] CHECKPOINT_GRIPPER_DIR={CHECKPOINT_GRIPPER_DIR}')
print(f'[init] EXPORT_OUTPUT_DIR={EXPORT_OUTPUT_DIR}')


In [None]:
# Merge + 导出 arm_only 掩膜

# ============================================================
# Merge + 导出 arm_only 掩膜
# 逻辑：先分别膨胀 arm / gripper，再 boolean 相减
#
# 数据来源：仅从 checkpoint PNG 读取（不依赖 Session 内存）
#   arm:     CHECKPOINT_ARM_DIR（由 mask_1_arm.ipynb 生成）
#   gripper: CHECKPOINT_GRIPPER_DIR（由 mask_2_gripper.ipynb 生成）
# ============================================================

# --- 校验 checkpoint 目录 ---
_arm_ckpt_files = sorted(os.listdir(CHECKPOINT_ARM_DIR)) if os.path.isdir(CHECKPOINT_ARM_DIR) else []
_gri_ckpt_files = sorted(os.listdir(CHECKPOINT_GRIPPER_DIR)) if os.path.isdir(CHECKPOINT_GRIPPER_DIR) else []

if not _arm_ckpt_files:
    raise RuntimeError(
        f"[merge] arm checkpoint 不存在或为空: {CHECKPOINT_ARM_DIR}\n"
        "请先运行 mask_1_arm.ipynb。"
    )
if not _gri_ckpt_files:
    raise RuntimeError(
        f"[merge] gripper checkpoint 不存在或为空: {CHECKPOINT_GRIPPER_DIR}\n"
        "请先运行 mask_2_gripper.ipynb。"
    )

print(f"[merge] arm checkpoint:     {len(_arm_ckpt_files)} frames ← {CHECKPOINT_ARM_DIR}")
print(f"[merge] gripper checkpoint: {len(_gri_ckpt_files)} frames ← {CHECKPOINT_GRIPPER_DIR}")
print(f"[merge] arm_dilate_radius={EXPORT_ARM_DILATE_RADIUS}  gripper_dilate_radius={EXPORT_GRIPPER_DILATE_RADIUS}")
print(f"[merge] output_dir={EXPORT_OUTPUT_DIR}")

os.makedirs(EXPORT_OUTPUT_DIR, exist_ok=True)

def _make_kernel(radius):
    if radius <= 0:
        return None
    ks = 2 * int(radius) + 1
    return np.ones((ks, ks), np.uint8)

_arm_kernel     = _make_kernel(EXPORT_ARM_DILATE_RADIUS)
_gripper_kernel = _make_kernel(EXPORT_GRIPPER_DILATE_RADIUS)
print(f"[merge] arm kernel: {2*EXPORT_ARM_DILATE_RADIUS+1}x{2*EXPORT_ARM_DILATE_RADIUS+1}  "
      f"gripper kernel: {2*EXPORT_GRIPPER_DILATE_RADIUS+1}x{2*EXPORT_GRIPPER_DILATE_RADIUS+1}")

_img_w, _img_h = get_frame_size(video_frames_for_vis)

# --- 主循环 ---
_saved_paths = []
_total_gripper_px = 0

for _frame_idx in range(TOTAL_FRAMES):
    _arm_union = np.array(
        Image.open(os.path.join(CHECKPOINT_ARM_DIR, f"{_frame_idx:05d}.png")).convert("L")
    )
    _gripper_union = np.array(
        Image.open(os.path.join(CHECKPOINT_GRIPPER_DIR, f"{_frame_idx:05d}.png")).convert("L")
    )

    # Step 1: 分别膨胀
    if _arm_kernel is not None and np.any(_arm_union):
        _arm_union = cv2.dilate(_arm_union, _arm_kernel, iterations=1)
    if _gripper_kernel is not None and np.any(_gripper_union):
        _gripper_union = cv2.dilate(_gripper_union, _gripper_kernel, iterations=1)

    # Step 2: boolean 相减（dilated_arm AND NOT dilated_gripper）
    _arm_only = np.where((_arm_union > 0) & (_gripper_union == 0), 255, 0).astype(np.uint8)

    _total_gripper_px += int(np.count_nonzero(_gripper_union))

    if _frame_idx % max(int(EXPORT_LOG_EVERY), 1) == 0 or _frame_idx == TOTAL_FRAMES - 1:
        print(
            f"[merge][frame {_frame_idx:05d}] "
            f"arm_px={int(np.count_nonzero(_arm_union))} "
            f"gripper_px={int(np.count_nonzero(_gripper_union))} "
            f"arm_only_px={int(np.count_nonzero(_arm_only))}"
        )

    _out_fp = os.path.join(EXPORT_OUTPUT_DIR, f"{_frame_idx:05d}.png")
    Image.fromarray(_arm_only, mode="L").save(_out_fp)
    _saved_paths.append(_out_fp)

if _total_gripper_px == 0:
    raise ValueError(
        "[merge][FATAL] 所有帧的 gripper 像素总和为 0。\n"
        "checkpoint 内容可能无效。请检查 mask_2_gripper.ipynb 的分割结果。"
    )

print(f"[merge] export complete: {len(_saved_paths)} masks → {EXPORT_OUTPUT_DIR}")

# ============================================================
# 可视化 Merge 结果（抽帧叠加检查）
# 左：原图   右：arm_only mask 叠加（绿色 = mask 区域）
# ============================================================

_vis_frames = list(range(0, TOTAL_FRAMES, VIS_FRAME_STRIDE))[:VIS_MAX_PLOTS]
plt.close("all")

for _fidx in _vis_frames:
    _mask_fp = os.path.join(EXPORT_OUTPUT_DIR, f"{_fidx:05d}.png")
    if not os.path.exists(_mask_fp):
        print(f"[vis/merge] mask not found: {_mask_fp}")
        continue

    _mask_img = np.array(Image.open(_mask_fp).convert("L"))
    _frame_img = video_frames_for_vis[_fidx].copy()

    _overlay = _frame_img.copy()
    _mb = _mask_img > 0
    _overlay[_mb] = (
        _frame_img[_mb] * 0.4 + np.array([0, 220, 0], dtype=np.float32) * 0.6
    ).astype(np.uint8)

    _fig, _axes = plt.subplots(1, 2, figsize=(13, 4))
    _axes[0].imshow(_frame_img)
    _axes[0].set_title(f"Frame {_fidx}: original")
    _axes[0].axis("off")
    _axes[1].imshow(_overlay)
    _axes[1].set_title(f"Frame {_fidx}: arm_only overlay")
    _axes[1].axis("off")
    plt.tight_layout()
    plt.show()
