# Wetlands GeoAI End-to-End Workflow

This notebook walks through the full wetlands modelling pipeline:

1. Acquire NAIP and Sentinel-2 composites for an area of interest (AOI) and build stack manifests.
2. Prepare training data and fit the UNet semantic segmentation model.
3. Run streaming inference on a hold-out AOI using the trained model.
4. Compare predictions to National Wetlands Inventory (NWI) polygons on an interactive `leafmap` basemap.

The workflow leans on reusable modules in `wetlands_ml_geoai` and is designed to stream large rasters from stack manifests without exhausting system memory.


## Prerequisites

- Activate the project virtual environment and install dependencies (`pip install -r requirements.txt`).
- Ensure the `geoai` and `leafmap` packages are installed (both are already listed in `requirements.txt`).
- Provide AOI geometry files and wetlands labels locally (place them under the git-ignored `data/` directory).
- Configure STAC access (public Element84 Sentinel-2 catalogue works without authentication).


In [7]:
from __future__ import annotations

import json
import logging
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional

import geopandas as gpd
import leafmap
import numpy as np
import rasterio
from rasterio.features import shapes

NOTEBOOK_ROOT = Path.cwd().resolve()
SRC_PATH = NOTEBOOK_ROOT / "src"
if not SRC_PATH.exists():
    SRC_PATH = NOTEBOOK_ROOT.parent / "src"
if SRC_PATH.exists() and str(SRC_PATH) not in sys.path:
    sys.path.insert(0, str(SRC_PATH))

from wetlands_ml_geoai.sentinel2.compositing import run_pipeline
from wetlands_ml_geoai.training.unet import train_unet
from wetlands_ml_geoai.inference.unet_stream import infer_manifest

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")


In [8]:
# Resolve project directories (assumes the notebook lives under <repo_root>/notebooks)
REPO_ROOT = Path.cwd().resolve()
if not (REPO_ROOT / "src" / "wetlands_ml_geoai").exists():
    REPO_ROOT = REPO_ROOT.parent

DATA_DIR = REPO_ROOT / "data"
ARTIFACTS_DIR = REPO_ROOT / "outputs"
DATA_DIR.mkdir(parents=True, exist_ok=True)
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

print(f"Repository root: {REPO_ROOT}")
print(f"Data directory: {DATA_DIR}")
print(f"Artifacts directory: {ARTIFACTS_DIR}")


Repository root: C:\_code\python\wetlands_ml_codex
Data directory: C:\_code\python\wetlands_ml_codex\data
Artifacts directory: C:\_code\python\wetlands_ml_codex\outputs


In [None]:
@dataclass
class PipelineConfig:
    identifier: str
    aoi: str
    years: List[int]
    output_dir: Path
    labels_path: Path
    wetlands_path: Optional[Path] = None
    auto_download_naip: bool = True
    auto_download_wetlands: bool = True
    naip_year: Optional[int] = None
    cloud_cover: float = 60.0
    min_clear_obs: int = 3
    seasons: tuple[str, ...] = ("SPR", "SUM", "FAL")
    naip_max_items: Optional[int] = None
    naip_preview: bool = False
    mask_dilation: int = 0
    naip_target_resolution: Optional[float] = None


@dataclass
class TrainingConfig:
    tiles_dir: Path
    models_dir: Path
    tile_size: int = 512
    stride: int = 256
    batch_size: int = 4
    epochs: int = 10
    architecture: str = "unet"
    encoder_name: str = "resnet34"
    learning_rate: float = 1e-3
    weight_decay: float = 1e-4
    val_split: float = 0.2


@dataclass
class InferenceConfig:
    window_size: int = 1024
    overlap: int = 256
    probability_threshold: float = 0.5
    num_classes: int = 2
    architecture: str = "unet"
    encoder_name: str = "resnet34"


