<!-- For reference:

from ultralytics.data.annotator import auto_annotate
from ultralytics.models.sam import Predictor as SAMPredictor
from ultralytics import SAM -->

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import cv2
import numpy as np
import tlc
import tqdm
from segment_anything import SamPredictor, sam_model_registry

from tlc_tools.common import infer_torch_device

In [None]:
PROJECT_NAME = "3LC Tutorials"
MODEL_TYPE = "vit_b"
MODEL_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
CHECKPOINT = "../../transient_data/sam_vit_b_01ec64.pth"
EMBEDDING_DIM = 3
REDUCTION_METHOD = "umap"
BATCH_SIZE = 4

In [None]:
device = infer_torch_device()
print(f"Using device: {device}")

In [None]:
def bbs_to_segments(
    input_table: tlc.Table,
    sam_model_type: str = MODEL_TYPE,
    checkpoint: str = CHECKPOINT,
):
    sam_model = sam_model_registry[sam_model_type](checkpoint=checkpoint)
    sam_model.to(device)
    sam_model.eval()

    sam_predictor = SamPredictor(sam_model)

    value_map = input_table.get_value_map("bbs.bb_list.label")

    table_writer = tlc.TableWriter(
        f"{input_table.name}-sam",
        input_table.dataset_name,
        input_table.project_name,
        description="Added SAM-segmentations from bounding box prompts",
        column_schemas={
            "image": tlc.ImageUrlStringValue(),
            "segments": tlc.InstanceSegmentationMasks(
                "segments",
                instance_properties_structure={
                    "label": tlc.CategoricalLabel("label", classes=value_map),
                    "score": tlc.Schema(
                        value=tlc.Float32Value(0, 1), description="Confidence score (predicted IoU) of the segmentation"
                    ),
                },
                is_prediction=True,
            ),
        },
        input_tables=[input_table.url],
    )

    for row in tqdm.tqdm(input_table, desc="Processing rows", total=len(input_table)):
        segments = []
        labels = []
        confidence = []

        image = cv2.imread(row["image"])
        sam_predictor.set_image(image)

        for bb in row["bbs"]["bb_list"]:
            box_arr = np.array([bb["x0"], bb["y0"], bb["x0"] + bb["x1"], bb["y0"] + bb["y1"]])

            masks, scores, _ = sam_predictor.predict(
                box=box_arr,
                multimask_output=False,
            )

            segments.append(masks[0].astype(np.uint8))
            labels.append(bb["label"])
            confidence.append(scores[0])

        if len(segments) > 0:
            stacked_masks = np.asfortranarray(np.stack(segments, axis=-1))
        else:
            stacked_masks = np.zeros((0, 0, 0), dtype=np.uint8)

        output_row = {
            "image": row["image"],
            "segments": {
                "image_height": image.shape[0],
                "image_width": image.shape[1],
                "masks": stacked_masks,
                "instance_properties": {
                    "label": labels,
                    "score": confidence,
                },
            },
        }

        table_writer.add_row(output_row)

    out_table = table_writer.finalize()
    return out_table

In [None]:
# table = tlc.Table.from_names("train", "Aerial-aerial-sheep", "RF100VL")
table = tlc.Table.from_names("initial", "COCO128", "3LC Tutorials")
out_table = bbs_to_segments(table)