In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import psycopg2
import os

In [2]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    x0, y0, x1, y1 = box
    w, h = x1 - x0, y1 - y0
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

In [3]:
# Loading the SAM model and predictor
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# vit_b model checkpoint
sam_checkpoint_vitb = "../../experiments/checkpoints/sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint_vitb)
sam.to(device=device)

# Predictor for prompts
predictor = SamPredictor(sam)

  state_dict = torch.load(f)


In [4]:
# image folders
Back = "../../data/experimental/one_building_test/Back/"
Fwd = "../../data/experimental/one_building_test/Fwd/"
Left = "../../data/experimental/one_building_test/Left/"
Right = "../../data/experimental/one_building_test/Right/"

In [5]:
def wkt_to_bbox(wkt):
    coords = wkt.replace("POLYGON((", "").replace("))", "").split(", ")
    x_vals = []
    y_vals = []
    for coord in coords:
        x, y = map(float, coord.split())
        x_vals.append(x)
        y_vals.append(y)
    return [min(x_vals), min(y_vals), max(x_vals), max(y_vals)]

def transform_bounding_boxes(boxes, img_width, img_height, direction):
    transformed_boxes = []

    for box in boxes:
        x_min, y_min, x_max, y_max = box

        # Step 1: Rotate 90 degrees clockwise
        new_x_min = y_min
        new_y_min = img_width - x_max
        new_x_max = y_max
        new_y_max = img_width - x_min

        # Step 2: Flip vertically
        final_x_min = new_x_min
        final_y_min = img_height - new_y_max
        final_x_max = new_x_max
        final_y_max = img_height - new_y_min

        # Step 3: Shift downward (if boxes are too high)
        shift_down = img_height / 3.1  # Shift down by half the image height
        final_y_min += shift_down
        final_y_max += shift_down

        if direction == "Fwd":
            transformed_boxes.append([final_x_min - 20, final_y_min - 300, final_x_max + 20, final_y_max + 20]) # Forward
        elif direction == "Left":
            transformed_boxes.append([final_x_min - 300, final_y_min - 100, final_x_max + 20, final_y_max + 200]) # Left
        elif direction == "Right":
            transformed_boxes.append([final_x_min - 20, final_y_min - 20, final_x_max + 300, final_y_max + 200]) # Right
        elif direction == "Back":
            transformed_boxes.append([final_x_min - 20, final_y_min - 20, final_x_max + 200, final_y_max + 600]) # Back

    return transformed_boxes

def get_unique_filename(output_dir, filename):
    # Create a unique filename by appending a counter or timestamp if the file already exists
    base_name, ext = os.path.splitext(filename)
    counter = 1
    while os.path.exists(os.path.join(output_dir, filename)):
        filename = f"{base_name}_{counter}{ext}"
        counter += 1
    return filename

def connect_to_database(imageid):
    # Database configuration
    DB_CONFIG = {
        "dbname" : "BagMapDB",
        "user" : "postgres",
        "password" : os.getenv("DB_PASSWORD"),
        "host" : "localhost",
        "port" : "5432"
    }

    # Connect to the database
    conn = psycopg2.connect(**DB_CONFIG)
    cursor = conn.cursor()

    cursor.execute("SELECT bag_ids, bboxes FROM bag_in_image_utrecht WHERE image_name = %s;", (imageid,))

    result = cursor.fetchone()

    cursor.close()
    conn.close()

    if result:
        bag_ids = result[0]  # List of bag IDs
        bboxes_wkt = result[1]  # List of WKT polygons
        return bag_ids, bboxes_wkt

    return [], []

In [7]:
MASK_OUTPUT_ROOT = "../../data/masks/"

