In [22]:
from types import SimpleNamespace
from pathlib import Path
from owslib.wms import WebMapService
from tqdm import tqdm
import torch

from ultralytics import YOLO

from PIL import Image

CFG = SimpleNamespace()

# 00 - Define functions

In [23]:
def load_and_save_images_from_wms(
        url: str,
        layer_name: str,
        coords_center: tuple,
        img_size: tuple,
        resolution: float,
        path: Path,
        filename: str,
        return_image_object: bool = False,
        ):
    """
    Load an image from a WMS server and save it to a file with a world file for georeferencing.
    Inputs:
        url (str): The URL of the WMS server.
        layer_name (str): The name of the layer to retrieve.
        coords_center (tuple): The center coordinates (longitude, latitude) for the image.
        img_size (tuple): The size of the image in pixels (width, height).
        resolution (float): The resolution of the image in meters per pixel.
        path (Path): The directory where the image and world file will be saved.
        filename (str): The base name for the output files.
        return_image_object (bool): If True, returns the image as a PIL Image object instead of saving it to a file.
    Outputs:
        Saves a JPEG image and a world file in the specified directory.
    """
    # Ensure the output directory exists
    path.mkdir(parents=True, exist_ok=True)

    # Calculate the bounding box based on the center coordinates, image size, and resolution
    x_min = coords_center[0] - (img_size[0] * resolution) / 2
    x_max = coords_center[0] + (img_size[0] * resolution) / 2
    y_min = coords_center[1] - (img_size[1] * resolution) / 2
    y_max = coords_center[1] + (img_size[1] * resolution) / 2

    # For debugging purposes, print the bounding box coordinates
    print(f"Bounding box coordinates: ({x_min}, {y_min}, {x_max}, {y_max})")

    # Get the image from the WMS server
    wms = WebMapService(url)
    img = wms.getmap(
        layers=[layer_name],
        srs='EPSG:3857', # Web Mercator
        bbox=(x_min, y_min, x_max, y_max),
        size=img_size,
        format='image/jpeg',
        transparent=True,
        )
    
    # Create a world file for georeferencing
    world_file_content = f"{resolution}\n0.0\n0.0\n-{resolution}\n{x_min}\n{y_max}\n"
    world_file_path = path / f'{filename}.pgw'
    with open(world_file_path, 'w') as world_file:
        world_file.write(world_file_content)
    
    if return_image_object:
        # Convert the image to a PIL Image object and return it
        return Image.open(img)
        print(type(img))

    elif not return_image_object:
        # If not returning the image object, write the image to a file
        out = open(path / f'{filename}.jpg', 'wb')
        out.write(img.read())
        out.close()

In [24]:
def make_prediction_and_save_mask_from_pil(
        image: Image.Image,
        model,
        output_dir: Path,
        image_name: str,
        ):
    """
    Predicts and saves the combined mask for a given PIL image using a YOLO model.
    
    Args:
        image (PIL.Image.Image): Input image.
        model: YOLO model object.
        output_dir (Path): Directory to save the mask.
        image_name (str): Name for the output mask file (without extension).
    """
    # Ensure the output directory exists
    output_dir.mkdir(exist_ok=True)

    # Make predictions using the YOLO model
    results = model.predict(
        source=image,
        verbose=False,
        )
    
    # Extract and combine masks from predictions and save them
    class_values = {0: 100, 1: 255}  # Define class values

    # Iterate through the predictions
    for pred in results:
        if pred.masks is not None:
            masks = pred.masks.data  # shape: (num_instances, H, W)
            class_ids = pred.boxes.cls.int().cpu().numpy()  # Get class IDs for each mask

            combined_mask = torch.zeros(masks.shape[1], masks.shape[2], dtype=torch.uint8)  # Initialize combined mask

            # Iterate through each mask and its corresponding class ID
            for i, mask in enumerate(masks):
                class_id = class_ids[i]
                if class_id in class_values:
                    # Convert mask to boolean to use as an index
                    mask = mask.bool()
                    combined_mask[mask] = class_values[class_id]  # Set pixel value based on class

            mask_img = Image.fromarray(combined_mask.numpy())  # Create image from combined mask
            mask_img.save(output_dir / f"{image_name}.png")  # Save the image


    # for pred in results:
    #     if pred.masks is not None:
    #         masks = pred.masks.data  # shape: (num_instances, H, W)
    #         combined_mask = (masks.sum(dim=0) > 0).byte()
    #         mask_img = Image.fromarray((combined_mask.cpu().numpy() * 255).astype('uint8'))
    #         mask_img.save(output_dir / f"{image_name}.png")

