In [1]:
import cv2
import json

video_path = "hand/test.mp4"
keypoints_file = "hand_keypoints.json"

with open(keypoints_file, "r") as f:
    hand_keypoints = json.load(f)

cap = cv2.VideoCapture(video_path)

if not cap.isOpened():
    print("Error: Could not open video file.")
    exit()

print("Video file opened successfully.")

Video file opened successfully.


In [2]:
import torch
from sam2.build_sam import build_sam2_video_predictor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

print("SAM 2 model initialized.")

SAM 2 model initialized.


In [3]:
frame_width = int(cap.get(3))
frame_height = int(cap.get(4))
fps = int(cap.get(cv2.CAP_PROP_FPS))

fourcc = cv2.VideoWriter_fourcc(*"mp4v")
output_video_path = "output_masked.mp4"
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))

print(f"Video properties - Width: {frame_width}, Height: {frame_height}, FPS: {fps}")

Video properties - Width: 1280, Height: 720, FPS: 29


In [4]:
ret, frame = cap.read()
cap.release()

if not ret:
    print("Error: Could not read the first frame.")
    exit()

print("First frame loaded successfully.")

First frame loaded successfully.


In [5]:
inference_state = predictor.init_state(video_path=video_path)

print("Inference state initialized.")

Inference state initialized.


In [13]:
import cv2
import json
import numpy as np
import torch
from sam2.build_sam import build_sam2_video_predictor

video_path = "hand/test.mp4"
keypoints_file = "hand_keypoints.json"
output_video_path = "output_masked_debug.mp4"

with open(keypoints_file, "r") as f:
    hand_keypoints = json.load(f)

cap = cv2.VideoCapture(video_path)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
cap.release()

fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

inference_state = predictor.init_state(video_path=video_path)
predictor.reset_state(inference_state)

points = np.array([kp for hand in hand_keypoints for kp in hand["keypoints"]], dtype=np.float32)
labels = np.ones(len(points), np.int32)

predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=0,
    obj_id=1,
    points=points,
    labels=labels,
)

video_segments = {}
for f_idx, f_obj_ids, f_mask_logits in predictor.propagate_in_video(inference_state):
    if f_mask_logits is None or len(f_mask_logits) == 0:
        print(f"Frame {f_idx}: f_mask_logits is None or empty.")
        continue
    video_segments[f_idx] = {}
    for i, o_id in enumerate(f_obj_ids):
        bin_mask = (f_mask_logits[i] > 0).cpu().numpy().astype(np.uint8)
        nonzero_count = np.count_nonzero(bin_mask)
        if nonzero_count == 0:
            print(f"Frame {f_idx}, obj {o_id}: mask is empty (0 nonzero pixels).")
        else:
            print(f"Frame {f_idx}, obj {o_id}: mask nonzero pixels = {nonzero_count}")
        video_segments[f_idx][o_id] = bin_mask

cap = cv2.VideoCapture(video_path)
frame_idx = 0

max_idx = max(video_segments.keys()) if video_segments else -1
print("Max seg frame index:", max_idx)

while frame_idx <= max_idx:
    ret, frame = cap.read()
    if not ret:
        print(f"Frame {frame_idx}: OpenCV returned no frame.")
        break

    if frame is None or frame.shape[0] == 0 or frame.shape[1] == 0:
        print(f"Frame {frame_idx}: invalid shape {None if frame is None else frame.shape}.")
        out.write(frame if frame is not None else np.zeros((1, 1, 3), dtype=np.uint8))
        frame_idx += 1
        continue

    if frame_idx in video_segments:
        for oid, mask in video_segments[frame_idx].items():
            if mask is None or mask.shape[0] == 0 or mask.shape[1] == 0:
                print(f"Frame {frame_idx}, obj {oid}: mask shape is invalid.")
                continue
            if np.count_nonzero(mask) == 0:
                print(f"Frame {frame_idx}, obj {oid}: mask is all zeros.")
                continue
            mask_255 = mask * 255
            try:
                resized_mask = cv2.resize(mask_255, (frame.shape[1], frame.shape[0]))
            except Exception as e:
                print(f"Frame {frame_idx}, obj {oid}: cv2.resize error: {e}")
                continue

            colored_mask = cv2.applyColorMap(resized_mask, cv2.COLORMAP_JET)
            frame = cv2.addWeighted(frame, 0.6, colored_mask, 0.4, 0)

    out.write(frame)
    frame_idx += 1

