In [1]:
# Copyright (c) Meta Platforms, Inc. and affiliates.

# Video segmentation with SAM 2

This notebook shows how to use SAM 2 for interactive segmentation in videos. It will cover the following:

- adding clicks on a frame to get and refine _masklets_ (spatio-temporal masks) 
- propagating clicks to get _masklets_ throughout the video
- segmenting and tracking multiple objects at the same time

We use the terms _segment_ or _mask_ to refer to the model prediction for an object on a single frame, and _masklet_ to refer to the spatio-temporal masks across the entire video. 

If running locally using jupyter, first install `segment-anything-2` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything-2#installation) in the repository.

In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from tqdm import tqdm

In [2]:
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
os.environ['CUDA_VISIBLE_DEVICES']="1"

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

### Loading the SAM 2 video predictor

In [3]:
import torch 
print(torch.__version__)

2.5.1+cu121


In [4]:
from sam2.build_sam import build_sam2_video_predictor

# sam2_checkpoint = "./checkpoints/sam2.1_hiera_small.pt"
# sam2_checkpoint = "./checkpoints/sam2.1_hiera_s_ioct.pt"
sam2_checkpoint = "./sam2_logs/configs/sam2.1_training/sam2.1_hiera_s_ioct/checkpoints/checkpoint.pt"

model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)

In [5]:
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

### Select an example video

In [6]:
video_name = 'OD-2025-01-14_153544_test_1_cropped' # 'OD-2024-11-26_151419_test_1_cropped' 'OD-2025-01-14_153544_test_1_cropped' 
video_path = f"../data/videos/{video_name}.mp4"
support_dir ="../data/SUP"
support_frame_dir = os.path.join(support_dir, 'JPEGImages')
support_mask_dir = os.path.join(support_dir, 'Annotations')
inference_state = predictor.init_state(video_path=video_path, support_path=support_frame_dir)
predictor.reset_state(inference_state)

frame loading (JPEG): 100%|██████████| 1/1 [00:00<00:00, 99.79it/s]


#### Method 1: Add support masks

In [7]:
from collections import defaultdict

def load_ann_png(path):
    """Load a PNG file as a mask and its palette."""
    mask = Image.open(path)
    palette = mask.getpalette()
    mask = np.array(mask).astype(np.uint8)
    return mask, palette

def get_per_obj_mask(mask):
    """Split a mask into per-object masks."""
    object_ids = np.unique(mask)
    object_ids = object_ids[object_ids > 0].tolist()
    per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids}
    return per_obj_mask

def load_masks_from_dir(
    input_mask_dir, frame_name, allow_missing=False
):
    """Load masks from a directory as a dict of per-object masks."""
    input_mask_path = os.path.join(input_mask_dir, frame_name)
    if not os.path.exists(input_mask_path):
        pass
    if allow_missing and not os.path.exists(input_mask_path):
        return {}, None
    input_mask, input_palette = load_ann_png(input_mask_path)
    per_obj_input_mask = get_per_obj_mask(input_mask)

    return per_obj_input_mask, input_palette


# collect all the object ids and the support set
use_all_masks=False
inputs_per_object = defaultdict(dict)
support_mask_dir = os.path.join(support_dir, 'Annotations')
video_segments = defaultdict(dict)
if os.path.exists(support_mask_dir):
    sup_frame_names = os.listdir(support_mask_dir)
    for idx, name in enumerate(sup_frame_names):
        per_obj_input_mask, input_palette = load_masks_from_dir(
            input_mask_dir=support_mask_dir,
            frame_name=name,
            allow_missing=False,
        )
        for object_id, object_mask in per_obj_input_mask.items():
            # skip empty masks
            if not np.any(object_mask):
                continue
            # if `use_all_masks=False`, we only use the first mask for each object
            if len(inputs_per_object[object_id]) > 0 and not use_all_masks:
                continue
            inputs_per_object[object_id][idx] = object_mask

anno_frame_ids = [0, 50, 100, 150, 200]
object_ids = sorted(inputs_per_object)
predictor.reset_state(inference_state)

for anno_frame_id in anno_frame_ids:
    for object_id in object_ids:
        # add those input masks to SAM 2 inference state before propagation
        input_frame_inds = sorted(inputs_per_object[object_id])
        # predictor.reset_state(inference_state)
        for input_frame_idx in input_frame_inds:
            predictor.add_sup_mask(
                inference_state=inference_state,
                frame_idx=anno_frame_id,
                obj_id=object_id,
                sup_frame_idx=input_frame_idx,
                sup_mask=inputs_per_object[object_id][input_frame_idx],
            )
            print(f"adding mask from support frame {input_frame_idx} for frame {anno_frame_id} as input for {object_id=}")