# 01 - Download image and make prediction

### Configuration for downloading data

In [25]:
CFG.WMS_URL = 'http://geoservices.buergernetz.bz.it/mapproxy/root/ows?'
CFG.RESOLUTION = 0.2
CFG.LAYER_NAME = 'p_bz-Orthoimagery:Aerial-2023-RGB' #'p_bz-Orthoimagery:Aerial-2014-RGB' #'p_bz-Orthoimagery:Aerial-2020-RGB', 
CFG.IMG_SIZE = (640, 640)

# Set the raster size and center in Pseudo-Mercator coordinates (EPSG:3857)
CFG.AEREA_LEFT_BOTTOM_COORDS = (1233023, 5870474)
CFG.RASTER_SIZE = (1, 1)  # Number of images in x and y direction

CFG.PATH_TO_OUTPUT_DIR = Path().cwd() / 'predicted_masks'
CFG.DELETE_EXISTING = True  # Set to True to delete existing files in the output directory

# Select the model and load it
CFG.PATH_TO_MODEL = 'models/yolo11m_250imgs_100epochs_best.pt'
model = YOLO(CFG.PATH_TO_MODEL)

In [26]:
# Delete existing files in the output directory if specified
if CFG.DELETE_EXISTING and CFG.PATH_TO_OUTPUT_DIR.exists():
    for file in CFG.PATH_TO_OUTPUT_DIR.glob('*'):
        file.unlink()

# Calculate the step size based on the image size and resolution
x_step = CFG.IMG_SIZE[0] * CFG.RESOLUTION
y_step = CFG.IMG_SIZE[1] * CFG.RESOLUTION

with tqdm(total=CFG.RASTER_SIZE[0] * CFG.RASTER_SIZE[1], desc='Total progress') as pbar_total:

    for col_x in range(CFG.RASTER_SIZE[0]):
        for row_y in range(CFG.RASTER_SIZE[1]):

            # Calculate the coordinates for the center of the image
            coords_center = (
                CFG.AEREA_LEFT_BOTTOM_COORDS[0] + x_step/2 + (col_x*x_step),
                CFG.AEREA_LEFT_BOTTOM_COORDS[1] + x_step/2 + (row_y*y_step),
            )

            # Load and save the image
            filename = f'{int(coords_center[0])}_{int(coords_center[1])}'
            image = load_and_save_images_from_wms(
                url=CFG.WMS_URL,
                layer_name=CFG.LAYER_NAME,
                coords_center=coords_center,
                img_size=CFG.IMG_SIZE,
                resolution=CFG.RESOLUTION,
                path=CFG.PATH_TO_OUTPUT_DIR,
                filename=filename,
                return_image_object=True,
                )
            
            make_prediction_and_save_mask_from_pil(
                image=image,
                model=model,
                output_dir=CFG.PATH_TO_OUTPUT_DIR,
                image_name=filename,
                )
            
            # Update the progress bar
            pbar_total.update(1)

Total progress:   0%|          | 0/1 [00:00<?, ?it/s]

Bounding box coordinates: (1233023.0, 5870474.0, 1233151.0, 5870602.0)


Total progress: 100%|██████████| 1/1 [00:01<00:00,  1.87s/it]
