# Split Potsdam dataset tiles into chips to enable training with own tools.

In [64]:
import glob
import itertools
import os
import warnings
from collections.abc import Generator

import numpy as np
import rasterio
import rasterio.windows
import torchgeo.datasets.utils
from rasterio.errors import NotGeoreferencedWarning
from torchgeo.datasets import Potsdam2D

In [22]:
def get_chips(
    src: rasterio.io.DatasetReader, width: int = 512, height: int = 512
) -> Generator[tuple[rasterio.windows.Window, "affine.Affine"], None, None]:
    num_cols, num_rows = src.width, src.height

    tile = rasterio.windows.Window(
        col_off=0, row_off=0, width=num_cols, height=num_rows
    )

    offsets = itertools.product(range(0, num_cols, width), range(0, num_rows, height))
    for col_off, row_off in offsets:
        chip = rasterio.windows.Window(
            col_off=col_off, row_off=row_off, width=width, height=height
        ).intersection(tile)
        transform = rasterio.windows.transform(chip, src.transform)
        yield chip, transform

## NOTE: This steps should be performed only ONCE.

## Setup

In [48]:
# Initialize the dataset.
data_dirpath = r"C:\Users\Dimit\Downloads\Potsdam"

# Configure the image identifiers.
image_dirname = "2_Ortho_RGB"
image_dirpath = os.path.join(data_dirpath, image_dirname)
image_name_glob = "_" + image_dirname.rsplit("_", maxsplit=1)[-1] + ".tif"

## Image Splitting

In [55]:
# Create a directory to store the image chips.
os.makedirs(os.path.join(image_dirpath, "images"), exist_ok=True)

# Gather the image paths.
image_paths = glob.glob(os.path.join(image_dirpath, "*" + image_name_glob))
image_paths.sort()

# Split the images.
for src_path in image_paths:
    src: rasterio.io.DatasetReader
    with rasterio.open(src_path) as src:
        meta: rasterio.profiles.Profile = src.meta.copy()
        for chip, transform in get_chips(src):
            # Ensure only square chips are processed.
            if chip.width != chip.height:
                continue

            # Update the chip metadata.
            meta.update(width=chip.width, height=chip.height, transform=transform)

            # Inject the normalized column and row offset in the filename.
            dst_path = (
                src_path.removesuffix(image_name_glob)
                + f"_{chip.col_off // chip.width}-{chip.row_off // chip.height}"
                + ".tif"
            )
            dst_path = os.path.join(
                os.path.dirname(dst_path), "images", os.path.basename(dst_path)
            )
            dst: rasterio.io.DatasetWriter
            with rasterio.open(dst_path, mode="w", **meta) as dst:
                dst.write(src.read(window=chip))

## Mask Splitting

In [75]:
# Create a directory to store the mask chips.
os.makedirs(os.path.join(data_dirpath, "masks"), exist_ok=True)

# Gather the mask paths.
mask_paths = glob.glob(os.path.join(data_dirpath, "*" + "_label.tif"))
mask_paths.sort()

# Split the masks.
for src_path in mask_paths:
    src: rasterio.io.DatasetReader
    with warnings.catch_warnings(action="ignore", category=NotGeoreferencedWarning):
        with rasterio.open(src_path) as src:
            meta: rasterio.profiles.Profile = src.meta.copy()
            for chip, transform in get_chips(src):
                if chip.width != chip.height:
                    continue

                # Convert the chip to a single-band raster.
                data = src.read(window=chip)
                data = np.moveaxis(data, source=0, destination=-1)
                data = torchgeo.datasets.utils.rgb_to_mask(
                    data, colors=Potsdam2D.colormap
                )

                # Update the chip metadata.
                meta.update(
                    width=chip.width, height=chip.height, count=1, transform=transform
                )

                # Inject the normalized column and row offset in the filename.
                dst_path = (
                    src_path.removesuffix("_label.tif")
                    + f"_{chip.col_off // chip.width}-{chip.row_off // chip.height}"
                    + ".tif"
                )
                dst_path = os.path.join(
                    os.path.dirname(dst_path), "masks", os.path.basename(dst_path)
                )
                dst: rasterio.io.DatasetWriter
                with rasterio.open(dst_path, mode="w", **meta) as dst:
                    dst.write(data, indexes=1)

## Dataset Filtering
Images which do not contain buildings are discarded, and every other class is mapped to background to be ignored during training.
This process transforms the dataset into one meant for building identification, such as the [Inria Aerial Image Labeling Benchmark](https://project.inria.fr/aerialimagelabeling/), but with a spatial resolution better suited for the downstream data.
The expectation is that building labeling is a contextually more appropriate pretext task than plain urban scene segmentation.

**TODO**: Consider discarding images based on a minimum building percentage in the corresponding masks.

In [None]:
building_idx = 3