In [1]:
import os
import cv2
import numpy as np
import random
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

In [2]:
input_folder = "../../data/experimental/bouwpub/building_masks/"
output_folder = "../../data/experimental/bouwpub/building_level_segmentation_test/"
os.makedirs(output_folder, exist_ok=True)

In [58]:
sam_checkpoint = "../../experiments/checkpoints/sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    crop_n_layers=4
)


In [None]:
for imageid in os.listdir(input_folder):
    if imageid != "IMG-20240605-WA0034.png":
        continue

    image_path = os.path.join(input_folder, imageid)

    # Load image with transparency support
    image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)

    # Ensure it has an alpha channel
    if image.shape[-1] == 4:  # RGBA
        alpha_channel = image[:, :, 3]
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
    else:  # No transparency
        alpha_channel = np.ones(image.shape[:2], dtype=np.uint8) * 255
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Generate masks
    masks = mask_generator.generate(image_rgb)

    # Create an RGBA output image
    mask_image_rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)

    for ann in masks:
        mask = ann['segmentation']

        color = [random.randint(0, 255) for _ in range(3)]
        mask_image_rgba[mask, :3] = color  # Assign color
        mask_image_rgba[mask, 3] = 255  # Make mask visible

    # Save with a unique filename if file already exists
    base_name, ext = os.path.splitext(imageid)
    output_path = os.path.join(output_folder, imageid)
    count = 1
    while os.path.exists(output_path):
        new_name = f"{base_name}_{count}{ext}"
        output_path = os.path.join(output_folder, new_name)
        count += 1

    cv2.imwrite(output_path, mask_image_rgba)

    print(f"Processed: {imageid} -> Saved at {output_path}")

print("Segmentation complete!")