In [None]:
import os

os.environ["BOKEH_ALLOW_WS_ORIGIN"] = "*"

In [None]:
import json
import functools
import datetime as dt
from pathlib import Path

import torch
import numpy as np
import pandas as pd
import ipywidgets as ipw
from IPython.display import display
from PIL import Image
from tqdm.notebook import tqdm

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

import utils

In [None]:
from bokeh.plotting import figure
from bokeh import layouts as bklayouts
from bokeh.themes import Theme
from bokeh.document import Document
from bokeh.io import show, push_notebook, output_notebook, curdoc
from bokeh.events import Tap
from bokeh import models as bkmodels

output_notebook()

In [None]:
def get_image(annotation_name, image_dir):
    image_name = "_".join(annotation_name.split("_")[:-1])
    image_file = image_dir.joinpath(image_name + ".jpg")
    if image_file.exists():
        return image_file
    else:
        return None

## Initialize SAM2 Model

The following function will download the large SAM2 model's weights from here only if the folder has no model downloaded:

https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt

For all available models see here: https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints

In [None]:
utils.download_model()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# large sam2: works on gpu > 8g
sam2_checkpoint = "../models/sam2_hiera_large.pt"
model_cfg = str(Path("../models/sam2_hiera_l.yaml").absolute())

# base sam2: smaller version
# sam2_checkpoint = "./SAM2_models/sam2.1_hiera_base_plus.pt"
# model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

predictor = SAM2ImagePredictor(sam2_model)

In [None]:
def get_sam_masks(predictor, image, point_prompts):
    masks = []
    predictor.set_image(image)
    for prompt in tqdm(point_prompts, desc="getting masks for prompts", leave=False):
        # predict masks using SAM2
        mask, _, _ = predictor.predict(
            point_coords=prompt[np.newaxis, :],
            point_labels=np.ones(1),
            multimask_output=False,
        )
        masks.append(mask[0])

    return np.array(masks)


## Create categories from a BIIGLE Label Tree export
Set the path to the BIIGLE exported file

In [None]:
label_tree_file = "../data/labels.csv"

df_labels = pd.read_csv(label_tree_file)

In [None]:
# find the final species labels (not parents in the phylogeny tree)
parents = df_labels["id"].isin(df_labels["parent_id"])
df_labels["is_species"] = ~parents
print(f"number of species: {df_labels['is_species'].sum()}")

# add "None" category
df_labels = pd.concat([
    pd.DataFrame([[-1, "None", np.nan, "000000", -1, -1, True]], columns=df_labels.columns),
    df_labels
], ignore_index=True)

df_labels

In [None]:
categories = []

for i, df_row in df_labels.iterrows():
    parent = ""
    if not np.isnan(df_row["parent_id"]):
        parent_data = df_labels[df_labels["id"] == int(df_row["parent_id"])]
        parent = parent_data["name"].tolist()[0]

    cat = utils.get_coco_category(
        supercategory=parent,
        cat_id=df_row["id"],
        name=df_row["name"]
    )
    cat["color"] = df_row["color"]  # add color
    cat["is_species"] = df_row["is_species"]
    categories.append(cat)

df_cats = pd.DataFrame(categories)
df_cats

## Input Images
Set the input image directory:

In [None]:
image_dir = Path("../data/paparazzi")
print(image_dir.absolute())

assert image_dir.exists(), "Couldn't find the image directory"

## Input Point Prompts
Set the input point prompt directory (if not set, will use the image directory):

In [None]:
# prompt_dir = Path("./data/paparazzi")
prompt_dir = None
if prompt_dir is None:
    prompt_dir = image_dir

print(prompt_dir.absolute())

assert prompt_dir.exists(), "Couldn't find the prompt directory"

In [None]:
# get images and prompts
point_annotation_files = list(prompt_dir.glob("*.txt"))
image_files = [
    get_image(file.stem, image_dir)
    for file in point_annotation_files
]
image_files = [f for f in image_files if f is not None]  # remove not found images

print(f"number of images with prompt: {len(image_files)}")