def segment(folder, imageid):
    # Read the image and direction from folder
    image_path = os.path.join(folder, imageid)
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    direction = folder.rstrip("/").split("/")[-1]

    # Extract bag_ids and their bboxes
    bag_ids, bboxes_wkt = connect_to_database(imageid)

    # Transform bounding boxes for SAM
    bboxes = [wkt_to_bbox(wkt) for wkt in bboxes_wkt]
    input_boxes = torch.tensor(bboxes, device=device)
    input_boxes = transform_bounding_boxes(input_boxes.cpu().numpy(), image.shape[1], image.shape[0], direction)

    buffer_size = 500  # Buffer size around input boxes

    # Iterate over all input boxes
    for i, prompt_box in enumerate(input_boxes):
        if bag_ids[i] == "0344100000157740":
            xmin, ymin, xmax, ymax = prompt_box

            # Get the buffer coordinates
            x_min_buff = max(0, int(xmin) - buffer_size)
            y_min_buff = max(0, int(ymin) - buffer_size)
            x_max_buff = min(image.shape[1], int(xmax) + buffer_size)
            y_max_buff = min(image.shape[0], int(ymax) + buffer_size)

            if x_min_buff >= x_max_buff or y_min_buff >= y_max_buff:
                continue

            # Crop image
            cropped_image = image[y_min_buff:y_max_buff, x_min_buff:x_max_buff]

            predictor.set_image(cropped_image)

            # Define prompt box relative to cropped image
            prompt_box_cropped_image = np.array([
                max(0, min(cropped_image.shape[1] - 1, xmin - x_min_buff)),
                max(0, min(cropped_image.shape[0] - 1, ymin - y_min_buff)),
                max(0, min(cropped_image.shape[1] - 1, xmax - x_min_buff)),
                max(0, min(cropped_image.shape[0] - 1, ymax - y_min_buff))
            ], dtype=np.float32)

            # Run segmentation
            masks, scores, _ = predictor.predict(box=prompt_box_cropped_image, multimask_output=False)

            if scores[0] > 0.0:
                # Convert mask to binary (0 or 1)
                binary_mask = (masks[0] > 0).astype(np.uint8)

                # Create a full-size mask
                full_mask = np.zeros(image.shape[:2], dtype=np.uint8)
                full_mask[y_min_buff:y_max_buff, x_min_buff:x_max_buff] = binary_mask

                # Convert to 255 range for saving
                full_mask = full_mask * 255

                # Construct the corresponding mask path
                relative_path = os.path.relpath(image_path, "../../data/experimental/one_building_test/")
                mask_output_path = os.path.join(MASK_OUTPUT_ROOT, relative_path + ".png")

                # Create directory if not exists
                os.makedirs(os.path.dirname(mask_output_path), exist_ok=True)

                # Save binary mask
                cv2.imwrite(mask_output_path, full_mask)




In [8]:
for folder in [Back, Fwd, Left, Right]:
    count = 0
    for imageid in os.listdir(folder):
        print(f"Segmenting {folder} {imageid} {count}")
        count += 1
        segment(folder, imageid)

Segmenting ../../data/experimental/one_building_test/Back/ 262005059_0055_01_0074_P00_01.jpg 0
Segmenting ../../data/experimental/one_building_test/Back/ 262005060_0055_01_0073_P00_01.jpg 1
Segmenting ../../data/experimental/one_building_test/Back/ 262005061_0055_01_0072_P00_01.jpg 2
Segmenting ../../data/experimental/one_building_test/Back/ 262005062_0055_01_0071_P00_01.jpg 3
Segmenting ../../data/experimental/one_building_test/Back/ 262005063_0055_01_0070_P00_01.jpg 4
Segmenting ../../data/experimental/one_building_test/Back/ 262005064_0055_01_0069_P00_01.jpg 5
Segmenting ../../data/experimental/one_building_test/Back/ 262005065_0055_01_0068_P00_01.jpg 6
Segmenting ../../data/experimental/one_building_test/Back/ 262005219_0056_01_0086_P00_01.jpg 7
Segmenting ../../data/experimental/one_building_test/Back/ 262005220_0056_01_0087_P00_01.jpg 8
Segmenting ../../data/experimental/one_building_test/Back/ 262005221_0056_01_0088_P00_01.jpg 9
Segmenting ../../data/experimental/one_building_te