In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch

# from segment_anything import build_sam_vit_h
from segment_anything import sam_model_registry, SamPredictor
from PIL import Image

from utils import (
    get_point_coord,
    get_polygon,
    show_points,
    show_res_multi,
    show_res
)

In [None]:
def get_image(image_name, image_dir):
    # some image files have <space> in their names,
    # but <space> replaced by <_> in the csv table.
    image_file = image_dir.joinpath(image_name)
    if not image_file.exists():
        # try to find the image
        image_file = None
        for file in image_dir.glob("*.jpg"):
            if image_name.replace("_", " ") == file.name.replace("_", " "):
                image_file = image_dir.joinpath(file.name)
                break

    if image_file is not None:
        test_image = Image.open(image_file)
        return np.array(test_image)
    else:
        return None

def save_image_masks(masks, image_name, results_dir):
    save_dir = results_dir.joinpath(image_name)
    save_dir.mkdir(parents=True, exist_ok=True)
    for i, mask in enumerate(masks):
        # mask is 3D: 1, y, x
        mask_img = mask[0].astype(np.uint8) * 255
        mask_img = Image.fromarray(mask_img)
        mask_img.save(save_dir.joinpath(f"{i:03d}.png"))

In [None]:
#image_file = "../../data/20240813_data/VID_01_GP_50/VID_01_2023_GP__0.14.45.00.jpg"
#annotation_file = "../../data/training_data/1/Annotations/14717-training-images-1.csv"
#image = np.array(Image.open(image_file))

image_dir = Path("../../data/training_data/1")
print(image_dir.exists())

annotation_file = Path(
    "../../data/training_data/1/Annotations/14717-training-images-1.csv"
)
print(annotation_file.exists())

In [None]:
df_annotations = pd.read_csv(annotation_file)
df_annotations

In [None]:
df_points = df_annotations[df_annotations["shape_name"] == "Point"]
df_points

In [None]:
df_points.groupby("filename").count()

Download the large SAM model's weights from here:
https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

For all available models see here: https://github.com/facebookresearch/segment-anything?tab=readme-ov-file#model-checkpoints

In [None]:
sam_checkpoint = "../../SAM_models/sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam.to(device=device)

predictor = SamPredictor(sam)

In [None]:
user_id = 2813
first_name = "Nils"
last_name = "Jacobsen"
last_ann_label_id = int(df_annotations["annotation_label_id"].max())
last_label_id = int(df_annotations["label_id"].max())
last_annotation_id = int(df_annotations["annotation_id"].max())
polygon_id = 3

df_new_annotations = df_annotations.copy()
results_dir = Path("../../results/training_data/1")
results_dir.mkdir(exist_ok=True)

print(last_ann_label_id)
print(last_label_id)
print(last_annotation_id)

In [None]:
for image_name, row_locations in df_points.groupby("filename").groups.items():
    print(f"\nRow locations: {row_locations}")
    print(f"\nProcessing {image_name}")
    if len(row_locations) < 3:
        # the image has only the Laser Points
        print("the image has only the laser points")
        continue

    # load the image
    test_image = get_image(image_name, image_dir)
    
    # get point coordinates
    point_prompts = (
        df_points.loc[row_locations]["points"].apply(get_point_coord).to_list()
    )
    # get label names
    species = (
        df_points.loc[row_locations]["label_name"].to_list()
    )
    point_prompts = np.array(point_prompts, dtype=np.float32)
    prompt_labels = np.ones(len(point_prompts))  # positive prompt

    # plot image + points
    fig, ax = plt.subplots(1, 1, figsize=(9, 8))
    ax.imshow(test_image)
    show_points(point_prompts, prompt_labels, ax, marker_size=15)
    plt.show()

    # get the SAM predictions for each point
    print(f"getting predictions for {len(point_prompts)} point prompts...")
    all_masks = []
    all_scores = []
    all_species = []
    predictor.set_image(test_image)
    for i in range(len(point_prompts)):
        masks, scores, logits = predictor.predict(
            point_coords=point_prompts[i : i + 1],
            point_labels=prompt_labels[i : i + 1],
            multimask_output=False,
        )
        all_masks.append(masks)
        all_scores.append(scores)
        all_species.append(species[i : i + 1])
        
    # show the results
    #show_res_multi(all_masks, all_scores, image=test_image, input_box=None, saveit = True)
    show_res(all_masks, all_scores, all_species, input_points=point_prompts, input_labels=prompt_labels, input_box=None, image=test_image, saveit = True)

    # save all masks
    image_id = int(df_points.loc[row_locations[0], "image_id"])
    print(image_id)
    filename = df_points.loc[row_locations[0], "filename"]
    print(filename)
    image_file = image_dir.joinpath(filename)
    print(image_file)
    
    #save_masks_on_images(all_masks, test_image, image_file.stem, results_dir)
    save_image_masks(all_masks, image_file.stem, results_dir)

    # mask to polygons
    print("getting polygons from masks...")
    for i, masks in enumerate(all_masks):
        mask = masks[0]
        polygon = get_polygon(mask)
        # add an annotation row into the annotation csv table
        last_ann_label_id += 1
        last_label_id += 1
        last_annotation_id += 1
        row_idx = len(df_new_annotations.index)
        df_new_annotations.loc[row_idx, "image_id"] = image_id
        df_new_annotations.loc[row_idx, "filename"] = filename
        df_new_annotations.loc[row_idx, "user_id"] = user_id
        df_new_annotations.loc[row_idx, "firstname"] = first_name
        df_new_annotations.loc[row_idx, "lastname"] = last_name
        df_new_annotations.loc[row_idx, "annotation_label_id"] = last_ann_label_id
        df_new_annotations.loc[row_idx, "label_id"] = last_label_id
        df_new_annotations.loc[row_idx, "annotation_id"] = last_annotation_id
        df_new_annotations.loc[row_idx, "label_hierarchy"] = df_points.loc[
            row_locations[i], "label_hierarchy"
        ]
        df_new_annotations.loc[row_idx, "label_name"] = df_points.loc[
            row_locations[i], "label_name"
        ]
        df_new_annotations.loc[row_idx, "shape_id"] = polygon_id
        df_new_annotations.loc[row_idx, "shape_name"] = "Polygon"
        df_new_annotations.loc[row_idx, "points"] = str(polygon.ravel().tolist())


print("\n\nDone!")

In [None]:
df_new_annotations

In [None]:
df_new_annotations.to_csv(results_dir.joinpath("new_annotation.csv"), index=False)