for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    c, _, h, w = out_mask_logits.shape
    output_mask = np.zeros((1, h, w), dtype=int)
    for i, out_obj_id in enumerate(out_obj_ids):
        output_mask[(out_mask_logits[i] > 0.0).cpu()] = out_obj_id
    output_mask = np.squeeze(output_mask)
    video_segments[out_frame_idx] = output_mask


Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
  pred_masks_gpu = fill_holes_in_mask_scores(


adding mask from support frame 0 for frame 0 as input for object_id=1
adding mask from support frame 0 for frame 0 as input for object_id=2
adding mask from support frame 0 for frame 0 as input for object_id=3
adding mask from support frame 0 for frame 50 as input for object_id=1
adding mask from support frame 0 for frame 50 as input for object_id=2
adding mask from support frame 0 for frame 50 as input for object_id=3
adding mask from support frame 0 for frame 100 as input for object_id=1
adding mask from support frame 0 for frame 100 as input for object_id=2
adding mask from support frame 0 for frame 100 as input for object_id=3
adding mask from support frame 0 for frame 150 as input for object_id=1
adding mask from support frame 0 for frame 150 as input for object_id=2
adding mask from support frame 0 for frame 150 as input for object_id=3
adding mask from support frame 0 for frame 200 as input for object_id=1
adding mask from support frame 0 for frame 200 as input for object_id=2
a

propagate in video: 100%|██████████| 643/643 [00:35<00:00, 18.32it/s]


In [8]:
plot_dir = f'./results/demo/offline_fs/{video_name}'
os.makedirs(plot_dir, exist_ok=True)

cmap = plt.get_cmap("tab10")
class_to_color = {
    1: cmap(1),
    2: cmap(2),
    3: cmap(3),
}

cap = cv2.VideoCapture(video_path)
ret, frame = cap.read()
width, height = frame.shape[:2][::-1]

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(plot_dir + "/output.mp4", fourcc, 30.0, (width, height))

# render the segmentation results every few frames
out_frame_idx = 0
vis_gap = 50
plt.close("all")
plt.figure(figsize=(6, 3))
while True:

    ret, frame = cap.read()
    if not ret:
        break
    width, height = frame.shape[:2][::-1]

    if (out_frame_idx % vis_gap) == 0: print(f"frame {out_frame_idx}")
    output_mask = video_segments[out_frame_idx]
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    colored_mask = np.zeros((output_mask.shape[0], output_mask.shape[1], 4))
    for class_id, color in class_to_color.items():
        colored_mask[output_mask == class_id] = color  # Assign the corresponding color from the colormap
    plt.figure(figsize=(6, 3))
    plt.imshow(frame)
    plt.imshow(colored_mask, alpha=0.6)  # Adjust alpha to control transparency
    plt.axis('off')  # Turn off axis for cleaner display
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, str(out_frame_idx) + '.png'), dpi=100)
    plt.close("all")

    overlay = 0.6 * frame.astype(np.float32) / 255.0 + 0.6 * colored_mask[:, :, :3]
    overlay = np.clip(overlay, 0, 1)
    overlay = (overlay * 255).astype(np.uint8)

    out.write(cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))

    out_frame_idx += 1

cap.release()
out.release()


frame 0
frame 50
frame 100
frame 150
frame 200
frame 250
frame 300
frame 350
frame 400
frame 450
frame 500
frame 550
frame 600


### Select the testing set

In [None]:
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
# video_dir = "./datasets/iOCT_lSNR/valid/JPEGImages/seq_1"
video_dir = "./datasets/videos/OD-2025-01-14_153544_test_1_demo"
support_dir ="./datasets/iOCT_lSNR/valid/SUP"
support_frame_dir = os.path.join(support_dir, 'JPEGImages')
support_mask_dir = os.path.join(support_dir, 'Annotations')

# scan all the JPEG frame names in this directory
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", '.png']
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(12, 8))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))

In [None]:
inference_state = predictor.init_state(video_path=video_dir, support_path=support_frame_dir)

#### Method 1: Add support masks

In [None]:
from collections import defaultdict

