In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
import os
import hydra
from hydra import initialize_config_module, initialize, compose
from hydra.core.global_hydra import GlobalHydra
import sys
from samgeo import SamGeo2, SamGeo
import cv2
import rasterio
!pip install pymupdf
import fitz 


# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.float16).__enter__()
 
if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

In [None]:
def extract_image_from_pdf(pdf_path, page_number=0, image_index=0):
    """
    Extract an image from a specified page in a PDF.
    
    Args:
        pdf_path (str): Path to the PDF file.
        page_number (int): Page number (0-indexed) to extract the image from.
        image_index (int): Index of the image on the page (default is the first image).
        
    Returns:
        np.ndarray: Extracted image in RGB format, or None if no image is found.
    """
    # Open the PDF
    pdf_document = fitz.open(pdf_path)
    page = pdf_document.load_page(page_number)  # Load the specified page
    images = page.get_images(full=True)  # Get all images on the page

    if not images or image_index >= len(images):
        print(f"No image found on page {page_number + 1} at index {image_index}.")
        return None

    # Extract the specified image
    xref = images[image_index][0]  # XREF of the image
    base_image = pdf_document.extract_image(xref)
    image_bytes = base_image["image"]

    # Convert to a NumPy array
    image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
    pdf_document.close()
    return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Load the image
#image_bgr = cv2.imread('dataset/aalesund/FOKUS/1504200/200.jpg')
#image_bgr = cv2.imread('dataset/aalesund/UTFORDRING/Monokrom/1504343A/343a.jpg')
#image_bgr = cv2.imread('aalesund_fokus/213.jpg')
#image_bgr = cv2.imread('dataset/molde/UTFORDRING/0577/0577_plankart.tif')

image_rgb = extract_image_from_pdf('dataset/kristiansund/FOKUS/R-077/R-077 Plankart.pdf') 
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)

#image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

In [None]:
import os
import torch
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

# Move the model to the desired device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(
    model_id="facebook/sam2-hiera-large",
    points_per_side=32,  # Define points per side
    points_per_batch=16,  # Number of points per batch
    pred_iou_thresh=0.75,  # Filter threshold for mask quality
    stability_score_thresh=0.75,  # Filter threshold for stability score
    stability_score_offset=1.0,
    mask_threshold=0.0,
    box_nms_thresh=0.5,           
    crop_n_layers=1,
    crop_nms_thresh=1,
    crop_overlap_ratio=0.8,        
    crop_n_points_downscale_factor=1,
    point_grids=None,
    min_mask_region_area=0,
    output_mode="binary_mask",
    use_m2m=True,                  
    multimask_output=False
)

print("Model and mask generator initialized successfully.")


In [None]:
sam_result = mask_generator.generate(image_rgb)
#mask_generator.generate(image_rgb)
#sam_result = mask_generator.masks
print(len(sam_result))

In [None]:
# Import necessary modules
import numpy as np
import cv2

# Define function to check if one mask is completely inside another
def is_mask_inside(outer_mask, inner_mask):
    # Check if all True pixels in inner_mask are also True in outer_mask
    return np.all(outer_mask[inner_mask > 0])

# Define function to check if a mask is on a predominantly gray or black background
def is_colorful_region(image, mask, saturation_threshold=20, brightness_threshold=50):
    """
    Checks if the region within the mask is sufficiently colorful and bright.
    Args:
        image: Original image (in BGR format).
        mask: Binary mask for the region of interest.
        saturation_threshold: Minimum saturation required to consider a region colorful.
        brightness_threshold: Minimum brightness required to consider a region bright.
    Returns:
        bool: True if the region is colorful and bright, False otherwise.
    """
    # Ensure the mask is of type uint8 and has the same shape as the image's first two dimensions
    if mask.ndim == 2:  # If mask is already 2D
        mask = mask.astype(np.uint8)
    else:
        raise ValueError("The mask should be a 2D binary array.")

    # Resize the mask if it does not match the image dimensions
    if mask.shape != image.shape[:2]:
        mask = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)

    # Scale binary mask (0 or 1) to 0 or 255
    mask = mask * 255

    # Convert the image to HSV color space
    hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

    # Apply the mask to extract the region of interest
    masked_hsv = cv2.bitwise_and(hsv_image, hsv_image, mask=mask)

    # Extract the saturation (S) and value (V) channels
    saturation = masked_hsv[..., 1]
    brightness = masked_hsv[..., 2]

    # Calculate the mean saturation and brightness within the mask
    mean_saturation = cv2.mean(saturation, mask=mask)[0]
    mean_brightness = cv2.mean(brightness, mask=mask)[0]

    # Determine if the region is colorful and bright
    return mean_saturation > saturation_threshold and mean_brightness > brightness_threshold