In [None]:
# --- User configuration ---
# Update these paths to point to your local data assets.
TRAIN_AOI_CONFIG = PipelineConfig(
    identifier="train_aoi",
    aoi=str(DATA_DIR / "aoi" / "train_extent.gpkg"),  # Replace with GPkg/WKT/GeoJSON/bbox
    years=[2023],
    output_dir=ARTIFACTS_DIR / "train_aoi",
    labels_path=DATA_DIR / "labels" / "train_wetlands.gpkg",  # Provide NWI (or custom) training labels
    wetlands_path=DATA_DIR / "labels" / "train_wetlands.gpkg",
    auto_download_naip=True,
    auto_download_wetlands=True,
    naip_year=2021,
    naip_max_items=12,
)

TEST_AOI_CONFIG = PipelineConfig(
    identifier="test_aoi",
    aoi=str(DATA_DIR / "aoi" / "test_extent.gpkg"),
    years=[2023],
    output_dir=ARTIFACTS_DIR / "test_aoi",
    labels_path=DATA_DIR / "labels" / "test_wetlands.gpkg",  # Optional: used for evaluation overlay
    wetlands_path=DATA_DIR / "labels" / "test_wetlands.gpkg",
    auto_download_naip=True,
    auto_download_wetlands=True,
    naip_year=2021,
    naip_max_items=8,
)

TRAINING_CONFIG = TrainingConfig(
    tiles_dir=ARTIFACTS_DIR / "tiles" / TRAIN_AOI_CONFIG.identifier,
    models_dir=ARTIFACTS_DIR / "models" / TRAIN_AOI_CONFIG.identifier,
    epochs=15,
    batch_size=6,
)

INFERENCE_CONFIG = InferenceConfig(
    window_size=1536,
    overlap=384,
    probability_threshold=0.6,
)


In [None]:
RUN_TRAIN_PIPELINE = False   # Set to True to generate Sentinel-2 composites and stack manifests for training AOI
RUN_TEST_PIPELINE = False    # Set to True to generate stack manifest for the inference AOI
RUN_TRAINING = False         # Set to True to launch UNet training (may take hours on CPU)
RUN_INFERENCE = False        # Set to True to run streaming inference on the test AOI



In [None]:
def ensure_paths_exist(config: PipelineConfig) -> None:
    missing = []
    if not Path(config.aoi).exists():
        missing.append(config.aoi)
    if config.wetlands_path and not config.wetlands_path.exists():
        logging.warning("Wetlands GPkg missing at %s; auto-download is enabled." % config.wetlands_path)
    if missing:
        raise FileNotFoundError(
            "Missing required inputs: " + ", ".join(missing)
        )


def run_stack_pipeline(cfg: PipelineConfig) -> None:
    ensure_paths_exist(cfg)
    cfg.output_dir.mkdir(parents=True, exist_ok=True)

    logging.info("Starting Sentinel-2 + NAIP pipeline for %s", cfg.identifier)
    run_pipeline(
        aoi=cfg.aoi,
        years=cfg.years,
        output_dir=cfg.output_dir,
        seasons=cfg.seasons,
        cloud_cover=cfg.cloud_cover,
        min_clear_obs=cfg.min_clear_obs,
        auto_download_naip=cfg.auto_download_naip,
        auto_download_naip_year=cfg.naip_year,
        auto_download_naip_max_items=cfg.naip_max_items,
        auto_download_naip_preview=cfg.naip_preview,
        auto_download_wetlands=cfg.auto_download_wetlands,
        wetlands_output_path=cfg.wetlands_path,
        mask_dilation=cfg.mask_dilation,
        naip_target_resolution=cfg.naip_target_resolution,
    )
    logging.info("Finished pipeline for %s", cfg.identifier)


def discover_manifests(directory: Path) -> List[Path]:
    return sorted(directory.rglob("stack_manifest.json"))