cap.release()
out.release()

propagate in video: 100%|█████████████████████| 210/210 [17:58<00:00,  5.14s/it]


Resize error on frame 0, obj 1: OpenCV(4.11.0) /Users/xperience/GHA-Actions-OpenCV/_work/opencv-python/opencv-python/opencv/modules/imgproc/src/resize.cpp:3845: error: (-215:Assertion failed) !dsize.empty() in function 'resize'

Resize error on frame 1, obj 1: OpenCV(4.11.0) /Users/xperience/GHA-Actions-OpenCV/_work/opencv-python/opencv-python/opencv/modules/imgproc/src/resize.cpp:3845: error: (-215:Assertion failed) !dsize.empty() in function 'resize'

Resize error on frame 2, obj 1: OpenCV(4.11.0) /Users/xperience/GHA-Actions-OpenCV/_work/opencv-python/opencv-python/opencv/modules/imgproc/src/resize.cpp:3845: error: (-215:Assertion failed) !dsize.empty() in function 'resize'

Resize error on frame 3, obj 1: OpenCV(4.11.0) /Users/xperience/GHA-Actions-OpenCV/_work/opencv-python/opencv-python/opencv/modules/imgproc/src/resize.cpp:3845: error: (-215:Assertion failed) !dsize.empty() in function 'resize'

Resize error on frame 4, obj 1: OpenCV(4.11.0) /Users/xperience/GHA-Actions-OpenCV/_

In [15]:
import os
import warnings
from threading import Thread

import numpy as np
import torch
from PIL import Image
import cv2


def load_video_frames_from_video_file(
    video_path,
    image_size,
    offload_video_to_cpu,
    img_mean=(0.485, 0.456, 0.406),
    img_std=(0.229, 0.224, 0.225),
    compute_device=torch.device("cuda"),
):
    """
    Load video frames using OpenCV, preserving aspect ratio via letterboxing
    so each frame is padded to (image_size, image_size). 
    This avoids forcing a square resize.
    """

    cap = cv2.VideoCapture(video_path)
    frames = []
    original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Convert BGR to RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

        h, w = frame.shape[:2]

        # Scale while preserving aspect ratio:
        # We'll letterbox to fit into (image_size, image_size).
        # 1) Compute scale so max dimension = image_size
        scale = image_size / max(w, h)
        new_w = int(w * scale)
        new_h = int(h * scale)

        # 2) Resize while preserving aspect ratio
        resized = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)

        # 3) Letterbox: pad with zeros to (image_size, image_size)
        canvas = np.zeros((image_size, image_size, 3), dtype=np.float32)
        # Center the resized frame
        start_x = (image_size - new_w) // 2
        start_y = (image_size - new_h) // 2
        canvas[start_y:start_y+new_h, start_x:start_x+new_w, :] = resized

        # Convert to torch tensor shape [3, H, W]
        tensor_frame = torch.from_numpy(canvas).permute(2, 0, 1)
        frames.append(tensor_frame)

    cap.release()

    # Stack into [N, 3, H, W]
    images = torch.stack(frames, dim=0) if len(frames) > 0 else torch.zeros(0)

    # For the "video_height, video_width" return, let's keep the original
    # (before letterboxing). 
    # The letterboxed frames are (image_size, image_size).
    # We'll return the original video height/width from the capture.
    video_height, video_width = original_height, original_width

    if not offload_video_to_cpu and images.numel() > 0:
        images = images.to(compute_device)

    img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None].to(images.device)
    img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None].to(images.device)

    if images.numel() > 0:
        images -= img_mean
        images /= img_std

    return images, video_height, video_width

In [17]:
from sam2.utils.misc import load_video_frames_from_video_file