# Step 1: Extract masks, compute area, and compute bounding boxes for sorting
# Include the index to reference back to `sam_result`
masks_with_areas_and_bboxes = []
for i, mask in enumerate(sam_result):
    segmentation = mask['segmentation']
    if np.any(segmentation):
        area = np.sum(segmentation)
        # Compute bounding box
        coords = np.argwhere(segmentation)
        y_coords, x_coords = coords[:, 0], coords[:, 1]
        min_x, max_x = x_coords.min(), x_coords.max()
        min_y, max_y = y_coords.min(), y_coords.max()
        bbox = (min_x, min_y, max_x, max_y)
        masks_with_areas_and_bboxes.append((i, segmentation, area, bbox))

# Sort masks by area (from largest to smallest)
masks_with_areas_and_bboxes.sort(key=lambda x: x[2], reverse=True)  # (index, mask, area, bbox)

# Set the threshold for the minimum number of contained masks required to remove a mask
contained_mask_threshold = int(0.1 * len(masks_with_areas_and_bboxes))

# Identify masks to remove
indices_to_remove = set()

# Loop through masks and remove larger masks that contain multiple smaller masks
for i, (outer_idx, outer_mask, outer_area, outer_bbox) in enumerate(masks_with_areas_and_bboxes):
    outer_min_x, outer_min_y, outer_max_x, outer_max_y = outer_bbox
    contained_count = 0  # Counter for masks contained within the current outer mask

    # Only consider smaller masks
    for inner_idx, inner_mask, inner_area, inner_bbox in masks_with_areas_and_bboxes[i+1:]:
        inner_min_x, inner_min_y, inner_max_x, inner_max_y = inner_bbox

        # Check if inner bounding box is entirely within outer bounding box
        if (inner_min_x >= outer_min_x and inner_max_x <= outer_max_x and
            inner_min_y >= outer_min_y and inner_max_y <= outer_max_y):
            # Now check if inner_mask is inside outer_mask
            if is_mask_inside(outer_mask, inner_mask):
                contained_count += 1  # Increment count for each contained mask

    # Only mark the larger mask for removal if it contains at least `contained_mask_threshold` smaller masks
    if contained_count >= contained_mask_threshold:
        indices_to_remove.add(outer_idx)

# Filter out the unwanted masks
filtered_masks_with_areas_and_bboxes = [
    (idx, mask, area, bbox)
    for idx, mask, area, bbox in masks_with_areas_and_bboxes
    if idx not in indices_to_remove
]

# Also remove any masks that cover the entire image (if any)
image_area = image_bgr.shape[0] * image_bgr.shape[1]
filtered_masks_with_areas_and_bboxes = [
    (idx, mask, area, bbox)
    for idx, mask, area, bbox in filtered_masks_with_areas_and_bboxes
    if area < image_area
]

# Apply brightness filtering to remove masks with predominantly gray or black backgrounds
filtered_masks_with_areas_and_bboxes = [
    (idx, mask, area, bbox)
    for idx, mask, area, bbox in filtered_masks_with_areas_and_bboxes
    if is_colorful_region(image_bgr, mask, saturation_threshold=0, brightness_threshold=0)
]

# Create a filtered sam_result
filtered_sam_result = [sam_result[idx] for idx, _, _, _ in filtered_masks_with_areas_and_bboxes]

# Debug: Print the number of masks after filtering
print(f"Total masks after filtering: {len(filtered_sam_result)}")


In [None]:
from shapely.geometry import MultiPolygon, Polygon

polygons_list = []

# Prepare a copy of the original image for drawing polygons
image_with_polygons = image_bgr.copy()

# Image area
image_area = image_bgr.shape[0] * image_bgr.shape[1]

# List to store polygons with their area
mask_polygons = []