## Interactive UI for Label Assignments

In [None]:
def get_bokeh_image(img_file):
    img = Image.open(img_file).convert("RGBA")
    img = np.array(img)
    h, w = img.shape[:2]
    return img.view(dtype=np.uint32).reshape((h, w)), h, w

In [None]:
# Plot Datasources
def get_datasources(image_files):
    img, h, w = get_bokeh_image(image_files[0].absolute())
    image_ds = bkmodels.ColumnDataSource(data={
        "path": [str(image_files[0].absolute())],
        "image": [img],
        "h": [h],
        "w": [w]
    })
    mask_ds = bkmodels.ColumnDataSource(data={
        "label": [],
        "cat_id": [],
        "xs": [],
        "ys": [],
        "color": [],
    })

    return image_ds, mask_ds

In [None]:
# status output widget
output = ipw.Output(layout={
    "border": "1px solid dodgerblue",
    "height": "160px",
    "overflow": "scroll",
    "font-size": "12px"
})

In [None]:
def match_labels(df_prompts, df_cats):
    # match the prompt label with the one from the categories
    cat_names = []
    cat_ids = []
    colors = []
    # df_species = df_cats[df_cats["is_species"]]

    for label in df_prompts["label"]:
        # make a list of possible categories given the label
        possible_cats = [label] + label.split(" ")
        where = df_cats["name"].isin(possible_cats)
        cat_name = df_cats["name"][where].tolist()
        if len(cat_name) > 0:
            cat_names.append(cat_name[0])
            cat_ids.append(int(
                df_cats[where]["id"].tolist()[0]
            ))
            hex = df_cats[where]["color"].tolist()[0]
            colors.append(utils.hex2rgb(hex))
        else:
            cat_names.append(label)
            cat_ids.append(None)
            colors.append((0, 0, 0, 0.5))

    df_prompts["label"] = cat_names
    df_prompts["cat_id"] = cat_ids
    df_prompts["color"] = colors

    return df_prompts

In [None]:
def create_plot():
    plot = figure(
        name="main_plot",
        # title="SAM2 Masks Label Assignement",
        width=1000,
        height=570,
        output_backend="webgl",
        sizing_mode="stretch_width",
        match_aspect=True,
        # aspect_scale=1,
        aspect_ratio="auto",
        toolbar_location="above",
        tools="pan,wheel_zoom,box_zoom,reset,box_select",
        active_drag="pan",
        active_scroll=None,
        # tooltips=[("Name", "@names"),],
    )
    plot.background_fill_color = (80, 80, 80)
    plot.border_fill_color = (230, 230, 230)
    plot.xaxis.axis_line_color = (127, 127, 127)
    plot.yaxis.axis_line_color = (127, 127, 127)
    plot.xgrid.grid_line_color = (50, 50, 50, 0.5)
    plot.ygrid.grid_line_color = (50, 50, 50, 0.5)
    plot.y_range.flipped = True
    plot.y_range.bounds = (0, None)
    plot.x_range.bounds = (0, None)
    #
    plot.title.text_font_size = "18px"
    plot.title.align = "center"
    plot.title.padding = 5
    #
    plot.toolbar.logo = None
    tap = bkmodels.TapTool()
    tap.visible = False
    plot.add_tools(tap)
    hover = bkmodels.HoverTool(tooltips=[("Name", "@name"),])
    hover.visible = False
    plot.add_tools(hover)

    return plot


def plot_masks(plot, mask_ds: bkmodels.ColumnDataSource):
    mask_renderer = plot.patches(
        name="masks",
        source=mask_ds,
        xs="xs", ys="ys",
        fill_color={"field": "color"},
        fill_alpha=0.55, line_color="white", line_width=2,
        # selection_fill_color=(255, 255, 75, 0.5),
    )
    mask_renderer.nonselection_glyph = mask_renderer.glyph.clone(fill_alpha=0.25)
    mask_renderer.hover_glyph = mask_renderer.glyph.clone(
        fill_color=(255, 45, 255, 0.75),
        line_color="yellow",
        line_width=4,
    )
    mask_renderer.selection_glyph = mask_renderer.glyph.clone(
        fill_color=(255, 45, 255, 0.85),
        line_color="yellow",
        line_width=3,
    )
    plot.hover.renderers = [mask_renderer]
    plot.hover.tooltips = [("Label", "@label"), ("Category ID", "@cat_id")]

