# Dataset Tiling and Mask Preparation

In this notebook, the **S1S2 dataset** is split into **512 × 512 pixel tiles** with a small amount of overlap.

Some images contain **NoData regions**. These regions are indicated in the corresponding *validity masks*, where:

- `0` → Invalid (NoData)  
- `1` → Valid  

To ensure these areas are ignored during training:

- All invalid (`0`) pixels from the validity mask are **burned into the segmentation mask** using the value `99`.
- The value `99` is configured as an **ignore label** during model training.

This approach prevents NoData regions from influencing the loss or evaluation metrics.

In [None]:
from pathlib import Path
import rasterio as rio
from tqdm.auto import tqdm
import json
import numpy as np

from multiprocessing import Pool
from multiprocessing.pool import ThreadPool


In [None]:
patch_size = 512
water_value = 1
background_value = 0
stride = 412
do_L2A = True
image_bands = [1, 2, 3, 4]
expected_image_channels = len(image_bands)

In [None]:
data_dir = Path("/media/nick/4TB Working 6/Datasets/S1S2-Water/train")

In [None]:
img_patches_dir = data_dir / "images"
mask_patches_dir = data_dir / "labels"

img_patches_dir.mkdir(exist_ok=True, parents=True)
mask_patches_dir.mkdir(exist_ok=True)
img_patches_dir, mask_patches_dir

In [None]:
s1s2_folders = data_dir.parent
s1s2_folders.exists()

In [None]:
part_folders = list(
    set(s1s2_folders.glob("part*")) - set(s1s2_folders.glob("part*.zip"))
)
part_folders

In [None]:
s2_imgs = []
s2_img_L2A = []
for folder in part_folders:
    s2_imgs.extend(list(folder.rglob("*sentinel12_s2_*_img.tif")))
    s2_img_L2A.extend(list(folder.rglob("*sentinel12_s2_*_L2A*.tif")))

if not do_L2A:
    s2_imgs = [x for x in s2_imgs if x not in s2_img_L2A]
len(s2_imgs)

In [None]:
s2_imgs


In [None]:
def extract_patch(
    input_array: np.ndarray,
    top: int,
    bottom: int,
    left: int,
    right: int,
    input_raster_path: Path,
    dataset: str,
    src: rio.DatasetReader,
    patch_dir: Path,
    label: bool = False,
    background_value: int = 0,
    water_value: int = 1,
) -> None:
    patch = input_array[:, top:bottom, left:right]
    file_name = input_raster_path.stem
    file_name = file_name.replace("_msk", "")
    file_name = file_name.replace("_img", "")
    # file_name = file_name.replace(replace_for_mask, "")
    patch_path = patch_dir / f"{file_name}_{dataset}_{top}_{bottom}_{left}_{right}.tif"

    local_profile = src.profile.copy()
    local_profile.update(
        {"height": patch_size, "width": patch_size, "count": patch.shape[0]}
    )
    local_profile.update(
        {"transform": src.window_transform(window=((top, bottom), (left, right)))}
    )
    if label:
        patch[patch == 0] = background_value
        patch[patch == 1] = water_value
    if patch.shape[-2] != patch_size or patch.shape[-1] != patch_size:
        print(patch.shape)
        raise ValueError("Patch shape is not 512x512")

    with rio.open(patch_path, "w", **local_profile) as dst:
        dst.write(patch)

In [None]:
# for img in tqdm(s2_imgs):
def make_patches(img: Path) -> None:
    metadata = list(img.parent.glob("*meta.json"))[0]

    meta = json.load(metadata.open())
    dataset = meta["properties"]["split"]

    src = rio.open(img)
    img_array = src.read(image_bands)

    is_l2a = "L2A" in img.name

    label_path = img.parent / img.name.replace("img", "msk").replace("_L2A", "")

    valid_path = img.parent / img.name.replace("img", "valid").replace("_L2A", "")

    label_src = rio.open(label_path)
    label_array = label_src.read()
    valid_array = rio.open(valid_path).read()
    # where array is 0 set label to 99
    label_array[valid_array == 0] = 99

    assert img_array.shape[0] == expected_image_channels, (
        f"Expected {expected_image_channels} channels, got {img_array.shape[0]}"
    )

    top = 0
    while True:
        bottom = top + patch_size
        left = 0
        while True:
            right = left + patch_size

            extract_patch(
                input_array=img_array,
                top=top,
                bottom=bottom,
                left=left,
                right=right,
                input_raster_path=img,
                src=src,
                dataset=dataset,
                patch_dir=img_patches_dir,
                label=False,
            )
            if not is_l2a:  # avoid making duplicate labels for l2a images
                extract_patch(
                    input_array=label_array,
                    top=top,
                    bottom=bottom,
                    left=left,
                    right=right,
                    input_raster_path=label_path,
                    src=label_src,
                    dataset=dataset,
                    patch_dir=mask_patches_dir,
                    label=True,
                    background_value=background_value,
                    water_value=water_value,
                )

            left += stride
            right = left + patch_size

            if right >= label_array.shape[2]:
                break

        top += stride
        bottom = top + patch_size

        if bottom >= label_array.shape[1]:
            break


In [None]:
with ThreadPool(16) as p:
    results = list(tqdm(p.imap(make_patches, s2_imgs), total=len(s2_imgs)))

In [None]:
def check_imgs(patch_path):
    try:
        array = rio.open(patch_path).read()
    except Exception as e:
        print(f"{patch_path} failed check {e}")
        return
    try:
        assert array.shape[-2] == patch_size
        assert array.shape[-1] == patch_size
        # assert array.shape[0] == 6
    except AssertionError:
        print(f"{patch_path} failed check, {array.shape}")
    try:
        file_name = patch_path.name
        label_file_name = file_name.replace("_L2A", "")
        label_path = mask_patches_dir / label_file_name
        assert label_path.exists()
    except AssertionError:
        print(f"{label_path} does not exist")


def check_masks(patch_path):
    try:
        array = rio.open(patch_path).read()
    except Exception as e:
        print(f"{patch_path} failed check {e}")
        return
    try:
        assert array.shape[-2] == patch_size
        assert array.shape[-1] == patch_size
        assert array.shape[0] == 1
    except AssertionError:
        print(f"{patch_path} failed check, {array.shape}")
    try:
        assert np.all(np.isin(array, [0, 1, 99]))
    except AssertionError:
        print(f"{patch_path} failed check, {np.unique(array)}")

In [None]:
img_patches = list(img_patches_dir.rglob("*.tif"))
mask_patches = list(mask_patches_dir.rglob("*.tif"))
print(len(img_patches), len(mask_patches))

In [None]:
if do_L2A:
    assert len(img_patches) // 2 == len(mask_patches)
else:
    assert len(img_patches) == len(mask_patches)

In [None]:
with Pool(4) as p:
    results = list(tqdm(p.imap(check_imgs, img_patches), total=len(img_patches)))

  0%|          | 0/87880 [00:00<?, ?it/s]

In [32]:
with Pool(4) as p:
    results = list(tqdm(p.imap(check_masks, mask_patches), total=len(mask_patches)))

  0%|          | 0/43940 [00:00<?, ?it/s]