<td>
   <a target="_blank" href="https://labelbox.com" ><img src="https://labelbox.com/blog/content/images/2021/02/logo-v4.svg" width=256/></a>
</td>


<td>
<a href="https://colab.research.google.com/github/Labelbox/labelbox-python/blob/develop/examples/integrations/sam/meta_sam_labelbox_video.ipynb" target="_blank"><img
src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
</td>

<td>
<a href="https://github.com/Labelbox/labelbox-python/blob/develop/examples/integrations/sam/meta_sam_labelbox_video.ipynb" target="_blank"><img
src="https://img.shields.io/badge/GitHub-100000?logo=github&logoColor=white" alt="GitHub"></a>
</td>

# Setting the stage

First, we import and prepare the prerequisites to process the video.

### General dependencies

In [None]:
!nvidia-smi

In [None]:
import os
HOME = os.getcwd()
print(HOME)

import sys
from google.colab.patches import cv2_imshow
import cv2
import PIL
from PIL import Image
import numpy as np
import uuid
import tempfile

from IPython import display
display.clear_output()
from IPython.display import display, Image
from io import BytesIO

In [None]:
# You can also use the Labelbox Client API to get specific videos or an entire
# dataset from your Catalog. Refer to these docs:
# https://labelbox-python.readthedocs.io/en/latest/#labelbox.client.Client.get_data_row

VIDEO_PATH = "https://storage.googleapis.com/labelbox-datasets/image_sample_data/skateboarding.mp4"

%cd {HOME}
!wget -v {VIDEO_PATH}

### YOLOv8 dependencies

In [None]:
# Dependencies for YOLOv8

!pip install ultralytics==8.0.20

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting ultralytics==8.0.20
  Downloading ultralytics-8.0.20-py3-none-any.whl.metadata (24 kB)
Collecting pandas>=1.1.4 (from ultralytics==8.0.20)
  Downloading pandas-2.2.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (19 kB)
Collecting seaborn>=0.11.0 (from ultralytics==8.0.20)
  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting thop>=0.1.1 (from ultralytics==8.0.20)
  Downloading thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Collecting sentry-sdk (from ultralytics==8.0.20)
  Downloading sentry_sdk-1.42.0-py2.py3-none-any.whl.metadata (9.8 kB)
Collecting pytz>=2020.1 (from pandas>=1.1.4->ultralytics==8.0.20)
  Downloading pytz-2024.1-py2.py3-none-any.whl.metadata (22 kB)
Collecting tzdata>=2022.7 (from pandas>=1.1.4->ultralytics==8.0.20)
  Downloading tzdata-2024.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting matplotlib>=3.2.2 (from ultralytics==8.0.20)
  Downloading m

In [None]:
# Import packages

import ultralytics
ultralytics.checks()
from ultralytics import YOLO

In [None]:
# Instantiate YOLOv8 model

model = YOLO(f'{HOME}/yolov8n.pt')
colors = np.random.randint(0, 256, size=(len(model.names), 3))

print(model.names)

# Specify which classes you care about. The rest of classes will be filtered out.
chosen_class_ids = [0] # person

### SAM dependencies

In [None]:
# Download SAM model SDK

%cd {HOME}
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

In [None]:
# Download SAM model weights

%cd {HOME}
!mkdir {HOME}/weights
%cd {HOME}/weights

!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))

In [None]:
# Import packages

import torch
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

