<td>
   <a target="_blank" href="https://labelbox.com" ><img src="https://labelbox.com/blog/content/images/2021/02/logo-v4.svg" width=256/></a>
</td>

<td>
<a href="https://colab.research.google.com/github/Labelbox/labelbox-python/blob/develop/examples/integrations/sam/meta_sam_labelbox.ipynb" target="_blank"><img
src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
</td>

<td>
<a href="https://github.com/Labelbox/labelbox-python/blob/develop/examples/integrations/sam/meta_sam_labelbox.ipynb" target="_blank"><img
src="https://img.shields.io/badge/GitHub-100000?logo=github&logoColor=white" alt="GitHub"></a>
</td>

# Setup
This notebook is used to show how to use Meta's Segment Anything model to create masks that can then be uploaded to a Labelbox project

In [None]:
!pip install -q "labelbox[data]"
!pip install -q ultralytics==8.0.20
!pip install -q 'git+https://github.com/facebookresearch/segment-anything.git'

In [None]:
# Check if in google colab
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

In [None]:
from IPython import display
display.clear_output()

import ultralytics
ultralytics.checks()

import cv2
import numpy as np
from ultralytics import YOLO
from IPython.display import display, Image
import torch
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import os
import urllib.request
import uuid

import labelbox as lb
import labelbox.types as lb_types

HOME = os.getcwd()

if IN_COLAB:
    from google.colab.patches import cv2_imshow

Ultralytics YOLOv8.0.20 🚀 Python-3.10.13 torch-2.2.0 CPU
Setup complete ✅ (10 CPUs, 16.0 GB RAM, 251.9/460.4 GB disk)


# API Key and Client
Provide a valid api key below in order to properly connect to the Labelbox Client.

In [None]:
# Add your api key
API_KEY=""
# To get your API key go to: Workspace settings -> API -> Create API Key
client = lb.Client(api_key=API_KEY)

# Predicting bounding boxes around common objects using YOLOv8

First, we start with loading the YOLOv8 model, getting a sample image, and running the model on it to generate bounding boxes around some common objects.

### Utilize YOLOV8 to Create Bounding Boxes

We use YOLOV8 in this demo to obtain bounding boxes around our images that we can later feed into SAM for our masks.

Below we run inference on a image using the YOLOv8 model.

In [None]:
# You can also use the Labelbox Client API to get specific images or an entire
# dataset from your Catalog. Refer to these docs:
# https://labelbox-python.readthedocs.io/en/latest/#labelbox.client.Client.get_data_row

IMAGE_PATH = "https://storage.googleapis.com/labelbox-datasets/image_sample_data/chairs.jpeg"

In [None]:
model = YOLO(f'{HOME}/yolov8n.pt')
results = model.predict(source=IMAGE_PATH, conf=0.25)

# print(results[0].boxes.xyxy) # print bounding box coordinates

# print(results[0].boxes.conf) # print confidence scores

#for c in results[0].boxes.cls:
# print(model.names[int(c)]) # print predicted classes

Below we visualize the bounding boxes on the image using CV2.

In [None]:
image_bgr = cv2.imread("./chairs.jpeg")

for box in results[0].boxes.xyxy:
  cv2.rectangle(image_bgr, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0), 2)

if IN_COLAB:
  cv2_imshow(image_bgr)
else:
  cv2.imshow("demo", image_bgr)
  cv2.waitKey()

### Predicting segmentation masks using Meta's Segment Anything model

Now we load Meta's Segment Anything model and feed the bounding boxes to it, so it can generate segmentation masks within them.

In [None]:
# Download SAM model weights

CHECKPOINT_PATH = os.path.join(HOME, "sam_vit_h_4b8939.pth")

