# FLAIR Dataset Processing for Water Segmentation

This notebook processes the [FLAIR dataset](https://ignf.github.io/FLAIR/) (French Land cover from Aerospace ImageRy) for water segmentation training.

## Overview

The notebook performs the following steps:

1. **Mosaic Creation**: Combines individual scene tiles from the FLAIR dataset into larger mosaic images for both labels and aerial imagery
2. **Label Remapping**: Converts the multi-class FLAIR labels to binary water masks (water class=5 and pool class=13 → 1, everything else → 0)
3. **Patch Generation**: Splits the mosaics into 512×512 patches at multiple resolutions (0.2m and 1m)
4. **Water Filtering**: Identifies patches that contain water pixels, filtering out ~74% of patches that have no water
5. **CSV Export**: Creates a manifest file (`water_patches.csv`) linking image patches to their corresponding label patches

## Configuration

- **Input**: FLAIR aerial imagery and labels (20cm resolution, RGF93/Lambert-93 projection - EPSG:2154)
- **Output**: 512×512 patches at 0.2m and 1m resolution
- **Water Classes**: Class 5 (water) and Class 13 (swimming pools) are merged into a single water class

In [1]:
import rasterio as rio
from pathlib import Path
import numpy as np
from rasterio.merge import merge
from affine import Affine
from tqdm.auto import tqdm
from multiprocessing import Pool
from functools import partial
import pandas as pd


In [2]:
data_dir = Path("/media/nick/4TB Working 6/Datasets/FLAIR_retrain/train")
input_labels = data_dir / "flair_labels_train"
input_images = data_dir / "flair_aerial_train"
patch_dir = data_dir / "FLAIR patches"
assert input_labels.exists()
assert input_images.exists()
patch_dir.mkdir(exist_ok=True, parents=True)

In [3]:
rio.open(list(input_labels.rglob("*.tif"))[0]).profile

{'driver': 'GTiff', 'dtype': 'uint8', 'nodata': None, 'width': 7680, 'height': 2560, 'count': 1, 'crs': CRS.from_wkt('PROJCS["RGF93 v1 / Lambert-93",GEOGCS["RGF93 v1",DATUM["Reseau_Geodesique_Francais_1993_v1",SPHEROID["GRS 1980",6378137,298.257222101,AUTHORITY["EPSG","7019"]],AUTHORITY["EPSG","6171"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4171"]],PROJECTION["Lambert_Conformal_Conic_2SP"],PARAMETER["latitude_of_origin",46.5],PARAMETER["central_meridian",3],PARAMETER["standard_parallel_1",49],PARAMETER["standard_parallel_2",44],PARAMETER["false_easting",700000],PARAMETER["false_northing",6600000],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH],AUTHORITY["EPSG","2154"]]'), 'transform': Affine(0.2, 0.0, 655202.0,
       0.0, -0.2, 6844909.00000001), 'blockxsize': 7680, 'blockysize': 16, 'tiled': False, 'compress': 'lzw', 'interleave': 'band'}

In [4]:
water = 5
pool = 13

In [5]:
def process_subfolder(args):
    """Process a single subfolder to create a mosaic."""
    subfolder, input_dir, dataset = args
    scene_tiles = list(subfolder.rglob("*.tif"))
    if not scene_tiles:
        return None

    src_files_to_mosaic = [rio.open(tile) for tile in scene_tiles]

    mosaic, out_trans = merge(src_files_to_mosaic, res=0.2, dtype=np.uint8)
    profile = src_files_to_mosaic[0].profile
    profile.update(
        {
            "height": mosaic.shape[1],
            "width": mosaic.shape[2],
            "transform": out_trans,
        }
    )
    profile["crs"] = rio.crs.CRS.from_epsg(2154)
    raster_name = f"{subfolder.parent.name}_{subfolder.name}_{dataset}.tif"
    out_path = input_dir / raster_name

    if dataset == "labels":
        new_array = np.zeros(mosaic.shape, dtype=np.uint8)
        new_array[mosaic == water] = 1
        new_array[mosaic == pool] = 1
        mosaic = new_array

    with rio.open(out_path, "w", **profile) as dst:
        dst.write(mosaic)

    for src in src_files_to_mosaic:
        src.close()

    return out_path


def mosaic_from_dir(input_dir: Path, dataset: str, num_workers: int = 8) -> list[Path]:
    # Collect all subfolders to process
    subfolders = []
    for folder in input_dir.glob("*"):
        for subfolder in folder.glob("*"):
            if subfolder.is_dir():
                subfolders.append((subfolder, input_dir, dataset))

    # Process in parallel
    with Pool(num_workers) as p:
        results = list(
            tqdm(p.imap(process_subfolder, subfolders), total=len(subfolders))
        )

    return [r for r in results if r is not None]

In [6]:
all_mosaics = mosaic_from_dir(input_labels, "labels")
all_mosaics = mosaic_from_dir(input_images, "images")

  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/757 [00:00<?, ?it/s]

In [7]:
def make_patches(file, patch_dir, resolutions, patch_size=512, stride=512):
    for res in resolutions:
        if "labels" in file.parent.name:
            type = "labels"
        elif "aerial" in file.parent.name:
            type = "images"
        else:
            raise ValueError("Unknown type")

        if "train" in file.parent.name:
            dataset = "train"
        elif "test" in file.parent.name:
            dataset = "test"
        else:
            raise ValueError("Unknown dataset")

        src = rio.open(file)
        width = src.width
        height = src.height
        native_res = src.res[0]

        count = src.profile["count"]
        array = src.read(
            out_shape=(
                count,
                int(height * (native_res / res)),
                int(width * (native_res / res)),
            )
        )
        height, width = array.shape[1], array.shape[2]
        if height < patch_size or width < patch_size:
            continue
        top = 0
        break_next_row = False
        while True:
            left = 0
            right = left + patch_size
            bottom = top + patch_size
            break_next_column = False
            while True:
                base_name = file.stem.replace("_labels", "").replace("_images", "")
                file_name = (
                    f"{base_name}_{top}_{bottom}_{left}_{right}_{res}m_{dataset}.tif"
                )
                out_path = patch_dir / f"{type}/ {file_name}"

                if not out_path.exists():
                    out_path.parent.mkdir(exist_ok=True, parents=True)
                    patch = array[:, top:bottom, left:right]
                    assert patch.shape[1] == patch_size
                    assert patch.shape[2] == patch_size
                    profile = src.profile
                    profile.update(
                        {
                            "height": patch.shape[1],
                            "width": patch.shape[2],
                            "transform": rio.windows.transform(
                                rio.windows.Window(
                                    left, top, patch.shape[2], patch.shape[1]
                                ),
                                Affine(
                                    res,
                                    0,
                                    src.transform[2],
                                    0,
                                    -res,
                                    src.transform[5],
                                ),
                            ),
                        }
                    )

                    with rio.open(out_path, "w", **profile) as dst:
                        dst.write(patch)

                if break_next_column:
                    break

                right += stride
                if right >= width:
                    right = width
                    break_next_column = True
                left = right - patch_size

            if break_next_row:
                break
            bottom += stride
            if bottom >= height:
                bottom = height
                break_next_row = True
            top = bottom - patch_size

In [8]:
for input_dir in [input_labels, input_images]:
    files = list(input_dir.glob("*.tif"))
    make_patches_partial = partial(
        make_patches, patch_dir=patch_dir, resolutions=[0.2, 1]
    )
    with Pool(8) as p:
        r = list(tqdm(p.imap(make_patches_partial, files), total=len(files)))

  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/757 [00:00<?, ?it/s]

In [9]:
labels_patch_dir = patch_dir / "labels"
image_patch_dir = patch_dir / "images"

In [10]:
all_labels = list(labels_patch_dir.glob("*.tif"))
len(all_labels)

64240

In [11]:
all_images = list(image_patch_dir.glob("*.tif"))
len(all_images)

64240

In [12]:
def check_for_water(label):
    src = rio.open(label)
    array = src.read()
    if array.sum() != 0:
        return label
    return None


with Pool(8) as p:
    r = list(tqdm(p.imap(check_for_water, all_labels), total=len(all_labels)))
r = [x for x in r if x is not None]

  0%|          | 0/64240 [00:00<?, ?it/s]

In [13]:
def get_image_path(label_path: Path) -> Path:
    return Path(str(label_path).replace("labels", "images"))


images_with_water_df = pd.DataFrame(r, columns=["label_path"])
images_with_water_df["image_path"] = images_with_water_df["label_path"].apply(
    get_image_path
)

In [14]:
images_with_water_df

Unnamed: 0,label_path,image_path
0,/media/nick/4TB Working 6/Datasets/FLAIR_retra...,/media/nick/4TB Working 6/Datasets/FLAIR_retra...
1,/media/nick/4TB Working 6/Datasets/FLAIR_retra...,/media/nick/4TB Working 6/Datasets/FLAIR_retra...
2,/media/nick/4TB Working 6/Datasets/FLAIR_retra...,/media/nick/4TB Working 6/Datasets/FLAIR_retra...
3,/media/nick/4TB Working 6/Datasets/FLAIR_retra...,/media/nick/4TB Working 6/Datasets/FLAIR_retra...
4,/media/nick/4TB Working 6/Datasets/FLAIR_retra...,/media/nick/4TB Working 6/Datasets/FLAIR_retra...
...,...,...
16879,/media/nick/4TB Working 6/Datasets/FLAIR_retra...,/media/nick/4TB Working 6/Datasets/FLAIR_retra...
16880,/media/nick/4TB Working 6/Datasets/FLAIR_retra...,/media/nick/4TB Working 6/Datasets/FLAIR_retra...
16881,/media/nick/4TB Working 6/Datasets/FLAIR_retra...,/media/nick/4TB Working 6/Datasets/FLAIR_retra...
16882,/media/nick/4TB Working 6/Datasets/FLAIR_retra...,/media/nick/4TB Working 6/Datasets/FLAIR_retra...


In [15]:
images_with_water_df.to_csv(
    patch_dir.parent / "FLAIR_patches_with_water.csv", index=False
)