Installer nødvendige dependencies

In [None]:
!pip install -q 'git+https://github.com/facebookresearch/segment-anything.git'
!pip install -q jupyter_bbox_widget roboflow dataclasses-json supervision

Importer nødvendige pakker.

In [None]:
import os
import cv2
import torch
import numpy as np
import supervision as sv
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import matplotlib.pyplot as plt
from scipy.stats import mode
from supervision.draw.color import Color, ColorPalette


Sett opp SAM modellen.

In [None]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

#CHECKPOINT_PATH = "models/sam_vit_b_01ec64.pth"
#MODEL_TYPE = "vit_b"

#CHECKPOINT_PATH = "models/sam_vit_l_0b3195.pth"
#MODEL_TYPE = "vit_l"

CHECKPOINT_PATH = "models/sam_vit_h_4b8939.pth"
MODEL_TYPE = "vit_h"

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,                 # Controls the sampling density
    pred_iou_thresh=0.9,                # Increase to filter out low-quality masks
    stability_score_thresh=0.95,        # Increase to keep only stable masks
    stability_score_offset=1.0,         # Adjust for stability calculations
    box_nms_thresh=0.1,                 # Decrease to reduce overlapping masks
    crop_n_layers=1,                    # Reduce complexity
    crop_nms_thresh=0.5,                # Adjust NMS threshold for crops
    min_mask_region_area=5000,          # Increase to filter out small masks (in pixels)
    output_mode="binary_mask"
)

Sett opp bildet du vil segmentere og segmenter det ved bruk av sam.

In [None]:
IMAGE_PATH = "datasets/aalesund/1504200/200.jpg"
scale_percent = 30

image_bgr = cv2.imread(IMAGE_PATH)

width = int(image_bgr.shape[1] * scale_percent / 100)
height = int(image_bgr.shape[0] * scale_percent / 100)
new_dim = (width, height)

image_bgr = cv2.resize(image_bgr, new_dim, interpolation=cv2.INTER_AREA)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

sam_result = mask_generator.generate(image_rgb)

Definer hjelpe funksjoner får å sjekke om en maske er inne i en annen, og finne mest vanlige farge til masken.

In [None]:
def is_mask_inside(outer_mask, inner_mask):
    return np.all(outer_mask[inner_mask > 0])

def custom_mode(array):
    values, counts = np.unique(array, return_counts=True)
    return values[np.argmax(counts)]

def get_most_common_color(image_bgr, mask):
    mask_area = np.where(mask)
    pixels = image_bgr[mask_area]
    if pixels.size == 0 or pixels.ndim != 2 or pixels.shape[1] != 3:
        return (0, 0, 0)

    b_mode = int(custom_mode(pixels[:, 0]))
    g_mode = int(custom_mode(pixels[:, 1]))
    r_mode = int(custom_mode(pixels[:, 2]))
    return (b_mode+50, g_mode+50, r_mode+50)  

Gjør klar maskene og sett ein threshold for hvor mange masker som kan være inne i en annen før den blir fjernet.

In [None]:
masks_with_areas = [
    (i, mask['segmentation'], np.sum(mask['segmentation']))
    for i, mask in enumerate(sam_result) if np.any(mask['segmentation'])
]

masks_with_areas.sort(key=lambda x: x[2], reverse=True) 

contained_mask_threshold = int(0.5 * len(masks_with_areas))
print(f"Contained Mask Threshold: {contained_mask_threshold}")

Filtrer maskene.

In [None]:
indices_to_remove = set()

for i, (outer_idx, outer_mask, outer_area) in enumerate(masks_with_areas):
    contained_count = 0 

    for inner_idx, inner_mask, inner_area in masks_with_areas[i+1:]:
        if is_mask_inside(outer_mask, inner_mask):
            contained_count += 1 

    if contained_count >= contained_mask_threshold:
        indices_to_remove.add(outer_idx)

filtered_masks_with_areas = [
    (idx, mask, area) for idx, mask, area in masks_with_areas if idx not in indices_to_remove
]

image_area = image_bgr.shape[0] * image_bgr.shape[1]
filtered_masks_with_areas = [
    (idx, mask, area) for idx, mask, area in filtered_masks_with_areas if area < image_area
]

filtered_sam_result = [sam_result[idx] for idx, _, _ in filtered_masks_with_areas]

sorted_masks = [mask for _, mask, _ in filtered_masks_with_areas]

print(f"Number of masks after filtering: {len(filtered_masks_with_areas)}")


Generer fargepalett baser på fargene under hver maske.

In [None]:
sorted_mask_colors = [
    Color.from_bgr_tuple(get_most_common_color(image_bgr, mask)) for mask in sorted_masks
]
custom_color_palette = ColorPalette(colors=sorted_mask_colors)

Annoter bildet med de forskjellige maskene.

In [None]:
detections = sv.Detections.from_sam(sam_result=filtered_sam_result)
mask_annotator = sv.MaskAnnotator(color=custom_color_palette, opacity=0.9)

custom_color_lookup = np.arange(len(sorted_mask_colors))

try:
    annotated_image_with_custom_colors = mask_annotator.annotate(
        scene=image_bgr.copy(), 
        detections=detections,
        custom_color_lookup=custom_color_lookup
    )
except AssertionError as ae:
    print(f"Assertion error: {ae}")
except Exception as e:
    print(f"Error during annotation: {e}")


Vis det orginale og det annoterte bildet.

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))  
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(cv2.cvtColor(annotated_image_with_custom_colors, cv2.COLOR_BGR2RGB)) 
plt.title("Annotated Image with Filtered Masks")
plt.axis("off")

plt.show()