def choose_model_checkpoint(models_dir: Path, preferred_name: str = "best_model.pth") -> Path:
    candidate = models_dir / preferred_name
    if candidate.exists():
        return candidate
    checkpoints = sorted(models_dir.glob("*.pth"), key=lambda p: p.stat().st_mtime, reverse=True)
    if not checkpoints:
        raise FileNotFoundError(f"No checkpoint files found under {models_dir}")
    logging.warning("Using most recent checkpoint: %s", checkpoints[0])
    return checkpoints[0]


In [None]:
if RUN_TRAIN_PIPELINE:
    run_stack_pipeline(TRAIN_AOI_CONFIG)
else:
    logging.info("Skipping training AOI pipeline; set RUN_TRAIN_PIPELINE = True to execute.")


In [None]:
if RUN_TEST_PIPELINE:
    run_stack_pipeline(TEST_AOI_CONFIG)
else:
    logging.info("Skipping test AOI pipeline; set RUN_TEST_PIPELINE = True to execute.")


In [None]:
train_manifests = discover_manifests(TRAIN_AOI_CONFIG.output_dir)
if not train_manifests:
    logging.warning("No training manifests detected under %s", TRAIN_AOI_CONFIG.output_dir)
else:
    for manifest in train_manifests:
        logging.info("Training manifest -> %s", manifest)



In [None]:
if RUN_TRAINING:
    if not train_manifests:
        raise RuntimeError("Training manifests missing. Generate them by enabling RUN_TRAIN_PIPELINE.")
    if not TRAIN_AOI_CONFIG.labels_path.exists():
        raise FileNotFoundError(f"Training labels not found at {TRAIN_AOI_CONFIG.labels_path}")

    TRAINING_CONFIG.tiles_dir.mkdir(parents=True, exist_ok=True)
    TRAINING_CONFIG.models_dir.mkdir(parents=True, exist_ok=True)

    logging.info("Launching UNet training with %s manifest(s).", len(train_manifests))
    train_unet(
        labels_path=TRAIN_AOI_CONFIG.labels_path,
        tiles_dir=TRAINING_CONFIG.tiles_dir,
        models_dir=TRAINING_CONFIG.models_dir,
        stack_manifest_path=train_manifests,
        tile_size=TRAINING_CONFIG.tile_size,
        stride=TRAINING_CONFIG.stride,
        batch_size=TRAINING_CONFIG.batch_size,
        epochs=TRAINING_CONFIG.epochs,
        architecture=TRAINING_CONFIG.architecture,
        encoder_name=TRAINING_CONFIG.encoder_name,
        learning_rate=TRAINING_CONFIG.learning_rate,
        weight_decay=TRAINING_CONFIG.weight_decay,
        val_split=TRAINING_CONFIG.val_split,
    )
else:
    logging.info("Skipping training; set RUN_TRAINING = True to launch model fitting.")


In [None]:
test_manifests = discover_manifests(TEST_AOI_CONFIG.output_dir)
if not test_manifests:
    logging.warning("No test manifests detected under %s", TEST_AOI_CONFIG.output_dir)
else:
    for manifest in test_manifests:
        logging.info("Test manifest -> %s", manifest)



In [None]:
prediction_dir = ARTIFACTS_DIR / "predictions"
prediction_dir.mkdir(parents=True, exist_ok=True)

prediction_raster = prediction_dir / f"{TEST_AOI_CONFIG.identifier}_unet_prediction.tif"

if RUN_INFERENCE:
    if not test_manifests:
        raise RuntimeError("No test manifest found. Enable RUN_TEST_PIPELINE to generate one.")
    model_checkpoint = choose_model_checkpoint(TRAINING_CONFIG.models_dir)
    logging.info("Using checkpoint: %s", model_checkpoint)

    infer_manifest(
        manifest=test_manifests[0],
        model_path=model_checkpoint,
        output_path=prediction_raster,
        window_size=INFERENCE_CONFIG.window_size,
        overlap=INFERENCE_CONFIG.overlap,
        num_channels=None,
        architecture=INFERENCE_CONFIG.architecture,
        encoder_name=INFERENCE_CONFIG.encoder_name,
        num_classes=INFERENCE_CONFIG.num_classes,
        probability_threshold=INFERENCE_CONFIG.probability_threshold,
    )
    logging.info("Inference complete -> %s", prediction_raster)