In [None]:
# Instantiate SAM model

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
sam = sam_model_registry["vit_h"](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
mask_predictor = SamPredictor(sam)

### Labelbox dependencies

In [None]:
# Install labelbox package

!pip install -q "labelbox[data]"

In [None]:
# Import packages

import labelbox as lb
import labelbox.types as lb_types

In [None]:
# Create a Labelbox API key for your account by following the instructions here:
# https://docs.labelbox.com/reference/create-api-key
# Then, fill it in here

API_KEY = ""
client = lb.Client(API_KEY)

### Helper functions

In [None]:
# Cast color to ints
def get_color(color):
  return (int(color[0]), int(color[1]), int(color[2]))

# Get video dimensions
def get_video_dimensions(input_cap):
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  return height, width

# Get output video writer with same dimensions and fps as input video
def get_output_video_writer(input_cap, output_path):
  # Get the video's properties (width, height, FPS)
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  fps = int(cap.get(cv2.CAP_PROP_FPS))

  # Define the output video file
  output_codec = cv2.VideoWriter_fourcc(*"mp4v")  # MP4 codec
  output_video = cv2.VideoWriter(output_path, output_codec, fps, (width, height))

  return output_video

# Visualize a video frame with bounding boxes, classes and confidence scores
def visualize_detections(frame, boxes, conf_thresholds, class_ids):
    frame_copy = np.copy(frame)
    for idx in range(len(boxes)):
        class_id = int(class_ids[idx])
        conf = float(conf_thresholds[idx])
        x1, y1, x2, y2 = int(boxes[idx][0]), int(boxes[idx][1]), int(boxes[idx][2]), int(boxes[idx][3])
        color = colors[class_id]
        label = f"{model.names[class_id]}: {conf:.2f}"
        cv2.rectangle(frame_copy, (x1, y1), (x2, y2), get_color(color), 2)
        cv2.putText(frame_copy, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, get_color(color), 2)
    return frame_copy

def add_color_to_mask(mask, color):
  next_mask = mask.astype(np.uint8)
  next_mask = np.expand_dims(next_mask, 0).repeat(3, axis=0)
  next_mask = np.moveaxis(next_mask, 0, -1)
  return next_mask * color

# Merge masks into a single, multi-colored mask
def merge_masks_colored(masks, class_ids):
  filtered_class_ids = []
  filtered_masks = []
  for idx, cid in enumerate(class_ids):
    if int(cid) in chosen_class_ids:
      filtered_class_ids.append(cid)
      filtered_masks.append(masks[idx])

  merged_with_colors = add_color_to_mask(filtered_masks[0][0], get_color(colors[int(filtered_class_ids[0])])).astype(np.uint8)

  if len(filtered_masks) == 1:
    return merged_with_colors

  for i in range(1, len(filtered_masks)):
    curr_mask_with_colors = add_color_to_mask(filtered_masks[i][0], get_color(colors[int(filtered_class_ids[i])]))
    merged_with_colors = np.bitwise_or(merged_with_colors, curr_mask_with_colors)

  return merged_with_colors.astype(np.uint8)

def get_instance_uri(client, global_key, array):
    """ Reads a numpy array into a temp Labelbox data row to-be-uploaded to Labelbox
    Args:
        client        :   Required (lb.Client) - Labelbox Client object
        global_key    :   Required (str) - Data row global key
        array         :   Required (np.ndarray) - NumPy ndarray representation of an image
    Returns:
        Temp Labelbox data row to-be-uploaded to Labelbox as row data
    """
    # Convert array to PIL image
    image_as_pil = PIL.Image.fromarray(array)
    # Convert PIL image to PNG file bytes
    image_as_bytes = BytesIO()
    image_as_pil.save(image_as_bytes, format='PNG')
    image_as_bytes = image_as_bytes.getvalue()
    # Convert PNG file bytes to a temporary Labelbox URL
    url = client.upload_data(
        content=image_as_bytes,
        filename=global_key,
        content_type="image/jpeg",
        sign=True
    )
    # Return the URL
    return url

def get_local_instance_uri(array):
    # Convert array to PIL image
    image_as_pil = PIL.Image.fromarray(array)

    with tempfile.NamedTemporaryFile(suffix='.png', dir="/content", delete=False) as temp_file:
      image_as_pil.save(temp_file)
      file_name = temp_file.name

    # Return the URL
    return file_name

def create_mask_frame(frame_num, instance_uri):
  return lb_types.MaskFrame(index=frame_num, instance_uri=instance_uri)

def create_mask_instances(class_ids):
  instances = []
  for cid in list(set(class_ids)): # get unique class ids
    if int(cid) in chosen_class_ids:
      color = get_color(colors[int(cid)])
      name = model.names[int(cid)]
      instances.append(lb_types.MaskInstance(color_rgb=color, name=name))
  return instances

def create_video_mask_annotation(frames, instance):
  return lb_types.VideoMaskAnnotation(
        frames=frames,
        instances=[instance]
    )

### Labelbox setup

In [None]:
# Create a new dataset

# read more here: https://docs.labelbox.com/reference/data-row-global-keys
global_key = os.path.basename(VIDEO_PATH)

asset = {
    "row_data": VIDEO_PATH,
    "global_key": global_key,
    #"media_type": "VIDEO"
}

dataset = client.create_dataset(name="yolo-sam-video-masks-dataset")
task = dataset.create_data_rows([asset])
task.wait_till_done()

print(f"Errors: {task.errors}")
print(f"Failed data rows: {task.failed_data_rows}")

In [None]:
# Run through YOLOv8 on the video once quickly to get unique class ids present
# This will inform which classes we add to the ontology

cap = cv2.VideoCapture(VIDEO_PATH)

unique_class_ids = set()

# Loop through the frames of the video
frame_num = 1
while cap.isOpened():
  if frame_num % 30 == 0 or frame_num == 1:
    print("Processing frame number", frame_num)
  ret, frame = cap.read()
  if not ret:
      break

  # Run frame through YOLOv8 and get class ids predicted
  detections = model.predict(frame, conf=0.7) # frame is a numpy array
  for cid in detections[0].boxes.cls:
    unique_class_ids.add(int(cid))
  frame_num += 1

cap.release()

In [None]:
unique_class_ids

In [None]:
# Create a new ontology if you don't have one

# Add all chosen classes into the ontology
tools = []
for cls in chosen_class_ids:
  tools.append(lb.Tool(tool=lb.Tool.Type.RASTER_SEGMENTATION, name=model.names[cls]))

ontology_builder = lb.OntologyBuilder(
    classifications=[],
    tools=tools
  )

ontology = client.create_ontology("yolo-sam-video-masks-ontology",
                                  ontology_builder.asdict()
                                  #media_type=lb.MediaType.Image
                                  )

# Or get an existing ontology by name or ID (uncomment one of the below)

# ontology = client.get_ontologies("Demo Chair").get_one()

# ontology = client.get_ontology("clhee8kzt049v094h7stq7v25")

In [None]:
# Create a new project if you don't have one

# Project defaults to batch mode with benchmark quality settings if this argument is not provided
# Queue mode will be deprecated once dataset mode is deprecated
project = client.create_project(name="yolo-sam-video-masks-project",
                                media_type=lb.MediaType.Video)

# Or get an existing project by ID (uncomment the below)

# project = get_project("fill_in_project_id")

# If the project already has an ontology set up, comment out this line
project.setup_editor(ontology)

In [None]:
# Create a new batch of data for the project you specified above

# Uncomment if you are using `data_rows` parameter below
# data_row_ids = client.get_data_row_ids_for_global_keys([global_key])['results']

batch = project.create_batch(
    "yolo-sam-video-masks-project",  # each batch in a project must have a unique name

    # you can also specify global_keys instead of data_rows
    global_keys=[global_key],

    # you can also specify data_rows instead of global_keys
    #data_rows=data_row_ids,

    priority=1  # priority between 1(highest) - 5(lowest)
)

print(f"Batch: {batch}")

In [None]:
tools = ontology.tools()

feature_schema_ids = dict()
for tool in tools:
  feature_schema_ids[tool.name] = tool.feature_schema_id

print(feature_schema_ids)

### Loop through each frame of video and process it

In [None]:
# Run YOLOv8 and then SAM on each frame, and write visualization videos to disk
# You can download /content/skateboarding_boxes.mp4 and /content/skateboarding_masks.mp4
# to visualize the results

cap = cv2.VideoCapture(VIDEO_PATH)

output_video_boxes = get_output_video_writer(cap, "/content/skateboarding_boxes.mp4")
output_video_masks = get_output_video_writer(cap, "/content/skateboarding_masks.mp4")
mask_frames = []

# Loop through the frames of the video
frame_num = 1
while cap.isOpened():
  if frame_num % 30 == 0 or frame_num == 1:
    print("Processing frames", frame_num, "-", frame_num+29)
  ret, frame = cap.read()
  if not ret:
      break

  # Run frame through YOLOv8 to get detections
  detections = model.predict(frame, conf=0.7) # frame is a numpy array

  # Write detections to output video
  frame_with_detections = visualize_detections(frame,
                                                 detections[0].boxes.cpu().xyxy,
                                                 detections[0].boxes.cpu().conf,
                                                 detections[0].boxes.cpu().cls)
  output_video_boxes.write(frame_with_detections)

  # Run frame and detections through SAM to get masks
  transformed_boxes = mask_predictor.transform.apply_boxes_torch(detections[0].boxes.xyxy, list(get_video_dimensions(cap)))
  if len(transformed_boxes) == 0:
    print("No boxes found on frame", frame_num)
    output_video_masks.write(frame)
    frame_num += 1
    continue
  mask_predictor.set_image(frame)
  masks, scores, logits = mask_predictor.predict_torch(
    boxes = transformed_boxes,
    multimask_output=False,
    point_coords=None,
    point_labels=None
  )
  masks = np.array(masks.cpu())
  if masks is None or len(masks) == 0:
    print("No masks found on frame", frame_num)
    output_video_masks.write(frame)
    frame_num += 1
    continue
  merged_colored_mask = merge_masks_colored(masks, detections[0].boxes.cls)

  # Write masks to output video
  image_combined = cv2.addWeighted(frame, 0.7, merged_colored_mask, 0.7, 0)
  output_video_masks.write(image_combined)

  # Create video mask annotation for upload to Labelbox
  instance_uri = get_instance_uri(client, global_key, merged_colored_mask)
  mask_frame = create_mask_frame(frame_num, instance_uri)
  mask_frames.append(mask_frame)

  frame_num += 1

  # For the purposes of this demo, only look at the first 90 frames
  if frame_num > 90:
    break

cap.release()
output_video_boxes.release()
output_video_masks.release()
cv2.destroyAllWindows()

In [None]:
# Create annotations for LB upload
mask_instances = create_mask_instances(unique_class_ids)
annotations = []
for instance in mask_instances:
  annotations.append(create_video_mask_annotation(mask_frames, instance))

labels = []
labels.append(
    lb_types.Label(data=lb_types.VideoData(global_key=global_key),
                   annotations=annotations))

In [None]:
# Upload the predictions to your specified project and data rows as pre-labels
# Note: This may take a few minutes, depending on size of video and number of masks

upload_job = lb.MALPredictionImport.create_from_objects(
    client=client,
    project_id=project.uid,
    name="mal_import_job" + str(uuid.uuid4()),
    predictions=labels
)
upload_job.wait_until_done()
print(f"Errors: {upload_job.errors}", )
print(f"Status of uploads: {upload_job.statuses}")