In [None]:
# widget events

def selected_image_on_change(attr, old, new, image_ds):
    if new is not None:
        img_idx = int(new)
        img, h, w = get_bokeh_image(image_files[img_idx].absolute())
        data = {
            "path": [str(image_files[img_idx].absolute())],
            "image": [img],
            "h": [h],
            "w": [w]
        }
        image_ds.data.update(data)
        # image_ds.document.add_next_tick_callback(
        #     lambda: image_ds.data.update(data)
        # )


@output.capture(clear_output=True, wait=True)
def get_image_masks(button, image_dd, mask_ds, df_categories):
    image_index = image_dd.value
    # print(image_index)
    img = Image.open(image_files[image_index]).convert("RGB")
    img = np.array(img)

    df_prompts = utils.load_prompt_data(point_annotation_files[image_index], is_paparazzi=True)
    # match labels with categories
    df_prompts = match_labels(df_prompts, df_categories)
    prompts = df_prompts[["x", "y"]].to_numpy()
    print(f"number of extracted prompts: {len(prompts)}")

    # run sam predictor
    masks = get_sam_masks(
        predictor, img,
        prompts
    )
    # update masks datasource
    polygons = [
        utils.get_polygon(mask)
        for mask in masks
    ]
    mask_data = dict(
        label=df_prompts["label"].to_list(),
        cat_id = df_prompts["cat_id"].to_list(),
        xs=[poly[:, 0] for poly in polygons],
        ys=[poly[:, 1] for poly in polygons],
        color=df_prompts["color"].to_list(),
    )
    mask_ds.data.update(mask_data)

    print("\rSAM predictor is done!")


def toggle_masks(active):
    masks_renderer = curdoc().select({"name": "masks"})
    if masks_renderer:
        masks_renderer[0].visible = not active


@output.capture(clear_output=False, wait=True)
def assign_label(
    button, mask_ds: bkmodels.ColumnDataSource,
    cat_dropdown: bkmodels.Select, df_categories: pd.DataFrame
):
    # print(cat_dropdown.value)
    # print(mask_ds.selected.indices)
    if len(mask_ds.selected.indices) == 0:
        print("No mask was selected!")
        return
    # assign the selected category to the selected masks
    selected_label = cat_dropdown.value
    selected_indices = mask_ds.selected.indices
    new_label = df_categories.iloc[selected_label]["name"]
    new_cat_id = df_categories.iloc[selected_label]["id"]
    patches = {}
    patches["label"] = [
        (index, new_label)
        for index in selected_indices
    ]
    patches["color"] = [
        (index, utils.hex2rgb(df_categories.iloc[selected_label]["color"]))
        for index in selected_indices
    ]
    patches["cat_id"] = [
        (index, new_cat_id)
        for index in selected_indices
    ]
    mask_ds.patch(patches)
    print(f"{len(selected_indices)} mask labels was changed to {new_label} with category id {new_cat_id}.")


@output.capture(clear_output=False, wait=True)
def export_as_coco(
    button, image_dd,
    df_categories: pd.DataFrame,
    image_ds: bkmodels.ColumnDataSource,
    mask_ds: bkmodels.ColumnDataSource
):
    # check if there are masks with labels assigned
    if len(mask_ds.data["label"]) == 0:
        print("There no masks with assigned labels!")
        return

    # curr_image_idx = image_dd.value
    coco_annotations = utils.convert_to_coco(df_categories, 0, image_ds, mask_ds)
    # save the coco annotations to a json file
    json_file = Path("./coco_results/paparazzi")
    json_file.mkdir(exist_ok=True, parents=True)
    img_name = coco_annotations["images"][0]["file_name"]
    json_file = json_file.joinpath(f"{img_name}_coco_annotations.json")
    with open(json_file, mode="w") as f:
        json.dump(coco_annotations, f, indent=4)
    
    print(f"Annotations were saved to {json_file.absolute()}")

