In [None]:
import numpy as np
import PIL
import torch

import fiftyone as fo
import fiftyone.zoo as foz 

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

In [None]:
sam_checkpoint = "C:/Users/combus-rnd/SAM_weights/sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device)

In [None]:
dataset = foz.load_zoo_dataset(
    "open-images-v7", 
    split="validation", 
    max_samples=100,
    label_types=["detections", "points"],
    shuffle=True,
)

In [None]:
dataset.name = "openimages_sam"
dataset.persistent = True
dataset.compute_metadata()

## visualize the dataset
session = fo.launch_app(dataset)


In [None]:
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.9,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=400
)

In [None]:
def add_SAM_auto_segmentation(sample):
    image  = np.array(PIL.Image.open(sample.filepath))
    masks = mask_generator.generate(image)
    
    full_mask = np.zeros_like(masks[0]["segmentation"]).astype(int)
    for i in range(len(masks)):
        x, y = np.where(masks[i]['segmentation'])
        full_mask[x,y] = i + 1
    
    sample["auto_SAM"] = fo.Segmentation(mask=full_mask.astype(np.uint8))

In [None]:
def add_SAM_auto_segmentations(dataset):
    for sample in dataset.iter_samples(autosave=True, progress=True):
        add_SAM_auto_segmentation(sample)

In [None]:
dataset.first().points.keypoints[0]