In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from segment_anything import sam_model_registry, SamPredictor
from PIL import Image

from utils import (
    get_point_coord,
    get_polygon,
    show_points,
    show_mask,
    show_res_multi,
)

In [None]:
def get_image(annotation_name, image_dir):
    image_name = "_".join(annotation_name.split("_")[:-1])
    image_file = image_dir.joinpath(image_name + ".jpg")
    if image_file.exists():
        test_image = Image.open(image_file)
        return np.array(test_image), image_file
    else:
        return None, None


def save_image_masks(masks, image_name, results_dir):
    save_dir = results_dir.joinpath(image_name)
    save_dir.mkdir(parents=True, exist_ok=True)
    for i, mask in enumerate(masks):
        # mask is 3D: 1, y, x
        mask_img = mask[0].astype(np.uint8) * 255
        mask_img = Image.fromarray(mask_img)
        mask_img.save(save_dir.joinpath(f"{i:03d}.png"))

Download the large SAM model's weights from here:
https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

For all available models see here: https://github.com/facebookresearch/segment-anything?tab=readme-ov-file#model-checkpoints

In [None]:
sam_checkpoint = "../results/SAM_models/sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam.to(device=device)

predictor = SamPredictor(sam)

In [None]:
image_dir = Path("../results/paparazzi_results")
print(image_dir.exists())

In [None]:
results_dir = Path("../results/paparazzi_results")
results_dir.mkdir(exist_ok=True)

In [None]:
annotation_files = image_dir.glob("*.txt")

for file in annotation_files:
    print(f"\nProcessing {file}...")
    image, image_file = get_image(file.stem, image_dir)
    if image is None:
        continue

    df_annotations = pd.read_csv(
        file,
        delimiter="\t",
        header=None,
        names=["x", "y", "label"]
    )
    # get SAM prediction for each row in point annotation file
    print(f"getting prediction for {len(df_annotations)} point prompts...")
    all_prompts = []
    all_masks = []
    all_scores = []
    predictor.set_image(image)
    for i, row in df_annotations.iterrows():
        prompt = np.array([[row["x"], row["y"]]])
        masks, scores, logits = predictor.predict(
            point_coords=prompt,
            point_labels=np.ones(1),
            multimask_output=False,
        )
        all_prompts.append(prompt[0])
        all_masks.append(masks)
        all_scores.append(scores)
    
    # show the results
    fig, ax = plt.subplots(1, 1, figsize=(9, 8))
    show_res_multi(all_masks, all_scores, image=image, input_box=None, ax=ax)
    show_points(np.vstack(all_prompts), np.ones(len(all_prompts)), ax=ax)
    plt.show()
    
    # save the masks
    print("saving predicted masks...")
    # save_image_masks(all_masks, image_file.stem, results_dir)
    df_result = pd.DataFrame(columns=[
        "image_file", "prompt", "mask_id", "polygon"
    ])
    save_dir = results_dir.joinpath(image_file.stem)
    save_dir.mkdir(parents=True, exist_ok=True)
    for i, mask in enumerate(all_masks):
        # mask is 3D: 1, y, x
        mask_np = mask[0].astype(np.uint8) * 255
        mask_img = Image.fromarray(mask_np)
        mask_id = i + 1
        mask_img.save(save_dir.joinpath(f"{mask_id:03d}.png"))
        #
        df_result.loc[i, "image_file"] = image_file.name
        df_result.loc[i, "prompt"] = str([row["x"], row["y"]])
        df_result.loc[i, "mask_id"] = f"{mask_id:03d}"
        polygon = get_polygon(mask_np)
        df_result.loc[i, "polygon"] = str(polygon.ravel().tolist())
    df_result.to_csv(save_dir.joinpath("polygon_masks.csv"), index=False)

print("Done!")