<!-- For reference:

from ultralytics.data.annotator import auto_annotate
from ultralytics.models.sam import Predictor as SAMPredictor
from ultralytics import SAM -->

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import cv2
import numpy as np
import tlc
import torch
import tqdm
from segment_anything import SamPredictor, sam_model_registry

from tlc_tools.common import infer_torch_device

In [None]:
PROJECT_NAME = "3LC Tutorials"
MODEL_TYPE = "vit_b"
MODEL_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
CHECKPOINT = "../../transient_data/sam_vit_b_01ec64.pth"
EMBEDDING_DIM = 3
REDUCTION_METHOD = "umap"
BATCH_SIZE = 4

In [None]:
device = infer_torch_device()
print(f"Using device: {device}")

In [None]:
def bbs_to_segments(
    input_table: tlc.Table,
    sam_model_type: str = MODEL_TYPE,
    checkpoint: str = CHECKPOINT,
):
    sam_model = sam_model_registry[sam_model_type](checkpoint=checkpoint)
    sam_model.to(device)
    sam_model.eval()

    sam_predictor = SamPredictor(sam_model)

    value_map = input_table.get_value_map("bbs.bb_list.label")

    table_writer = tlc.TableWriter(
        f"{input_table.name}-sam",
        input_table.dataset_name,
        input_table.project_name,
        description="Added SAM-segmentations from bounding box prompts",
        column_schemas={
            "image": tlc.ImageUrlStringValue(),
            "segments": tlc.InstanceSegmentationMasks(
                "segments",
                instance_properties_structure={
                    "label": tlc.CategoricalLabel("label", classes=value_map),
                    "score": tlc.Schema(
                        value=tlc.Float32Value(0, 1), description="Confidence score (predicted IoU) of the segmentation"
                    ),
                },
                is_prediction=True,
            ),
        },
        input_tables=[input_table.url],
    )

    for row in tqdm.tqdm(input_table, desc="Processing rows", total=len(input_table)):
        image = cv2.imread(row["image"])
        sam_predictor.set_image(image)

        boxes = []
        labels = []

        for bb in row["bbs"]["bb_list"]:
            box_arr = np.array([bb["x0"], bb["y0"], bb["x0"] + bb["x1"], bb["y0"] + bb["y1"]])
            boxes.append(box_arr)
            labels.append(bb["label"])

        boxes = np.array(boxes)

        # Call Predictor's predict_torch instead of predict, to allow multiple box prompts
        if boxes.size > 0:
            boxes = sam_predictor.transform.apply_boxes(boxes, sam_predictor.original_size)
            boxes_torch = torch.as_tensor(boxes, dtype=torch.float, device=device)

            masks, scores, _ = sam_predictor.predict_torch(
                None,
                None,
                boxes=boxes_torch,
                multimask_output=False,
            )

            # Convert masks from (num_bbs, 1, h, w) to (h, w, num_bbs)
            segments = np.asfortranarray(masks.squeeze(1).permute(1, 2, 0).cpu().numpy().astype(np.uint8))
            scores = scores.squeeze(1).cpu().numpy().tolist()
        else:
            segments = np.zeros((0, 0, 0), dtype=np.uint8)
            scores = []

        output_row = {
            "image": row["image"],
            "segments": {
                "image_height": image.shape[0],
                "image_width": image.shape[1],
                "masks": segments,
                "instance_properties": {
                    "label": labels,
                    "score": scores,
                },
            },
        }

        table_writer.add_row(output_row)

    out_table = table_writer.finalize()
    return out_table

In [None]:
# table = tlc.Table.from_names("train", "Aerial-aerial-sheep", "RF100VL")
table = tlc.Table.from_names("initial", "COCO128", "3LC Tutorials")
out_table = bbs_to_segments(table)