In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.

# Run Segment Anything Model 2 on a live video stream


In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from IPython import display
import time

In [2]:
def gpu_optimize():
    # use bfloat16 for the entire notebook
    torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
    
    if torch.cuda.get_device_properties(0).major >= 8:
        # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

In [3]:
gpu_optimize()

### Loading the SAM 2 camera predictor


In [4]:
from sam2.build_sam import build_sam2_camera_predictor

def get_predictor(model='tiny'):
    if model=='tiny':
        sam2_checkpoint = "./checkpoints/sam2.1_hiera_tiny.pt"
        model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
    elif model=='small':
        sam2_checkpoint = "./checkpoints/sam2.1_hiera_small.pt"
        model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
    elif model=='large':
        sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
        model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
    return build_sam2_camera_predictor(model_cfg, sam2_checkpoint)

In [5]:
def show_mask_cv(mask, background, obj_id=None, random_color=False, alpha=0.6, pause=1000):
    if random_color:
        color = np.random.randint(0, 256, size=(3,), dtype=np.uint8)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        rgb_color = (np.array(cmap(cmap_idx)[:3]) * 255).astype(np.uint8)
        color = tuple(rgb_color[::-1]) #RGB2BGR

    color_mask = np.zeros_like(background, dtype=np.uint8)
    for i in range(3):  # RGB
        color_mask[:, :, i] = mask * color[i]

    overlay = cv2.addWeighted(background, 1.0, color_mask, alpha, 0)

    cv2.imshow("Masked Frame", overlay)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        cv2.destroyAllWindows()

#### Step 1: Add a first click on the first frame
mask input is excluded   
default: by point

In [6]:
def selecting_obj(predictor, frame, points=None, labels=None,
                  bbox=None, ann_frame_idx=0, ann_obj_id=(1)):
    predictor.load_first_frame(frame)

    if bbox is not None:
        boxes = np.array(bbox, dtype=np.float32).reshape(1, 4)
        points_arr = None
        labels_arr = None
    else:
        points_arr = np.array(points if points is not None else [[310, 360]], dtype=np.float32)
        labels_arr = np.array(labels if labels is not None else [1], dtype=np.int32)
        boxes = None

    _, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
        frame_idx=ann_frame_idx,
        obj_id=ann_obj_id,
        points=points_arr,
        labels=labels_arr,
        bbox=boxes
    )

    mask = (out_mask_logits[0] > 0.0).cpu().numpy()
    show_mask_cv(mask, frame, obj_id=out_obj_ids[0])

    return predictor

In [7]:
# 경로로 저장된 비디오 입력 가능
# cap = cv2.VideoCapture("videos/aquarium/aquarium.mp4")
cap = cv2.VideoCapture(0)

ret, frame = cap.read()
width, height = frame.shape[:2][::-1]

In [8]:
predictor = get_predictor()

## point(default)로 객체 선택
# predictor = selecting_obj(predictor, frame)

# bounding box로 객체 선택 (예시)
predictor = selecting_obj(predictor, frame, bbox=[100, 100, 300, 450])


Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
  pred_masks_gpu = fill_holes_in_mask_scores(


#### Step 2: track


In [9]:
# list to track FPS and delayed time
li = []
def track_selected(predictor, cap, vis_gap=1, ann_frame_idx=0):
    
    delayed_time = 0
    prev_time = 0
    total_elapsed_time = 0
    tracking = True
    while tracking:
    
        ret, frame = cap.read()
        prev_time = time.time()
        ann_frame_idx += 1
        if not ret:
            break
        width, height = frame.shape[:2][::-1]
    
        out_obj_ids, out_mask_logits = predictor.track(frame)
    
#         if ann_frame_idx % vis_gap == 0:
#             show_mask_cv((out_mask_logits[0] > 0.0).cpu().numpy(), frame, obj_id=out_obj_ids[0])
#             delayed_time = (time.time() - prev_time) * 1000
#             if vis_gap <= 1:
#                 total_elapsed_time += delayed_time
#             else:
#                 total_elapsed_time += delayed_time * vis_gap
#             avg_delay = total_elapsed_time / ann_frame_idx
#            print(f"frame {ann_frame_idx} / delay {delayed_time:.2f}ms / avg_delay {avg_delay:.2f}ms")
        show_mask_cv((out_mask_logits[0] > 0.0).cpu().numpy(), frame, obj_id=out_obj_ids[0])
        delayed_time = (time.time() - prev_time) * 1000
        total_elapsed_time += delayed_time
        avg_delay = total_elapsed_time / ann_frame_idx
        if ann_frame_idx % 50 == 0:
            li.append(f"FPS: {1000/avg_delay}, delay: {avg_delay}ms")
            
        if cv2.waitKey(1) & 0xFF == ord('q'):
            tracking = False
            break
    
    cap.release()
    cv2.destroyAllWindows()

In [10]:
track_selected(predictor, cap)


Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
  pred_masks_gpu = fill_holes_in_mask_scores(


KeyboardInterrupt: 

In [11]:
cap.release()
cv2.destroyAllWindows()

In [12]:
li

['FPS: 17.839920057090414, delay: 56.054062843322754ms',
 'FPS: 17.929233818051557, delay: 55.774831771850586ms',
 'FPS: 17.92559868613261, delay: 55.786142349243164ms',
 'FPS: 17.885920956142563, delay: 55.90989708900452ms',
 'FPS: 17.74385223889356, delay: 56.357547760009766ms',
 'FPS: 17.78134339801432, delay: 56.238720417022705ms',
 'FPS: 17.804898244419217, delay: 56.16431985582624ms',
 'FPS: 17.825416440757365, delay: 56.09967112541199ms',
 'FPS: 17.846419248295565, delay: 56.03364944458008ms',
 'FPS: 17.828174325450586, delay: 56.09099292755127ms',
 'FPS: 17.800470410103237, delay: 56.17829062721946ms',
 'FPS: 17.776226037872103, delay: 56.2549102306366ms',
 'FPS: 17.749103727414674, delay: 56.34087305802565ms']