video_path = "hand/test.mp4"
image_size = 224  # or whatever size you want
offload_video_to_cpu = False

images, video_height, video_width = load_video_frames_from_video_file(
    video_path=video_path,
    image_size=image_size,
    offload_video_to_cpu=offload_video_to_cpu,
    img_mean=(0.485, 0.456, 0.406),
    img_std=(0.229, 0.224, 0.225),
    compute_device=torch.device("cpu")
)

print("Loaded images:", images.shape)
print("Original video size:", video_height, "x", video_width)


Loaded images: torch.Size([210, 3, 224, 224])
Original video size: 224 x 224


In [19]:
import cv2
import json
import numpy as np
import torch
from sam2.build_sam import build_sam2_video_predictor

video_path = "hand/test.mp4"
keypoints_file = "hand_keypoints.json"

with open(keypoints_file, "r") as f:
    hand_keypoints = json.load(f)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

inference_state = predictor.init_state(video_path=video_path)

# Do NOT reset again if you want to keep your prompts:
# predictor.reset_state(inference_state)

points = np.array([kp for hand in hand_keypoints for kp in hand["keypoints"]], dtype=np.float32)
labels = np.ones(len(points), np.int32)

predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=0,
    obj_id=1,
    points=points,
    labels=labels,
)

video_segments = {}
for f_idx, f_obj_ids, f_mask_logits in predictor.propagate_in_video(inference_state):
    if f_mask_logits is None or len(f_mask_logits) == 0:
        continue
    video_segments[f_idx] = {}
    for i, o_id in enumerate(f_obj_ids):
        bin_mask = (f_mask_logits[i] > 0).cpu().numpy().astype(np.uint8)
        video_segments[f_idx][o_id] = bin_mask

print("Propagation done. Check video_segments for per-frame masks.")

propagate in video: 100%|█████████████████████| 210/210 [18:10<00:00,  5.19s/it]

Propagation done. Check video_segments for per-frame masks.





In [22]:
import cv2
import numpy as np

output_video_path = "output_masked.mp4"

cap = cv2.VideoCapture(video_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))

fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))

frame_idx = 0
while True:
    ret, frame = cap.read()
    if not ret:
        break

    if frame is None or frame.shape[0] == 0 or frame.shape[1] == 0:
        out.write(frame if frame is not None else np.zeros((1, 1, 3), dtype=np.uint8))
        frame_idx += 1
        continue

    if frame_idx in video_segments:
        for obj_id, bin_mask in video_segments[frame_idx].items():
            if bin_mask is None or bin_mask.size == 0:
                continue

            # Ensure the mask is not zero-dimension
            if bin_mask.ndim != 2 or bin_mask.shape[0] == 0 or bin_mask.shape[1] == 0:
                continue

            mask_255 = bin_mask * 255

            # Double-check that (frame.shape[1], frame.shape[0]) is not (0,0)
            target_w, target_h = frame.shape[1], frame.shape[0]
            if target_w == 0 or target_h == 0:
                continue

            try:
                mask_resized = cv2.resize(mask_255, (target_w, target_h))
            except Exception as e:
                print(f"Skipping frame {frame_idx}, obj {obj_id} due to resize error: {e}")
                continue

            if mask_resized is None or mask_resized.size == 0:
                print(f"Skipping frame {frame_idx}, obj {obj_id} because mask_resized is empty.")
                continue

            mask_colored = cv2.applyColorMap(mask_resized, cv2.COLORMAP_JET)
            frame = cv2.addWeighted(frame, 0.6, mask_colored, 0.4, 0)

    out.write(frame)
    frame_idx += 1

cap.release()
out.release()

print(f"Done. Masked video saved to {output_video_path}.")

Done. Masked video saved to output_masked.mp4.


In [23]:
first_hand = hand_keypoints[0]["keypoints"]
# pick just wrist + palm center
points = np.array([first_hand[0], first_hand[9]], dtype=np.float32)  # example
labels = np.ones(len(points), np.int32)
predictor.add_new_points_or_box(inference_state, frame_idx=0, obj_id=1, points=points, labels=labels)

