Importer dependencies

In [92]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import os
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from hydra import initialize
from hydra.core.global_hydra import GlobalHydra
from supervision import Detections
from supervision.draw.color import ColorPalette
import supervision as sv
import image_slicer
from shapely.geometry import Polygon


Configure PyTorch

In [93]:
# Use torch.no_grad() to disable gradient computations
torch.autograd.set_grad_enabled(False)

# Check if CUDA is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Enable benchmark mode in cudnn for performance
torch.backends.cudnn.benchmark = True


Function definitions

In [94]:
def adjust_mask_coordinates(mask, x_offset, y_offset, image_shape):
    """
    Adjusts the mask coordinates to the original image coordinates.

    Args:
        mask (numpy.ndarray): The binary mask.
        x_offset (int): The x-coordinate offset of the tile.
        y_offset (int): The y-coordinate offset of the tile.
        image_shape (tuple): The shape of the original image.

    Returns:
        numpy.ndarray: The adjusted mask.
    """
    full_mask = np.zeros(image_shape[:2], dtype=mask.dtype)
    mask_height, mask_width = mask.shape
    full_mask[y_offset:y_offset + mask_height, x_offset:x_offset + mask_width] = mask
    return full_mask


Load and preprocess image

In [95]:
# Load the image
image_path = 'datasets/aalesund/1504201/201.jpg'
image_bgr = cv2.imread(image_path)

# Convert to RGB for processing
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

# Get image dimensions
image_height, image_width = image_rgb.shape[:2]

Import and initialize SAM2 model

In [96]:
# Paths to configuration and checkpoint files
config_file_path = "./sam2.1_hiera_l.yaml"
checkpoint_file_path = "./sam2.1_hiera_large.pt"

# Verify file existence
assert os.path.exists(config_file_path), f"Config file not found at {config_file_path}"
assert os.path.exists(checkpoint_file_path), f"Checkpoint file not found at {checkpoint_file_path}"

# Clear any existing Hydra instances
GlobalHydra.instance().clear()

# Initialize Hydra and build the model
with initialize(config_path="."):
    sam2_model = build_sam2(config_file=config_file_path, ckpt_path=checkpoint_file_path).to(device)

# Set model to evaluation mode
sam2_model.eval()

# Create the mask generator with optimized parameters
mask_generator = SAM2AutomaticMaskGenerator(
    model=sam2_model,
    points_per_side=32,  # Adjusted for performance
    pred_iou_thresh=0.8,
    stability_score_thresh=0.9,
    stability_score_offset=1.0,
    mask_threshold=0.0,
    box_nms_thresh=1.0,
    crop_n_layers=0,
    min_mask_region_area=0,
    output_mode="binary_mask",
    use_m2m=True,
    multimask_output=False
)


The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1


Create mask generator and generate masks

In [97]:
import image_slicer
from PIL import Image

# Choose the number of tiles
num_tiles = 4  # Adjust to 2, 4, 8, etc.

# Slice the image
tiles = image_slicer.slice(image_path, num_tiles, save=False)

# Map tile numbers to their positions
tile_positions = {}
tile_width = tiles[0].image.width
tile_height = tiles[0].image.height

for tile in tiles:
    # tile.number ranges from 1 to num_tiles
    col_index = (tile.column - 1)
    row_index = (tile.row - 1)
    x_offset = col_index * tile_width
    y_offset = row_index * tile_height
    tile_positions[tile.number] = (x_offset, y_offset)


In [98]:
all_masks = []
mask_id_counter = 0  # Unique identifier for each mask

# Disable gradients and process each tile
with torch.no_grad():
    for tile in tiles:
        tile_image_pil = tile.image.convert('RGB')  # Ensure image is in RGB
        tile_image = np.array(tile_image_pil)
        
        # Convert to RGB format expected by the model
        tile_image_rgb = cv2.cvtColor(tile_image, cv2.COLOR_RGB2BGR)
        
        # Generate masks for the tile
        tile_masks = mask_generator.generate(tile_image_rgb)
        
        # Get x and y offsets
        x_offset, y_offset = tile_positions[tile.number]
        
        # Adjust mask coordinates to the original image
        for mask in tile_masks:
            # Adjust the segmentation mask
            full_mask = adjust_mask_coordinates(mask['segmentation'], x_offset, y_offset, image_rgb.shape)
            mask['segmentation'] = full_mask
            
            # Assign a unique ID to each mask
            mask['mask_id'] = mask_id_counter
            mask_id_counter += 1
            
            all_masks.append(mask)


KeyboardInterrupt: 

Post-process masks

In [None]:
# Assign random colors to each mask for visualization
color_palette = ColorPalette()

# Create detections from masks
detections = Detections(
    xyxy=[],
    mask=np.array([mask['segmentation'] for mask in all_masks])
)

# Annotate the image
mask_annotator = sv.MaskAnnotator(color=color_palette, opacity=0.6)
annotated_image = mask_annotator.annotate(
    scene=image_bgr.copy(),
    detections=detections
)

# Display the annotated image
plt.figure(figsize=(12, 8))
plt.imshow(cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB))
plt.title("Annotated Image with Masks")
plt.axis("off")
plt.show()


Annotate image

In [None]:
polygons_list = []

# Loop over each mask
for mask in all_masks:
    # Find contours in the mask
    contours, _ = cv2.findContours(mask['segmentation'].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # Process each contour
    for contour in contours:
        if contour.shape[0] >= 3:
            # Simplify the contour
            epsilon = 0.01 * cv2.arcLength(contour, True)
            approx = cv2.approxPolyDP(contour, epsilon, True)
            # Convert to Polygon
            poly = Polygon(approx.reshape(-1, 2))
            if poly.is_valid and poly.area > 0:
                polygons_list.append({
                    'polygon': poly,
                    'mask_id': mask['mask_id']
                })

# Prepare image for drawing polygons
image_with_polygons = image_bgr.copy()

# Draw polygons on the image
for item in polygons_list:
    poly = item['polygon']
    color = color_palette.by_idx(item['mask_id'])
    coords = np.array(list(poly.exterior.coords)).astype(np.int32)
    cv2.polylines(image_with_polygons, [coords], isClosed=True, color=color.as_bgr(), thickness=2)

# Display the image with vectorized polygons
plt.figure(figsize=(12, 8))
plt.imshow(cv2.cvtColor(image_with_polygons, cv2.COLOR_BGR2RGB))
plt.title("Image with Vectorized Polygons")
plt.axis("off")
plt.show()


Convert masks to polygons and display