In [23]:
import os
import random
import numpy as np
import torch
import rasterio
from PIL import Image
from torchvision import transforms

In [24]:
# Paths
input_root = r"E:\Sentinel-2_mosaics"
output_dir = r"E:\processed_tiles"
os.makedirs(f"{output_dir}/train/images", exist_ok=True)
os.makedirs(f"{output_dir}/train/masks", exist_ok=True)
os.makedirs(f"{output_dir}/test/images", exist_ok=True)
os.makedirs(f"{output_dir}/test/masks", exist_ok=True)

In [25]:
# NDWI Calculation
def compute_ndwi(green, nir):
    denom = green + nir
    denom[denom == 0] = 1e-6  # avoid division by zero
    return (green - nir) / denom

In [26]:
# Load Sentinel-2 bands
def load_bands(folder):
    band_files = {'B02': None, 'B03': None, 'B04': None, 'B08': None}
    for file in os.listdir(folder):
        for band in band_files.keys():
            if file.endswith(f"{band}.tif"):
                band_files[band] = os.path.join(folder, file)

    with rasterio.open(band_files['B02']) as f: B02 = f.read(1).astype(np.float32) / 10000
    with rasterio.open(band_files['B03']) as f: B03 = f.read(1).astype(np.float32) / 10000
    with rasterio.open(band_files['B04']) as f: B04 = f.read(1).astype(np.float32) / 10000
    with rasterio.open(band_files['B08']) as f: B08 = f.read(1).astype(np.float32) / 10000

    return B02, B03, B04, B08

In [27]:
# Slice into tiles
def slice_tiles(img, mask, size=128, allow_blank=False):
    img_tiles, mask_tiles, positions = [], [], []
    width, height = img.size
    for y in range(0, height, size):
        for x in range(0, width, size):
            if x + size <= width and y + size <= height:
                box = (x, y, x + size, y + size)
                tile = img.crop(box)
                tile_mask = mask.crop(box)
                if allow_blank or np.array(tile).mean() > 25:
                    img_tiles.append(tile)
                    mask_tiles.append(tile_mask)
                    positions.append((x, y))
    return img_tiles, mask_tiles, positions

In [28]:
# Save PyTorch Tensors
def save_tensor_tiles(tiles, masks, positions, counter, folder_name, minmax_tracker, include_xy=False, dataset_type="train"):
    for img, mask, (x, y) in zip(tiles, masks, positions):
        img_np = np.array(img)
        mask_np = np.array(mask)
        water_ratio = np.mean(mask_np > 0)
        water_pct = int(round(water_ratio * 100))

        # Filenaming scheme
        if include_xy:
            filename = f"num{counter:05d}_x{x}_y{y}_water{water_pct:02d}.pt"
        else:
            filename = f"num{counter:05d}_water{water_pct:02d}.pt"

        # Convert to normalised torch tensors
        img_tensor = torch.tensor(img_np, dtype=torch.float32).unsqueeze(0) / 255.0
        mask_tensor = torch.tensor(mask_np, dtype=torch.float32).unsqueeze(0) / 255.0

        # Define save path
        torch.save(img_tensor, os.path.join(output_dir, dataset_type, "images", f"img_{filename}"))
        torch.save(mask_tensor, os.path.join(output_dir, dataset_type, "masks", f"msk_{filename}"))

        # Track non-zero stats (Should remove to speed up process)
        nonzero = img_np[img_np > 0]
        if nonzero.size > 0:
            min_val = nonzero.min()
            max_val = nonzero.max()
            minmax_tracker['min'] = min(min_val, minmax_tracker['min'])
            minmax_tracker['max'] = max(max_val, minmax_tracker['max'])

        counter += 1
    return counter

In [29]:
# Main Processing Loop
counter = 0
tile_size = 128
minmax_tracker = {'min': float('inf'), 'max': float('-inf')}

for folder_name in os.listdir(input_root):
    folder_path = os.path.join(input_root, folder_name)
    if not os.path.isdir(folder_path):
        continue

    print(f"Processing: {folder_name}")
    B02, B03, B04, B08 = load_bands(folder_path)
    ndwi = compute_ndwi(B03, B08)
    mask = (ndwi > 0.2).astype(np.uint8) * 255

    rgb = np.stack([B04, B03, B02], axis=-1)
    rgb = np.clip(rgb * 255, 0, 255).astype(np.uint8)

    pil_img = Image.fromarray(rgb)
    pil_mask = Image.fromarray(mask)

    # Save full test tiles (no filtering or augmentation)
    test_tiles, test_masks, test_positions = slice_tiles(pil_img, pil_mask, tile_size, allow_blank=True)
    counter = save_tensor_tiles(test_tiles, test_masks, test_positions, counter, folder_name, minmax_tracker, include_xy=True, dataset_type="test")

    # Save 3x training augments (random augmentations)
    for _ in range(3):
        jittered = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1)(pil_img)
        aug_img, aug_mask = transform_pair(jittered, pil_mask)
        train_tiles, train_masks, train_positions = slice_tiles(aug_img, aug_mask, tile_size)
        counter = save_tensor_tiles(train_tiles, train_masks, train_positions, counter, folder_name, minmax_tracker, include_xy=False, dataset_type="train")

# Additional information
print(f"\nDone. Total image/mask tiles saved: {counter}")
print(f"Non-zero pixel value range across all image tiles: min = {minmax_tracker['min']:.3f}, max = {minmax_tracker['max']:.3f}")

Processing: Sentinel-2_mosaic_2020_Q1_47QPF_0_0
Processing: Sentinel-2_mosaic_2020_Q1_47QPG_0_0
Processing: Sentinel-2_mosaic_2020_Q1_47RPH_0_0
Processing: Sentinel-2_mosaic_2020_Q1_48RTQ_0_0
Processing: Sentinel-2_mosaic_2020_Q1_48RTR_0_0
Processing: Sentinel-2_mosaic_2020_Q1_49SER_0_0
Processing: Sentinel-2_mosaic_2020_Q1_49SES_0_0
Processing: Sentinel-2_mosaic_2020_Q1_50QKM_0_0
Processing: Sentinel-2_mosaic_2020_Q1_50QLM_0_0
Processing: Sentinel-2_mosaic_2020_Q1_50RQS_0_0
Processing: Sentinel-2_mosaic_2025_Q1_47QPF_0_0
Processing: Sentinel-2_mosaic_2025_Q1_47QPG_0_0
Processing: Sentinel-2_mosaic_2025_Q1_47RNH_0_0
Processing: Sentinel-2_mosaic_2025_Q1_47RPH_0_0
Processing: Sentinel-2_mosaic_2025_Q1_48RTQ_0_0
Processing: Sentinel-2_mosaic_2025_Q1_48RTR_0_0
Processing: Sentinel-2_mosaic_2025_Q1_49SER_0_0
Processing: Sentinel-2_mosaic_2025_Q1_49SES_0_0
Processing: Sentinel-2_mosaic_2025_Q1_49SFU_0_0
Processing: Sentinel-2_mosaic_2025_Q1_50QKM_0_0
Processing: Sentinel-2_mosaic_2025_Q1_50

Attempted to use other file formats such as .png, they took up less storage (less than half) however, processing times were considerably greater.

The minmax-tracker added considerable time - would remove if creating the slices again.