(0,
 [1],
 tensor([[[[-6.0038, -6.0038, -6.0299,  ..., -6.0885, -6.0223, -6.0223],
           [-6.0038, -6.0038, -6.0299,  ..., -6.0885, -6.0223, -6.0223],
           [-6.0998, -6.0998, -6.1374,  ..., -6.1178, -6.0138, -6.0138],
           ...,
           [-4.3484, -4.3484, -4.5250,  ..., -5.6619, -5.6162, -5.6162],
           [-4.1608, -4.1608, -4.3607,  ..., -5.6260, -5.6095, -5.6095],
           [-4.1608, -4.1608, -4.3607,  ..., -5.6260, -5.6095, -5.6095]]]]))

In [24]:
mask_logits = out_mask_logits[0]  # For frame 0, obj 1, for example

print("Mask logits shape:", mask_logits.shape)
print("Max logit value:", torch.max(mask_logits))
print("Min logit value:", torch.min(mask_logits))

positive_vals = (mask_logits > 0).sum().item()
total_vals = mask_logits.numel()

print(f"Number of logits > 0: {positive_vals} / {total_vals}")


Mask logits shape: torch.Size([1, 1024, 1024])
Max logit value: tensor(4.8348)
Min logit value: tensor(-3.7363)
Number of logits > 0: 709540 / 1048576


In [25]:
bin_mask = (mask_logits > 0).cpu().numpy().astype(np.uint8)
print("Nonzero pixels in bin_mask:", np.count_nonzero(bin_mask))

Nonzero pixels in bin_mask: 709540


In [30]:
import cv2
import json
import numpy as np
import torch
import matplotlib.pyplot as plt
from sam2.build_sam import build_sam2_video_predictor

# Setup
video_path = "hand/test2.mp4"
keypoints_file = "hand_keypoints.json"
sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

# Load flattened keypoints
with open(keypoints_file, "r") as f:
    hand_keypoints = json.load(f)
points = np.array([kp for hand in hand_keypoints for kp in hand["keypoints"]], dtype=np.float32)
labels = np.ones(len(points), np.int32)

# Initialize and add prompts (once)
inference_state = predictor.init_state(video_path=video_path)
predictor.reset_state(inference_state)
predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=0,
    obj_id=1,
    points=points,
    labels=labels,
)

# Single call to propagate
video_segments = {}
for frame_idx, obj_ids, mask_logits in predictor.propagate_in_video(inference_state):
    frame_masks = {}
    for i, obj_id in enumerate(obj_ids):
        frame_masks[obj_id] = (mask_logits[i] > 0).cpu().numpy().astype(np.uint8)
    video_segments[frame_idx] = frame_masks

# Now we have all results in `video_segments`. No more re-propagation needed.
# Example: visualize the mask on frame 0
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
ret, frame = cap.read()
cap.release()

if not ret or frame is None:
    print("Could not read frame 0.")
else:
    if 0 in video_segments:
        mask_dict = video_segments[0]
        if 1 in mask_dict:
            bin_mask = mask_dict[1]
            mask_255 = bin_mask * 255
            mask_resized = cv2.resize(mask_255, (frame.shape[1], frame.shape[0]))
            color_mask = cv2.applyColorMap(mask_resized, cv2.COLORMAP_JET)
            overlay = cv2.addWeighted(frame, 0.6, color_mask, 0.4, 0)

            fig, ax = plt.subplots(1,2, figsize=(10,5))
            ax[0].imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            ax[0].set_title("Original Frame 0")
            ax[0].axis("off")

            ax[1].imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
            ax[1].set_title("Mask Overlay (obj=1)")
            ax[1].axis("off")

            plt.show()
        else:
            print("No mask for obj=1 on frame 0.")
    else:
        print("No mask entry for frame=0 in video_segments.")

propagate in video: 100%|███████████████████████| 60/60 [09:45<00:00,  9.77s/it]


error: OpenCV(4.11.0) /Users/xperience/GHA-Actions-OpenCV/_work/opencv-python/opencv-python/opencv/modules/imgproc/src/resize.cpp:3845: error: (-215:Assertion failed) !dsize.empty() in function 'resize'