else:
    logging.info("Skipping inference; set RUN_INFERENCE = True after training completes.")


In [None]:
def raster_prediction_to_polygons(raster_path: Path, class_value: int = 1) -> gpd.GeoDataFrame:
    if not raster_path.exists():
        raise FileNotFoundError(f"Prediction raster not found: {raster_path}")
    with rasterio.open(raster_path) as src:
        mask = src.read(1)
        results = []
        for geom, value in shapes(mask.astype(np.uint8), transform=src.transform):
            if int(value) == class_value:
                results.append({"geometry": geom, "properties": {"class": int(value)}})
        if not results:
            logging.warning("No polygons extracted for class %s", class_value)
            return gpd.GeoDataFrame(columns=["geometry"], geometry="geometry", crs=src.crs)
        gdf = gpd.GeoDataFrame.from_features(results, crs=src.crs)
        return gdf


def load_wetlands_labels(path: Path) -> gpd.GeoDataFrame:
    if not path.exists():
        raise FileNotFoundError(f"Wetlands labels not found: {path}")
    gdf = gpd.read_file(path)
    if gdf.crs is None:
        logging.warning("Labels at %s lack CRS metadata; assuming EPSG:4326", path)
        gdf.set_crs("EPSG:4326", inplace=True)
    return gdf



In [None]:
if prediction_raster.exists():
    predicted_polys = raster_prediction_to_polygons(prediction_raster)
    if predicted_polys.crs is not None and predicted_polys.crs.to_string() != "EPSG:4326":
        predicted_polys = predicted_polys.to_crs(4326)

    try:
        reference_labels = load_wetlands_labels(TEST_AOI_CONFIG.labels_path)
        if reference_labels.crs.to_string() != "EPSG:4326":
            reference_labels = reference_labels.to_crs(4326)
    except FileNotFoundError:
        reference_labels = None
        logging.warning("Reference wetlands not available; map will show predictions only.")

    if reference_labels is not None and not reference_labels.empty:
        centroid = reference_labels.unary_union.centroid
    elif not predicted_polys.empty:
        centroid = predicted_polys.unary_union.centroid
    else:
        centroid = None

    if centroid is None:
        raise RuntimeError("Map extent unavailable; ensure predictions or labels are present.")

    map_center = (centroid.y, centroid.x)
    m = leafmap.Map(center=map_center, zoom=12, measure_control=False)
    m.add_basemap("Esri.WorldImagery")

    if reference_labels is not None and not reference_labels.empty:
        m.add_gdf(
            reference_labels,
            layer_name="NWI Wetlands",
            style={"color": "#3182bd", "fillColor": "#3182bd", "fillOpacity": 0.5, "weight": 1},
        )

    if not predicted_polys.empty:
        m.add_gdf(
            predicted_polys,
            layer_name="Predicted Wetlands",
            style={"color": "#e377c2", "fillColor": "#e377c2", "fillOpacity": 0.0, "weight": 2},
        )

    display(m)
else:
    logging.info("Prediction raster not yet generated. Run inference to create %s", prediction_raster)


## Next Steps

- Validate that stack manifests include the expected 25-band (NAIP + 3-season Sentinel-2) configuration before training.
- Inspect model checkpoints under `outputs/models/<identifier>` and track metrics logged by `geoai`.
- After generating predictions, compute quantitative scores (IOU, precision/recall) by rasterizing ground-truth polygons onto the manifest grid.
- Package results or derived layers in the `outputs/` directory; avoid committing rasters or GeoPackages to git.
