# Automatic Mask Generator

The main goal of this notebook will be to assess the value of using automatic mask generation to find objects or descriptors of the images, such as the complexity of the seabed.

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from PIL import Image

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

from utils import download_model

In [None]:
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    img_area = img.shape[0] * img.shape[1]
    for ann in sorted_anns:
        if ann["area"] > img_area / 2:
            continue

        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

In [None]:
torch.cuda.empty_cache()

The following function will download the large SAM2 model's weights from here only if the folder has no model downloaded:

https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt

For all available models see here: https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints

In [None]:
download_model()

In [None]:
from hydra import initialize, core

core.global_hydra.GlobalHydra.instance().clear()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# large sam2: works on gpu > 8g
sam2_checkpoint = "../models/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
config_dir = "../models/"

# base sam2: smaller version
#sam2_checkpoint = "../../SAM2_models/checkpoints/sam2_hiera_base_plus.pt"
#model_cfg = "sam2_hiera_b+.yaml"

with initialize(version_base=None, config_path=config_dir):
    sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)

#predictor = SAM2ImagePredictor(sam2_model)

You can choose the folder from where the images to be analyzed are found.

In [None]:
image_dir = Path("../results/paparazzi_results")
print(image_dir.exists())

And here you can choose the file to be analyzed.

In [None]:
image_file = "../results/paparazzi_results/VID_01_2023_GP__0.14.45.00.jpg"

image = np.array(
    Image.open(image_file)
)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis('off')
plt.show()

In [None]:
# points_per_side: Optional[int] = 32,
# points_per_batch: int = 64,
# pred_iou_thresh: float = 0.88,
# stability_score_thresh: float = 0.95,
# stability_score_offset: float = 1.0,
# box_nms_thresh: float = 0.7,
# crop_n_layers: int = 0,
# crop_nms_thresh: float = 0.7,
# crop_overlap_ratio: float = 512 / 1500,
# crop_n_points_downscale_factor: int = 1,
# point_grids: Optional[List[np.ndarray]] = None,
# min_mask_region_area: int = 0,
# output_mode: str = "binary_mask",

mask_generator = SAM2AutomaticMaskGenerator(
    model=sam2_model,
    points_per_side=64,
    pred_iou_thresh=0.87,
    stability_score_thresh=0.80,
    crop_n_layers=1,
    crop_n_points_downscale_factor=3,
    crop_nms_thresh=0.7,
    min_mask_region_area=500,  # Requires open-cv to run post-processing
)

The following cell will run the model on the image and might take a while.

In [None]:
masks = mask_generator.generate(image)

Mask generation returns a list over masks, where each mask is a dictionary containing various data about the mask. These keys are:
* `segmentation` : the mask
* `area` : the area of the mask in pixels
* `bbox` : the boundary box of the mask in XYWH format
* `predicted_iou` : the model's own prediction for the quality of the mask
* `point_coords` : the sampled input point that generated this mask
* `stability_score` : an additional measure of mask quality
* `crop_box` : the crop of the image used to generate this mask in XYWH format

In [None]:
print(len(masks))
print(masks[0].keys())

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_anns(masks)

plt.axis('off')
plt.show()

We can import the annotations file to compare the masks with what was tagged.

In [None]:
import pandas as pd
file = image_file[:-4] + "_2704x1520.txt"

df_annotations = pd.read_csv(
        file, delimiter="\t", header=None, names=["x", "y", "label"]
    )

In [None]:
def get_masks_from_point(x, y):
    labels = []
    for label, mask in enumerate(masks):
        if mask["segmentation"][int(y), int(x)]:
            labels.append(label)
    return labels

In [None]:
df_annotations["masks"] = df_annotations.apply(lambda x: get_masks_from_point(x.x, x.y), axis=1)
df_annotations

You can use the following interactive cells to scroll through the points and see what has been detected.

In [None]:
from ipywidgets import interact
from skimage.morphology import label, remove_small_objects

In [None]:
@interact(row=np.arange(0, len(df_annotations)+1, 1))
def plot(row):
    plt.imshow(image)
    plt.scatter(df_annotations.loc[row, "x"], df_annotations.loc[row, "y"])
    for label in df_annotations.loc[row, "masks"]:
        plt.contour(masks[label]["segmentation"])

## Process masks

Several masks have been automatically segmented.
Some of these masks correspond to the seabed and some are very noisy as they correspond to out-of-focus regions of the image.
We should have a way to determine which masks correspond to each of these so we can filter them out.

In [None]:
def get_area_coverage(mask):
    return np.sum(mask) / (mask.shape[0] * mask.shape[1])


def touches_border(mask):
    border_percentage = {
        "left": np.sum(mask[:, 0]) / mask.shape[0],
        "right": np.sum(mask[:, -1]) / mask.shape[0],
        "top": np.sum(mask[0]) / mask.shape[1],
        "bottom": np.sum(mask[-1]) / mask.shape[1],
    }
    return border_percentage


def get_number_of_objects(mask, size_limit=3000):
    labeled = label(mask)
    number_of_objects = {"all": len(np.unique(labeled)) - 1}
    labeled = remove_small_objects(labeled, size_limit)
    number_of_objects["big"] = len(np.unique(labeled)) - 1
    return number_of_objects

In [None]:
for mask in masks:
    mask["area_coverage"] = mask["area"] / (mask["segmentation"].shape[0] * mask["segmentation"].shape[1])
    mask["touches_border"] = touches_border(mask["segmentation"])
    mask["number_of_objects"] = get_number_of_objects(mask["segmentation"])

In [None]:
for n, mask in enumerate(masks):
    if mask["number_of_objects"]["big"] > 1:
        print(n, mask["number_of_objects"])

You can interactively go through the masks with the following cell.

In [None]:
@interact(row=np.arange(0, len(masks), 1))
def plot_masks(row):
    plt.figure(figsize=(10, 6))
    plt.imshow(masks[row]["segmentation"])
    plt.colorbar()

The following couple cells will tag masks that are clearly off and then they can be removed from the masks variable.

In [None]:
def get_masks_to_be_removed(masks, area_threshold=0.4, max_number_big_objects=1):
    to_be_removed = []
    for n, mask, in enumerate(masks):
        if mask["area_coverage"] > area_threshold:
            to_be_removed.append(n)
        if mask["number_of_objects"]["big"] > max_number_big_objects:
            to_be_removed.append(n)
    return to_be_removed

to_be_removed = get_masks_to_be_removed(masks)
to_be_removed

In [None]:
while len(to_be_removed) > 0:
    masks.pop(to_be_removed.pop())

## Background

After removing the extra masks that correspond to mistakes or background, we can combine every mask to find an estimation of the seabed complexity.

In [None]:
background = np.zeros_like(masks[0]["segmentation"])
for mask in masks:
    background += mask["segmentation"]

By combining the background contour and the points annotated, we can roughly see where the model had problems.
Some of the points might be slightly off the label of interest, which is problematic for automatic evaluation.

In [None]:
plt.imshow(image)
plt.contour(background)
plt.scatter(df_annotations.x, df_annotations.y, s=4, c="red")