In [40]:
mask_255.shape

(1, 1024, 1024)

In [41]:
frame.shape

(720, 1280, 3)

In [42]:
video_segments

{0: {1: array([[[0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0],
          ...,
          [0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0]]], dtype=uint8)},
 1: {1: array([[[0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0],
          ...,
          [0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0]]], dtype=uint8)},
 2: {1: array([[[0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0],
          ...,
          [0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0]]], dtype=uint8)},
 3: {1: array([[[0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0],
          ...,
          [0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0],
          [0, 0, 0, ..., 0, 0, 0]]], dtype=uint8)},
 4: {1: 

In [43]:
# Suppose you've already done:
# video_segments = {...} after calling propagate_in_video.

import pickle

# 1) SAVE:
with open("video_segments.pkl", "wb") as f:
    pickle.dump(video_segments, f)

# In a new script or cell:
import pickle

# 2) LOAD:
with open("video_segments.pkl", "rb") as f:
    loaded_segments = pickle.load(f)

# 3) Use loaded_segments to overlay masks
cap = cv2.VideoCapture("hand/test.mp4")
frame_idx = 0

while True:
    ret, frame = cap.read()
    if not ret:
        break
    if frame_idx in loaded_segments:
        # overlay...
        pass
    frame_idx += 1


In [45]:
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
import os

# Suppose you already have `video_segments` from propagate_in_video

video_path = "hand/test.mp4"
cap = cv2.VideoCapture(video_path)

if not cap.isOpened():
    print("Could not open video.")
    cap.release()
    raise SystemExit

frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))

print(f"Video info:")
print(f"  frame_count: {frame_count}")
print(f"  frame_width: {frame_width}")
print(f"  frame_height: {frame_height}")
print(f"  fps: {fps}")

fourcc = cv2.VideoWriter_fourcc(*"mp4v")
output_path = "output_masked_skip_invalid.mp4"
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

frame_idx = 0
successful_count = 0

while True:
    ret, frame = cap.read()
    if not ret:
        print(f"Stopping at frame_idx={frame_idx} (read failure or end).")
        break

    if frame is None:
        print(f"Skipping frame {frame_idx}: frame is None.")
        out.write(np.zeros((1,1,3), dtype=np.uint8))
        frame_idx += 1
        continue

    h, w, c = frame.shape
    if h == 0 or w == 0:
        print(f"Skipping frame {frame_idx}: shape is {frame.shape}.")
        out.write(frame)
        frame_idx += 1
        continue

    # Check if we have a mask in video_segments
    if frame_idx in video_segments and 1 in video_segments[frame_idx]:
        mask = video_segments[frame_idx][1]
        # mask shape?
        print(f"Frame {frame_idx}: mask shape={mask.shape}, dtype={mask.dtype}")

        if mask is None:
            print(f"Skipping frame {frame_idx}: mask is None.")
            out.write(frame)
            frame_idx += 1
            continue

        if mask.ndim != 2:
            print(f"Skipping frame {frame_idx}: mask.ndim={mask.ndim} (expected 2).")
            out.write(frame)
            frame_idx += 1
            continue

        mh, mw = mask.shape
        if mh == 0 or mw == 0:
            print(f"Skipping frame {frame_idx}: mask shape is {mask.shape}.")
            out.write(frame)
            frame_idx += 1
            continue

        mask_255 = mask * 255
        print(f"Resizing mask from ({mh}, {mw}) -> ({w}, {h})")

        try:
            mask_resized = cv2.resize(mask_255, (w, h))
        except Exception as e:
            print(f"Skipping frame {frame_idx} due to resize error: {e}")
            out.write(frame)
            frame_idx += 1
            continue

        if mask_resized.size == 0:
            print(f"Skipping frame {frame_idx}: mask_resized.size=0.")
            out.write(frame)
            frame_idx += 1
            continue

        color_mask = cv2.applyColorMap(mask_resized, cv2.COLORMAP_JET)
        overlay = cv2.addWeighted(frame, 0.6, color_mask, 0.4, 0)
        out.write(overlay)
        successful_count += 1
    else:
        out.write(frame)
    frame_idx += 1

