In [12]:
import torch
import torchvision
import torchaudio
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
import sys
!{sys.executable} -m pip install opencv-python matplotlib supervision pillow mediapipe
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/sam2.git'

!mkdir -p /content/checkpoints/
!mkdir -p /content/data/
!mkdir -p /content/data/frames/
!mkdir -p /content/logs


!wget -P /content/checkpoints/ https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt

PyTorch version: 2.5.1+cu121
Torchvision version: 0.20.1+cu121
CUDA is available: True
Collecting git+https://github.com/facebookresearch/sam2.git
  Cloning https://github.com/facebookresearch/sam2.git to /tmp/pip-req-build-9btigizx
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/sam2.git /tmp/pip-req-build-9btigizx
  Resolved https://github.com/facebookresearch/sam2.git to commit 2b90b9f5ceec907a1c18123530e92e794ad901a4
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
--2025-01-26 01:13:07--  https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.163.189.14, 3.163.189.96, 3.163.189.108, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.163.189.14|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 

Upload the input video as 'test.mp4' to '/content/data/'

In [14]:
!wget -q /content/checkpoints/ https://storage.googleapis.com/mediapipe-models/hand_landmarker/hand_landmarker/float16/1/hand_landmarker.task

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

using device: cuda


In [18]:
import logging
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import numpy as np
import cv2
import json
import os
from PIL import Image, ImageDraw
from sam2.build_sam import build_sam2_video_predictor
import supervision as sv

In [4]:
DATA_DIR = '/content/data'
FRAMES_DIR = '/content/data/frames'
CHECKPOINTS = '/content/checkpoints'
LOGS = 'content/logs'

device = torch.device("cpu")
CHECKPOINT = f"{CHECKPOINTS}/sam2.1_hiera_large.pt"
CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"

predictor = build_sam2_video_predictor(CONFIG, CHECKPOINT, device=device)

In [5]:
#UTIL FUNCTIONS

def frames_generator(source_video):
    videoInfo = sv.VideoInfo.from_video_path(source_video)

    START_IDX = 0
    END_IDX = videoInfo.total_frames

    frames_generator = sv.get_video_frames_generator(source_video, start=START_IDX, end=END_IDX)
    images_sink = sv.ImageSink(
        target_dir_path=FRAMES_DIR,
        overwrite=True,
        image_name_pattern="{:05d}.jpg"
    )

    with images_sink:
        for frame in frames_generator:
            images_sink.save_image(frame)


def save_bounding_boxes(detection_result):
    hand_landmarks_list = detection_result.hand_landmarks
    bounding_boxes = []

    frames = sorted(f for f in os.listdir(FRAMES_DIR) if f.endswith(('.jpg', '.png')))
    first_frame_path = os.path.join(FRAMES_DIR, frames[0])
    img = Image.open(first_frame_path)


    image_width = img.width
    image_height = img.height

    for idx, hand_landmarks in enumerate(hand_landmarks_list):

        x_coordinates = [landmark.x for landmark in hand_landmarks]
        y_coordinates = [landmark.y for landmark in hand_landmarks]

        #Computing the Bounding Box
        x_min = min(x_coordinates)
        y_min = min(y_coordinates)
        x_max = max(x_coordinates)
        y_max = max(y_coordinates)


        x_min_pixel = int(x_min * image_width)
        y_min_pixel = int(y_min * image_height)
        x_max_pixel = int(x_max * image_width)
        y_max_pixel = int(y_max * image_height)

        # Store the bounding box with hand index
        bounding_boxes.append({
            "hand_index": idx,
            "bounding_box": {
                "x_min": x_min_pixel,
                "y_min": y_min_pixel,
                "x_max": x_max_pixel,
                "y_max": y_max_pixel
            }
        })

    with open(f'{DATA_DIR}/hand_bbox.json', "w") as json_file:
        json.dump(bounding_boxes, json_file, indent=4)