def load_ann_png(path):
    """Load a PNG file as a mask and its palette."""
    mask = Image.open(path)
    palette = mask.getpalette()
    mask = np.array(mask).astype(np.uint8)
    return mask, palette

def get_per_obj_mask(mask):
    """Split a mask into per-object masks."""
    object_ids = np.unique(mask)
    object_ids = object_ids[object_ids > 0].tolist()
    per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids}
    return per_obj_mask

def load_masks_from_dir(
    input_mask_dir, frame_name, per_obj_png_file, allow_missing=False
):
    """Load masks from a directory as a dict of per-object masks."""
    input_mask_path = os.path.join(input_mask_dir, frame_name)
    if not os.path.exists(input_mask_path):
        pass
    if allow_missing and not os.path.exists(input_mask_path):
        return {}, None
    input_mask, input_palette = load_ann_png(input_mask_path)
    per_obj_input_mask = get_per_obj_mask(input_mask)

    return per_obj_input_mask, input_palette


# collect all the object ids and the support set
use_all_masks=False
inputs_per_object = defaultdict(dict)
support_mask_dir = os.path.join(support_dir, 'Annotations')
video_segments = defaultdict(dict)
if os.path.exists(support_mask_dir):
    sup_frame_names = os.listdir(support_mask_dir)
    for idx, name in enumerate(sup_frame_names):
        per_obj_input_mask, input_palette = load_masks_from_dir(
            input_mask_dir=support_mask_dir,
            frame_name=name,
            per_obj_png_file=False,  # our dataset combines all object masks into a single PNG file
            allow_missing=False,
        )
        for object_id, object_mask in per_obj_input_mask.items():
            # skip empty masks
            if not np.any(object_mask):
                continue
            # if `use_all_masks=False`, we only use the first mask for each object
            if len(inputs_per_object[object_id]) > 0 and not use_all_masks:
                continue
            inputs_per_object[object_id][idx] = object_mask

anno_frame_ids = [0, 50, 100]
object_ids = sorted(inputs_per_object)
predictor.reset_state(inference_state)

for anno_frame_id in anno_frame_ids:
    for object_id in object_ids:
        # add those input masks to SAM 2 inference state before propagation
        input_frame_inds = sorted(inputs_per_object[object_id])
        # predictor.reset_state(inference_state)
        for input_frame_idx in input_frame_inds:
            predictor.add_sup_mask(
                inference_state=inference_state,
                frame_idx=anno_frame_id,
                obj_id=object_id,
                sup_frame_idx=input_frame_idx,
                sup_mask=inputs_per_object[object_id][input_frame_idx],
            )
            print(f"adding mask from support frame {input_frame_idx} for frame {anno_frame_id} as input for {object_id=}")

for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    c, _, h, w = out_mask_logits.shape
    output_mask = np.zeros((1, h, w), dtype=int)
    for i, out_obj_id in enumerate(out_obj_ids):
        output_mask[(out_mask_logits[i] > 0.0).cpu()] = out_obj_id
    output_mask = np.squeeze(output_mask)
    video_segments[out_frame_idx] = output_mask

In [None]:
# plot_dir = './results/demo/offline_sup/lsnr_seq_1'
plot_dir = './results/demo/offline_sup/OD-2025-01-14_153544_test_1'
os.makedirs(plot_dir, exist_ok=True)

cmap = plt.get_cmap("tab10")
class_to_color = {
    1: cmap(1),
    2: cmap(2),
    3: cmap(3),
}

# render the segmentation results every few frames
vis_frame_stride = 1
plt.close("all")
for out_frame_idx in tqdm(range(0, len(frame_names), vis_frame_stride)):
    output_mask = video_segments[out_frame_idx]

    colored_mask = np.zeros((output_mask.shape[0], output_mask.shape[1], 4))
    for class_id, color in class_to_color.items():
        colored_mask[output_mask == class_id] = color  # Assign the corresponding color from the colormap
    raw_image = np.array(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    plt.figure(figsize=(6, 3))
    plt.imshow(raw_image)
    plt.imshow(colored_mask, alpha=0.6)  # Adjust alpha to control transparency
    plt.axis('off')  # Turn off axis for cleaner display
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, frame_names[out_frame_idx]), dpi=100) if vis_frame_stride==1 else plt.show() 
    plt.close()

In [None]:
# plot_dir = './results/demo/offline_sup'
# os.makedirs(plot_dir, exist_ok=True)

