# Paparazzi points to polygon with SAM

The goal of this notebook is to transform points saved from paparazzi into polygons by running SAM inference on those points.

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from PIL import Image

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

from utils import (
    download_model,
    get_polygon,
    show_points,
    show_res_multi,
    save_image_masks,
)

In [None]:
def get_image(annotation_name, image_dir):
    image_name = "_".join(annotation_name.split("_")[:-1])
    image_file = image_dir.joinpath(image_name + ".jpg")
    if image_file.exists():
        test_image = Image.open(image_file)
        return np.array(test_image), image_file
    else:
        return None, None

The following function will download the large SAM2 model's weights from here only if the folder has no model downloaded:

https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt

For all available models see here: https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints

In [None]:
download_model()

In [None]:
from hydra import initialize, core

core.global_hydra.GlobalHydra.instance().clear()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# large sam2: works on gpu > 8g
sam2_checkpoint = "../models/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
config_dir = "../models/"

# base sam2: smaller version
#sam2_checkpoint = "../../SAM2_models/checkpoints/sam2_hiera_base_plus.pt"
#model_cfg = "sam2_hiera_b+.yaml"

with initialize(version_base=None, config_path=config_dir):
    sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

predictor = SAM2ImagePredictor(sam2_model)

We should first check that the files have been downloaded and are in the proper folders.

In [None]:
image_dir = Path("../data/20240813/VID_05_GP_50/")
print(image_dir.exists())


Finally, we can loop through every image and use the points to prompt SAM2 and generate masks.

In [None]:
results_dir = Path("../results/paparazzi_results")
results_dir.mkdir(exist_ok=True)

annotation_files = image_dir.glob("Annotations/*.txt")
print(annotation_files)

for file in annotation_files:
    print(f"\nProcessing {file}...")
    image, image_file = get_image(file.stem, image_dir)
    if image is None:
        continue

    df_annotations = pd.read_csv(
        file,
        delimiter="\t",
        header=None,
        names=["x", "y", "label"]
    )
    # get SAM prediction for each row in point annotation file
    print(f"getting prediction for {len(df_annotations)} point prompts...")
    all_prompts = []
    all_masks = []
    all_scores = []
    predictor.set_image(image)
    for i, row in df_annotations.iterrows():
        prompt = np.array([[row["x"], row["y"]]])

        # predict masks using SAM2
        masks, scores, _ = predictor.predict(
            point_coords=prompt,
            point_labels=np.ones(1),
            multimask_output=False,
        )
        sorted_ind = np.argsort(scores)[::-1]
        masks = masks[sorted_ind]
        scores = scores[sorted_ind]

        all_prompts.append(prompt[0])
        all_masks.append(masks)
        all_scores.append(scores)

    # show the results
    fig, ax = plt.subplots(1, 1, figsize=(9, 8))
    show_res_multi(all_masks, all_scores, image=image, input_box=None, ax=ax)
    show_points(np.vstack(all_prompts), np.ones(len(all_prompts)), ax=ax)
    plt.show()

    # save the masks
    print("saving predicted masks...")
    save_image_masks(all_masks, image_file.stem, results_dir)
    df_result = pd.DataFrame(columns=[
        "image_file", "prompt", "mask_id", "polygon"
    ])

    save_dir = results_dir.joinpath(image_file.stem)
    save_dir.mkdir(parents=True, exist_ok=True)
    for i, mask in enumerate(all_masks):
        # mask is 3D: 1, y, x
        mask_np = mask[0].astype(np.uint8) * 255
        mask_img = Image.fromarray(mask_np)
        mask_id = i + 1
        mask_img.save(save_dir.joinpath(f"{mask_id:03d}.png"))
        #
        df_result.loc[i, "image_file"] = image_file.name
        df_result.loc[i, "prompt"] = str([row["x"], row["y"]])
        df_result.loc[i, "mask_id"] = f"{mask_id:03d}"
        polygon = get_polygon(mask_np)
        df_result.loc[i, "polygon"] = str(polygon.ravel().tolist())
    df_result.to_csv(save_dir.joinpath("polygon_masks.csv"), index=False)

print("\n\nDone!")