In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from PIL import Image

In [None]:
def get_image(annotation_name, image_dir):
    image_name = "_".join(annotation_name.split("_")[:-1])
    image_file = image_dir.joinpath(image_name + ".jpg")
    if image_file.exists():
        image = Image.open(image_file)
        return np.array(image), image_file
    else:
        return None, None


def save_image_masks(masks, image_name, results_dir):
    save_dir = results_dir.joinpath(image_name)
    save_dir.mkdir(parents=True, exist_ok=True)
    for i, mask in enumerate(masks):
        # mask is 3D: 1, y, x
        mask_img = mask[0].astype(np.uint8) * 255
        mask_img = Image.fromarray(mask_img)
        mask_img.save(save_dir.joinpath(f"{i:03d}.png"))

In [None]:
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones(
        (
            sorted_anns[0]["segmentation"].shape[0],
            sorted_anns[0]["segmentation"].shape[1],
            4,
        )
    )
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann["segmentation"]
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

In [None]:
torch.cuda.empty_cache()

In [None]:
sam_checkpoint = "../results/SAM_models/sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam.to(device=device)

# predictor = SamPredictor(sam)

In [None]:
image_dir = Path("../results/paparazzi_results")
print(image_dir.exists())

In [None]:
image_file = "../results/paparazzi_results/VID_01_2023_GP__0.14.45.00.jpg"

image = np.array(Image.open(image_file))

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis("off")
plt.show()

In [None]:
# points_per_side: Optional[int] = 32,
# points_per_batch: int = 64,
# pred_iou_thresh: float = 0.88,
# stability_score_thresh: float = 0.95,
# stability_score_offset: float = 1.0,
# box_nms_thresh: float = 0.7,
# crop_n_layers: int = 0,
# crop_nms_thresh: float = 0.7,
# crop_overlap_ratio: float = 512 / 1500,
# crop_n_points_downscale_factor: int = 1,
# point_grids: Optional[List[np.ndarray]] = None,
# min_mask_region_area: int = 0,
# output_mode: str = "binary_mask",

mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=40,
    pred_iou_thresh=0.87,
    stability_score_thresh=0.80,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    crop_nms_thresh=0.7,
    min_mask_region_area=500,  # Requires open-cv to run post-processing
)

In [None]:
masks = mask_generator.generate(image)

Mask generation returns a list over masks, where each mask is a dictionary containing various data about the mask. These keys are:
* `segmentation` : the mask
* `area` : the area of the mask in pixels
* `bbox` : the boundary box of the mask in XYWH format
* `predicted_iou` : the model's own prediction for the quality of the mask
* `point_coords` : the sampled input point that generated this mask
* `stability_score` : an additional measure of mask quality
* `crop_box` : the crop of the image used to generate this mask in XYWH format

In [None]:
print(len(masks))
print(masks[0].keys())

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_anns(masks)
plt.axis("off")
plt.show()