In [6]:
#Part 1 - Detect hands in the first frame
def detect_hands(input_video_path):

    frames_generator(input_video_path)
    model_file = open(f'{CHECKPOINTS}/hand_landmarker.task', "rb")
    model_data = model_file.read()
    model_file.close()
    base_options = python.BaseOptions(model_asset_buffer=model_data)
    options = vision.HandLandmarkerOptions(base_options=base_options,num_hands=2)
    detector = vision.HandLandmarker.create_from_options(options)

    frames = sorted(f for f in os.listdir(f'{FRAMES_DIR}') if f.endswith(('.jpg')))
    first_frame_path = os.path.join(FRAMES_DIR, frames[0])

    image = mp.Image.create_from_file(first_frame_path)
    detection_result = detector.detect(image)

    #Logging the detection result
    logging.basicConfig(
        filename=f'{LOGS}/detection_results.log',
        level=logging.INFO,
        format='%(asctime)s - %(message)s'
    )
    logging.info(f'Detection result: {detection_result}')

    #Save Bounding Boxes to prevent data loss
    save_bounding_boxes(detection_result)
    print('----------HAND DETECTION COMPLETED-------------------')

In [7]:
#Part 2 - SAM2 to track hands
def track_hands(input_video_path, output_video_path):
    #Initialize Inference state with all the frames
    inference_state = predictor.init_state(video_path=FRAMES_DIR)
    #Reset predictor if it has been used before
    predictor.reset_state(inference_state)

    with open(f'{DATA_DIR}/hand_bbox.json', 'r') as json_file:
        bbox = json.load(json_file)

    #types of hand classes
    hands = [0,1]
    for hand in hands:
        box = [b for b in bbox if b['hand_index'] == hand]

        #Handling the case where there is only one hand in the video
        if len(box) == 0:
            continue

        #Handling the case where there are multiple hands of the same type
        #i.e multiple left hands or multiple right hands
        boxx = np.array([[b["bounding_box"]["x_min"],
            b["bounding_box"]["y_min"],
            b["bounding_box"]['x_max'],
            b["bounding_box"]['y_max']]
            for b in box],
        dtype = np.float32)

        _, object_ids, mask_logits = predictor.add_new_points(
            inference_state=inference_state,
            frame_idx=0,
            obj_id=hand,
            box = boxx
        )

    #Propogating the masks to original video
    video_info = sv.VideoInfo.from_video_path(input_video_path)
    video_info.width = int(video_info.width)
    video_info.height = int(video_info.height)

    COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700']
    mask_annotator = sv.MaskAnnotator(
        color=sv.ColorPalette.from_hex(COLORS),
        color_lookup=sv.ColorLookup.CLASS)

    frame_sample = []
    SOURCE_FRAME_PATHS = sorted(sv.list_files_with_extensions(FRAMES_DIR, extensions=["jpg"]))
    with sv.VideoSink(output_video_path, video_info=video_info) as sink:
        for frame_idx, object_ids, mask_logits in predictor.propagate_in_video(inference_state):
            frame_path = SOURCE_FRAME_PATHS[frame_idx]
            frame = cv2.imread(frame_path)
            masks = (mask_logits > 0.0).cpu().numpy()
            masks = np.squeeze(masks).astype(bool)

            detections = sv.Detections(
                xyxy=sv.mask_to_xyxy(masks=masks),
                mask=masks,
                class_id=np.array(object_ids)
            )

            annotated_frame = mask_annotator.annotate(scene=frame.copy(), detections=detections)

            sink.write_frame(annotated_frame)
            if frame_idx % video_info.fps == 0:
                frame_sample.append(annotated_frame)

    print(f'----------OUTPUT VIDEO SAVED AT {output_video_path}-------------------')


In [8]:
input_video_path = f'{DATA_DIR}/test.mp4'
output_video_path = f'{DATA_DIR}/test_result.mp4'

In [19]:
detect_hands(input_video_path)

----------HAND DETECTION COMPLETED-------------------


In [20]:
track_hands(input_video_path, output_video_path)

frame loading (JPEG): 100%|██████████| 210/210 [00:09<00:00, 22.00it/s]

Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
propagate in video:   3%|▎         | 6/210 [04:29<2:32:48, 44.94s/it]


KeyboardInterrupt: 