if not os.path.isfile(CHECKPOINT_PATH):
    req = urllib.request.urlretrieve("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "sam_vit_h_4b8939.pth")


In [None]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"

In [None]:
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
mask_predictor = SamPredictor(sam)

In [None]:
transformed_boxes = mask_predictor.transform.apply_boxes_torch(results[0].boxes.xyxy, image_bgr.shape[:2])

mask_predictor.set_image(image_bgr)

masks, scores, logits = mask_predictor.predict_torch(
    boxes = transformed_boxes,
    multimask_output=False,
    point_coords=None,
    point_labels=None
)
masks = np.array(masks.cpu())

# print(masks)
# print(scores)

Here we visualize the segmentation masks drawn on the image.

In [None]:
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

final_mask = None
for i in range(len(masks) - 1):
  if final_mask is None:
    final_mask = np.bitwise_or(masks[i][0], masks[i+1][0])
  else:
    final_mask = np.bitwise_or(final_mask, masks[i+1][0])

plt.figure(figsize=(10, 10))
plt.imshow(image_rgb)
plt.axis('off')
plt.imshow(final_mask, cmap='gray', alpha=0.7)

plt.show()

### Uploading predicted segmentation masks with class names to Labelbox using Python SDK

In [None]:
# Create a Labelbox ObjectAnnotation of type mask for each predicted mask

# Identifying what values in the numpy array correspond to the mask annotation
color = (1, 1, 1)

class_names = []
for c in results[0].boxes.cls:
  class_names.append(model.names[int(c)])

annotations = []
for idx, mask in enumerate(masks):
  mask_data = lb_types.MaskData.from_2D_arr(np.asarray(mask[0], dtype="uint8"))
  mask_annotation = lb_types.ObjectAnnotation(
    name = class_names[idx], # this is the class predicted in Step 1 (object detector)
    value=lb_types.Mask(mask=mask_data, color=color),
  )
  annotations.append(mask_annotation)

In [None]:
# Create a new dataset

# read more here: https://docs.labelbox.com/reference/data-row-global-keys
global_key = "my_unique_global_key"

test_img_url = {
    "row_data": IMAGE_PATH,
    "global_key": global_key
}

dataset = client.create_dataset(name="auto-mask-classification-dataset")
task = dataset.create_data_rows([test_img_url])
task.wait_till_done()

print(f"Errors: {task.errors}")
print(f"Failed data rows: {task.failed_data_rows}")

In [None]:
# Create a new ontology if you don't have one

# Add all unique classes detected in Step 1
tools = []
for name in set(class_names):
  tools.append(lb.Tool(tool=lb.Tool.Type.RASTER_SEGMENTATION, name=name))

ontology_builder = lb.OntologyBuilder(
    classifications=[],
    tools=tools
  )

ontology = client.create_ontology("auto-mask-classification-ontology",
                                  ontology_builder.asdict(),
                                  media_type=lb.MediaType.Image
                                  )

# Or get an existing ontology by name or ID (uncomment one of the below)

# ontology = client.get_ontologies("Demo Chair").get_one()

# ontology = client.get_ontology("clhee8kzt049v094h7stq7v25")

In [None]:
# Create a new project if you don't have one

# Project defaults to batch mode with benchmark quality settings if this argument is not provided
# Queue mode will be deprecated once dataset mode is deprecated
project = client.create_project(name="auto-mask-classification-project",
                                media_type=lb.MediaType.Image
                                )

# Or get an existing project by ID (uncomment the below)

# project = get_project("fill_in_project_id")

# If the project already has an ontology set up, comment out this line
project.setup_editor(ontology)

In [None]:
# Create a new batch of data for the project you specified above

data_row_ids = client.get_data_row_ids_for_global_keys([global_key])['results']

batch = project.create_batch(
    "auto-mask-classification-batch",  # each batch in a project must have a unique name
    data_rows=data_row_ids,

    # you can also specify global_keys instead of data_rows
    #global_keys=[global_key],  # paginated collection of data row objects, list of data row ids or global keys

    priority=1  # priority between 1(highest) - 5(lowest)
)

print(f"Batch: {batch}")

In [None]:
labels = []
labels.append(
    lb_types.Label(data=lb_types.ImageData(global_key=global_key),
                   annotations=annotations))

In [None]:
# Upload the predictions to your specified project and data rows as pre-labels

upload_job = lb.MALPredictionImport.create_from_objects(
    client=client,
    project_id=project.uid,
    name="mal_job" + str(uuid.uuid4()),
    predictions=labels
)
upload_job.wait_until_done()

print(f"Errors: {upload_job.errors}", )
print(f"Status of uploads: {upload_job.statuses}")

### Cleanup

In [None]:
#dataset.delete()
#project.delete()