cap.release()
out.release()

print(f"Done. {successful_count} frames were successfully overlaid.")
print(f"Saved output to {output_path}")

Video info:
  frame_count: 210
  frame_width: 1280
  frame_height: 720
  fps: 29
Frame 0: mask shape=(1, 1024, 1024), dtype=uint8
Skipping frame 0: mask.ndim=3 (expected 2).
Frame 1: mask shape=(1, 1024, 1024), dtype=uint8
Skipping frame 1: mask.ndim=3 (expected 2).
Frame 2: mask shape=(1, 1024, 1024), dtype=uint8
Skipping frame 2: mask.ndim=3 (expected 2).
Frame 3: mask shape=(1, 1024, 1024), dtype=uint8
Skipping frame 3: mask.ndim=3 (expected 2).
Frame 4: mask shape=(1, 1024, 1024), dtype=uint8
Skipping frame 4: mask.ndim=3 (expected 2).
Frame 5: mask shape=(1, 1024, 1024), dtype=uint8
Skipping frame 5: mask.ndim=3 (expected 2).
Frame 6: mask shape=(1, 1024, 1024), dtype=uint8
Skipping frame 6: mask.ndim=3 (expected 2).
Frame 7: mask shape=(1, 1024, 1024), dtype=uint8
Skipping frame 7: mask.ndim=3 (expected 2).
Frame 8: mask shape=(1, 1024, 1024), dtype=uint8
Skipping frame 8: mask.ndim=3 (expected 2).
Frame 9: mask shape=(1, 1024, 1024), dtype=uint8
Skipping frame 9: mask.ndim=3 (ex

In [47]:
import cv2
import json
import numpy as np
import torch
import matplotlib.pyplot as plt
from sam2.build_sam import build_sam2_video_predictor

# 1) Setup
video_path = "hand/test2.mp4"
keypoints_file = "hand_keypoints.json"
sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Device:", device)
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)
print("SAM 2 model built.")

with open(keypoints_file, "r") as f:
    hand_keypoints = json.load(f)

points = np.array([kp for hand in hand_keypoints for kp in hand["keypoints"]], dtype=np.float32)
labels = np.ones(len(points), dtype=np.int32)

print("Points shape:", points.shape, "dtype:", points.dtype)
print("Labels shape:", labels.shape, "dtype:", labels.dtype)

# 2) Initialize state on the video
inference_state = predictor.init_state(video_path=video_path)
print("Inference state initialized.")
predictor.reset_state(inference_state)
print("State reset.")

# 3) Add prompts (keypoints) on frame 0
predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=0,
    obj_id=1,
    points=points,
    labels=labels,
)
print("Prompts added on frame 0 with obj_id=1.")

# 4) Single propagate call
video_segments = {}
for f_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    if out_mask_logits is None or len(out_mask_logits) == 0:
        continue

    frame_masks = {}
    for i, o_id in enumerate(out_obj_ids):
        # Convert logits to [0,1] mask
        bin_mask = (out_mask_logits[i] > 0).cpu().numpy().astype(np.uint8)

        # If shape is (1, H, W), squeeze out the leading dim → (H, W)
        if bin_mask.ndim == 3 and bin_mask.shape[0] == 1:
            bin_mask = bin_mask.squeeze(0)

        frame_masks[o_id] = bin_mask
    video_segments[f_idx] = frame_masks

print("Propagation done. video_segments now contains per-frame masks.")

# 5) Overlay on every frame, skipping invalid shapes
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
    print("Could not open video.")
    cap.release()
    raise SystemExit

frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))

print(f"Video info:")
print(f"  frame_count: {frame_count}")
print(f"  frame_width: {frame_width}")
print(f"  frame_height: {frame_height}")
print(f"  fps: {fps}")

fourcc = cv2.VideoWriter_fourcc(*"mp4v")
output_path = "output_masked_skip_invalid.mp4"
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