# cmap = plt.get_cmap("tab10")
# class_to_color = {
#     1: cmap(1),
#     2: cmap(2),
#     3: cmap(3),
# }

# # render the segmentation results every few frames
# vis_frame_stride = 1
# plt.close("all")
# plt.figure(figsize=(6, 3))
# for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
#     # output_mask = video_segments[out_frame_idx][2]
#     output_mask = np.zeros_like(video_segments[out_frame_idx][1])
#     for object_id in object_ids:
#         current_mask = video_segments[out_frame_idx][object_id]
#         output_mask[current_mask == object_id] = object_id

#     colored_mask = np.zeros((output_mask.shape[0], output_mask.shape[1], 4))
#     for class_id, color in class_to_color.items():
#         colored_mask[output_mask == class_id] = color  # Assign the corresponding color from the colormap
#     raw_image = np.array(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
#     plt.imshow(raw_image)
#     plt.imshow(colored_mask, alpha=0.6)  # Adjust alpha to control transparency
#     plt.axis('off')  # Turn off axis for cleaner display
#     plt.tight_layout()
#     plt.savefig(os.path.join(plot_dir, frame_names[out_frame_idx]), dpi=100) if vis_frame_stride==1 else plt.show() 

#### Method 2: Add a first click on a frame

To get started, let's try to segment the child on the left.

Here we make a **positive click** at (x, y) = (210, 350) with label `1`, by sending their coordinates and labels into the `add_new_points` API.

Note: label `1` indicates a *positive click (to add a region)* while label `0` indicates a *negative click (to remove a region)*.

In [None]:
predictor.reset_state(inference_state)

In [None]:
prompts = {}
ann_frame_idx = 0  # the frame index we interact with


# ----------------- annotate the tissue -----------------
ann_obj_id = 1
points = np.array([[100, 200], [150, 240]], dtype=np.float32)
labels = np.array([1, 0], np.int32)

prompts[ann_obj_id] = points, labels

_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# ----------------- annotate the tool -----------------
ann_obj_id = 2
points = np.array([[262, 120], [282, 137], [356, 19]], dtype=np.float32)
labels = np.array([1, 1, 1], np.int32)

prompts[ann_obj_id] = points, labels

_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# ----------------- annotate artifacts -----------------
ann_obj_id = 3
points = np.array([[414, 7], [520, 95], [275, 147], [374, 3]], dtype=np.float32)
labels = np.array([1, 1, 1, 0], np.int32)

prompts[ann_obj_id] = points, labels

_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# ----------------- show the results on the first frame -----------------
# show the results on the current (interacted) frame on all objects
fig_size = (10, 3)
fig, axs = plt.subplots(1, 2, figsize=fig_size)

axs[0].imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))

axs[1].imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
    show_points(*prompts[out_obj_id], plt.gca())
    show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)

fig.suptitle(f"{frame_names[frame_idx]}")
plt.tight_layout()
plt.show()

Propagate the prompts to get the masklet across the video

To get the masklet throughout the entire video, we propagate the prompts using the `propagate_in_video` API.

In [None]:
# run propagation throughout the video and collect the results in a dict
video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    c, _, h, w = out_mask_logits.shape
    output_mask = np.zeros((1, h, w), dtype=int)
    for i, out_obj_id in enumerate(out_obj_ids):
        output_mask[(out_mask_logits[i] > 0.0).cpu()] = out_obj_id
    output_mask = np.squeeze(output_mask)
    video_segments[out_frame_idx] = output_mask

In [None]:
plot_dir = './results/demo/offline_sup'
os.makedirs(plot_dir, exist_ok=True)

cmap = plt.get_cmap("tab10")
class_to_color = {
    1: cmap(1),
    2: cmap(2),
    3: cmap(3),
}

# render the segmentation results every few frames
vis_frame_stride = 1
plt.close("all")
plt.figure(figsize=(6, 3))
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    output_mask = video_segments[out_frame_idx]

    colored_mask = np.zeros((output_mask.shape[0], output_mask.shape[1], 4))
    for class_id, color in class_to_color.items():
        colored_mask[output_mask == class_id] = color  # Assign the corresponding color from the colormap
    raw_image = np.array(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    plt.imshow(raw_image)
    plt.imshow(colored_mask, alpha=0.6)  # Adjust alpha to control transparency
    plt.axis('off')  # Turn off axis for cleaner display
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, frame_names[out_frame_idx]), dpi=100) if vis_frame_stride==1 else plt.show() 