In [1]:
import glob
import itertools
import os
from collections.abc import Generator
from functools import partial

import numpy as np
import rasterio
import rasterio.merge
import rasterio.windows
from torchgeo.datasets import Potsdam2D
from torchgeo.datasets.utils import rgb_to_mask

In [2]:
root_dirpath = r"C:\Users\Dimit\Downloads\Potsdam"
join_to_root_dir = partial(os.path.join, root_dirpath)

In [3]:
image_dirname = "2_Ortho_RGB"
image_dirpath = join_to_root_dir(image_dirname)

label_dirname = "5_Labels_all"
label_dirpath = join_to_root_dir(label_dirname)

In [4]:
image_paths = glob.glob(os.path.join(image_dirpath, "*.tif"))
image_paths.sort()

label_paths = glob.glob(os.path.join(label_dirpath, "*.tif"))
label_paths.sort()

In [5]:
def get_chips(
    src: rasterio.io.DatasetReader, width: int = 512, height: int = 512
) -> Generator[tuple[rasterio.windows.Window, "affine.Affine"], None, None]:
    num_cols, num_rows = src.width, src.height

    tile = rasterio.windows.Window(
        col_off=0, row_off=0, width=num_cols, height=num_rows
    )

    offsets = itertools.product(range(0, num_cols, width), range(0, num_rows, height))
    for col_off, row_off in offsets:
        chip = rasterio.windows.Window(
            col_off=col_off, row_off=row_off, width=width, height=height
        ).intersection(tile)
        transform = rasterio.windows.transform(chip, src.transform)
        yield chip, transform

In [12]:
image_chip_dirpath = os.path.join(image_dirpath, "chips")
os.makedirs(image_chip_dirpath, exist_ok=True)

label_chip_dirpath = os.path.join(label_dirpath, "chips")
os.makedirs(label_chip_dirpath, exist_ok=True)

In [7]:
for i, src_path in enumerate(image_paths):
    print(f"{i}/{len(image_paths)}")
    src: rasterio.io.DatasetReader
    with rasterio.open(src_path) as src:
        meta: rasterio.profiles.Profile = src.meta.copy()
        for chip, transform in get_chips(src):
            if not (chip.width == chip.height == 512):
                continue
            src_data = src.read(window=chip)
            dst_name = (
                os.path.basename(src_path).removesuffix("_RGB.tif")
                + f"_{chip.col_off // chip.width}-{chip.row_off // chip.height}"
                + ".tif"
            )
            meta.update(width=chip.width, height=chip.height, transform=transform)
            dst_path = os.path.join(image_chip_dirpath, dst_name)
            dst: rasterio.io.DatasetWriter
            with rasterio.open(dst_path, mode="w", **meta) as dst:
                dst.write(src_data)

0/38
1/38
2/38
3/38
4/38
5/38
6/38
7/38
8/38
9/38
10/38
11/38
12/38
13/38
14/38
15/38
16/38
17/38
18/38
19/38
20/38
21/38
22/38
23/38
24/38
25/38
26/38
27/38
28/38
29/38
30/38
31/38
32/38
33/38
34/38
35/38
36/38
37/38


In [13]:
for i, src_path in enumerate(label_paths):
    print(f"{i}/{len(label_paths)}")
    src: rasterio.io.DatasetReader
    with rasterio.open(src_path) as src:
        meta: rasterio.profiles.Profile = src.meta.copy()
        for chip, _ in get_chips(src):
            if not (chip.width == chip.height == 512):
                continue
            src_data = src.read(window=chip)
            if not np.allclose(src_data, src_data.astype(np.uint8)):
                raise RuntimeError(f"bad mask: {src_path}")
            src_data = rgb_to_mask(np.moveaxis(src_data, 0, -1), Potsdam2D.colormap)
            # if not np.any(src_data):
            #     raise RuntimeError(f"bad mask: {src_path}")
            dst_name = (
                os.path.basename(src_path).removesuffix("_label.tif")
                + f"_{chip.col_off // chip.width}-{chip.row_off // chip.height}"
                + ".tif"
            )
            with rasterio.open(os.path.join(image_chip_dirpath, dst_name)) as img_src:
                crs = img_src.crs
                transform = img_src.transform
            meta.update(
                count=1,
                width=chip.width,
                height=chip.height,
                crs=crs,
                transform=transform,
            )
            dst_path = os.path.join(label_chip_dirpath, dst_name)
            dst: rasterio.io.DatasetWriter
            with rasterio.open(dst_path, mode="w", **meta) as dst:
                dst.write(src_data, indexes=1)

0/38
1/38
2/38
3/38
4/38
5/38
6/38
7/38
8/38
9/38
10/38
11/38
12/38
13/38
14/38
15/38
16/38
17/38
18/38
19/38
20/38
21/38
22/38
23/38
24/38
25/38
26/38
27/38
28/38
29/38
30/38
31/38
32/38
33/38
34/38
35/38
36/38
37/38