frame_idx = 0
successful_count = 0

while True:
    ret, frame = cap.read()
    if not ret:
        print(f"Stopping at frame_idx={frame_idx} (read failure or end).")
        break

    if frame is None:
        print(f"Skipping frame {frame_idx}: frame is None.")
        out.write(np.zeros((1,1,3), dtype=np.uint8))
        frame_idx += 1
        continue

    h, w, c = frame.shape
    if h == 0 or w == 0:
        print(f"Skipping frame {frame_idx}: shape is {frame.shape}.")
        out.write(frame)
        frame_idx += 1
        continue

    if frame_idx in video_segments and 1 in video_segments[frame_idx]:
        mask = video_segments[frame_idx][1]
        print(f"Frame {frame_idx}: mask shape={mask.shape}, dtype={mask.dtype}")

        if mask is None or mask.size == 0:
            print(f"Skipping frame {frame_idx}: empty or None mask.")
            out.write(frame)
            frame_idx += 1
            continue

        if mask.ndim != 2:
            print(f"Skipping frame {frame_idx}: mask.ndim={mask.ndim}, expected 2.")
            out.write(frame)
            frame_idx += 1
            continue

        mh, mw = mask.shape
        if mh == 0 or mw == 0:
            print(f"Skipping frame {frame_idx}: mask shape is {mask.shape}.")
            out.write(frame)
            frame_idx += 1
            continue

        mask_255 = mask * 255
        print(f"Resizing mask from ({mh}, {mw}) -> ({w}, {h})")

        try:
            mask_resized = cv2.resize(mask_255, (w, h))
        except Exception as e:
            print(f"Skipping frame {frame_idx} due to resize error: {e}")
            out.write(frame)
            frame_idx += 1
            continue

        if mask_resized.size == 0:
            print(f"Skipping frame {frame_idx}: mask_resized.size=0.")
            out.write(frame)
            frame_idx += 1
            continue

        color_mask = cv2.applyColorMap(mask_resized, cv2.COLORMAP_JET)
        overlay = cv2.addWeighted(frame, 0.6, color_mask, 0.4, 0)
        out.write(overlay)
        successful_count += 1
    else:
        # No mask for this frame
        out.write(frame)

    frame_idx += 1

cap.release()
out.release()

print(f"Done. {successful_count} frames were successfully overlaid.")
print(f"Saved output to {output_path}")


Device: cpu
SAM 2 model built.
Points shape: (42, 2) dtype: float32
Labels shape: (42,) dtype: int32
Inference state initialized.
State reset.
Prompts added on frame 0 with obj_id=1.


propagate in video: 100%|███████████████████████| 60/60 [05:08<00:00,  5.14s/it]


Propagation done. video_segments now contains per-frame masks.
Video info:
  frame_count: 61
  frame_width: 1280
  frame_height: 720
  fps: 30
Frame 0: mask shape=(1024, 1024), dtype=uint8
Resizing mask from (1024, 1024) -> (1280, 720)
Frame 1: mask shape=(1024, 1024), dtype=uint8
Resizing mask from (1024, 1024) -> (1280, 720)
Frame 2: mask shape=(1024, 1024), dtype=uint8
Resizing mask from (1024, 1024) -> (1280, 720)
Frame 3: mask shape=(1024, 1024), dtype=uint8
Resizing mask from (1024, 1024) -> (1280, 720)
Frame 4: mask shape=(1024, 1024), dtype=uint8
Resizing mask from (1024, 1024) -> (1280, 720)
Frame 5: mask shape=(1024, 1024), dtype=uint8
Resizing mask from (1024, 1024) -> (1280, 720)
Frame 6: mask shape=(1024, 1024), dtype=uint8
Resizing mask from (1024, 1024) -> (1280, 720)
Frame 7: mask shape=(1024, 1024), dtype=uint8
Resizing mask from (1024, 1024) -> (1280, 720)
Frame 8: mask shape=(1024, 1024), dtype=uint8
Resizing mask from (1024, 1024) -> (1280, 720)
Frame 9: mask shape=