In [None]:
def create_tools_layout(image_ds, mask_ds, df_categories):
    # image dropdown
    img_dropdown = bkmodels.Select(
        name="img_dropdown",
        title="Input Image:",
        options=[(i, f.name) for i, f in enumerate(image_files)],
        value=0
    )
    img_dropdown.on_change(
        "value",
        functools.partial(
            selected_image_on_change, image_ds=image_ds
        )
    )
    # get mask button
    run_sam_button = bkmodels.Button(
        name="sam_button",
        label="Get Masks",
        button_type="primary",
        align="end",
        width=120,
        height=30,
    )
    run_sam_button.on_click(
        functools.partial(
            get_image_masks,
            image_dd=img_dropdown, mask_ds=mask_ds, df_categories=df_categories
        )
    )
    # hide masks toggle button
    toggle_masks_button = bkmodels.Toggle(
        name="toggle_masks_button",
        button_type="default",
        label="Hide Masks",
        active=False,
        align="end",
        width=120,
        height=30,
    )
    toggle_masks_button.on_click(toggle_masks)
    
    # category dropdown
    # species = df_categories["name"][df_categories["is_species"]].to_list()
    # species.sort()
    cat_dropdown = bkmodels.Select(
        name="category_dropdown",
        title="Selected Mask Label:",
        options=[(i, n) for i, n in enumerate(df_categories["name"].to_list())],
        value=0
    )
    # assign label button
    assign_button = bkmodels.Button(
        name="assign_button",
        label="Assign the Label",
        button_type="primary",
        align="end",
        width=120,
        height=30,
    )
    assign_tooltip = bkmodels.Tooltip(
        content="Assign the selected label to the selected masks",
        position="right", visible=True
    )
    assign_help_button = bkmodels.HelpButton(
        tooltip=assign_tooltip, button_type="light",
        align="end", height=30, margin=(5, 5, 5, -5)
    )
    assign_button.on_click(
        functools.partial(
            assign_label,
            mask_ds=mask_ds,
            cat_dropdown=cat_dropdown,
            df_categories=df_categories
        )
    )

    # export button
    export_button = bkmodels.Button(
        name="export_button",
        label="Export to COCO",
        button_type="success",
        align="end",
        width=120,
        height=30,
    )
    export_button.on_click(
        functools.partial(
            export_as_coco,
            image_dd=img_dropdown,
            df_categories=df_categories,
            image_ds=image_ds,
            mask_ds=mask_ds, 
        )
    )

    # layout for the top panel
    top_layout =bklayouts.grid(
        children=[
            bklayouts.row(
                bklayouts.column(
                    bklayouts.row(
                        img_dropdown, run_sam_button, toggle_masks_button, spacing=5)
                ),
                bklayouts.column(
                    bklayouts.row(cat_dropdown, assign_button, assign_help_button, export_button, spacing=5)
                )
            , spacing=70),
        ]
    )

    return top_layout


In [None]:
# bokeh app
def bokeh_app(doc: Document, image_files, df_categories):
    image_ds, mask_ds = get_datasources(image_files)
    plot = create_plot()
    # plot the image
    plot.image_rgba(
        source=image_ds,
        image="image",
        x=0, y=0, dw="w", dh="h",
    )
    # plot masks
    plot_masks(plot, mask_ds)

    # layout
    top_layout = create_tools_layout(image_ds, mask_ds, df_categories)
    ui_layout = bklayouts.grid(children=[
        top_layout,
        bklayouts.row(plot),
    ], sizing_mode="stretch_width")

    doc.add_root(ui_layout)


## Run the App

In [None]:
show(
    functools.partial(bokeh_app, image_files=image_files, df_categories=df_cats),
    notebook_handle=True
)

output.clear_output()
display(output)