# Function to smooth contour using moving average
def smooth_contour(contour, window_size=5):
    # Ensure window_size is odd
    if window_size % 2 == 0:
        window_size += 1
    half_window = window_size // 2

    # Pad the contour to handle the circular nature
    contour = np.concatenate((contour[-half_window:], contour, contour[:half_window]), axis=0)
    
    smoothed_contour = []
    for i in range(half_window, len(contour) - half_window):
        window_points = contour[i - half_window:i + half_window + 1]
        mean_point = np.mean(window_points, axis=0)
        smoothed_contour.append(mean_point)
    smoothed_contour = np.array(smoothed_contour, dtype=np.int32)
    return smoothed_contour

# Loop over each mask in the filtered SAM result
for idx, mask_dict in enumerate(filtered_sam_result):
    mask = mask_dict['segmentation'].astype(np.uint8)  # Ensure mask is in uint8 format

    # Find contours in the mask
    contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)

    # Skip if no contours are found
    if not contours:
        continue

    # Process each contour
    for contour in contours:
        if contour.shape[0] < 5:
            continue  # Need at least 5 points to smooth

        # Reshape contour to 2D array
        contour = contour.reshape(-1, 2)

        # Smooth the contour using moving average
        smoothed_contour = smooth_contour(contour, window_size=15)  # Adjust window_size as needed

        if smoothed_contour.shape[0] >= 3:
            polygon = Polygon(smoothed_contour)
            # Ensure the polygon is valid
            if not polygon.is_valid or polygon.area == 0:
                # Try fixing invalid polygons
                polygon = polygon.buffer(0)
                if not polygon.is_valid or polygon.area == 0:
                    continue  # Skip if still invalid
            # Store the polygon along with its area and index
            mask_polygons.append({'area': polygon.area, 'polygon': polygon, 'index': idx})

# Introduce max_area_threshold to exclude overly large polygons
max_area_threshold = 0.1 * image_area  # Exclude polygons covering more than 90% of the image

# Filter out masks that are too large
mask_polygons = [mp for mp in mask_polygons if mp['area'] < max_area_threshold]

# Debug: Print the number of polygons after excluding large masks
print(f"Total polygons after excluding large masks: {len(mask_polygons)}")

# Now, filter out smaller polygons that are mostly within larger ones
# Sort polygons by area in descending order
mask_polygons.sort(key=lambda x: x['area'], reverse=True)

# Initialize list to hold the final polygons
filtered_polygons = []

# Function to check if a polygon is mostly within existing polygons
def is_polygon_mostly_within(poly, existing_polys, area_overlap_threshold=0.95):
    for existing_poly in existing_polys:
        intersection_area = poly.intersection(existing_poly).area
        if poly.area == 0:
            continue
        overlap_ratio = intersection_area / poly.area
        if overlap_ratio >= area_overlap_threshold:
            return True
    return False

# Process each polygon
for idx, poly_dict in enumerate(mask_polygons):
    poly = poly_dict['polygon']
    if not is_polygon_mostly_within(poly, [d['polygon'] for d in filtered_polygons], area_overlap_threshold=0.05):
        filtered_polygons.append(poly_dict)
    else:
        print(f"Polygon {idx} is mostly within another polygon and will be removed.")

# Debug: Print the number of polygons after filtering
print(f"Total polygons after overlap filtering: {len(filtered_polygons)}")

# Draw the filtered polygons on the image
for poly_dict in filtered_polygons:
    poly = poly_dict['polygon']
    if isinstance(poly, Polygon):
        # Handle single Polygon
        coords = np.array(list(poly.exterior.coords)).astype(np.int32)
        cv2.polylines(image_with_polygons, [coords], isClosed=True, color=(0, 255, 0), thickness=5)
        polygons_list.append(poly)
    elif isinstance(poly, MultiPolygon):
        # Handle MultiPolygon
        for sub_poly in poly.geoms:
            if sub_poly.is_valid and not sub_poly.is_empty:
                coords = np.array(list(sub_poly.exterior.coords)).astype(np.int32)
                cv2.polylines(image_with_polygons, [coords], isClosed=True, color=(0, 255, 0), thickness=5)
                polygons_list.append(sub_poly)

# Display the image with vectorized polygons
plt.figure(figsize=(10, 10))
plt.imshow(cv2.cvtColor(image_with_polygons, cv2.COLOR_BGR2RGB))
plt.title("Image with Filtered Polygons (Smoothed Contours)")
plt.axis("off")
plt.show()
