# Flood Detection Inference

Run water/cloud classification on Sentinel-2 tiles using the WorldFloods UNet model.

> For implementation details, see `03_research_inference.ipynb`.

**Prerequisites**

1. Run `05_maji_download.ipynb` to download Sentinel-2 GeoTIFF tiles
2. Download model weights to `models/WF2_unet_rbgiswirs/model.pt`

**Output Classes**

| Value | Class | Description |
|-------|-------|-------------|
| 0 | Invalid | NoData / out of bounds |
| 1 | Land | Dry land |
| 2 | Water | Detected water |
| 3 | Cloud | Cloud or cloud shadow |

In [None]:
%matplotlib inline

from pathlib import Path

import numpy as np
import rasterio
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.patches import Rectangle, Patch
import torch
from tqdm.auto import tqdm

from maji import (
    load_model,
    run_inference,
    normalize_tile,
    classify_prediction,
    CLASS_NAMES,
    DEFAULT_MODEL_BANDS,
)
from maji.viz import (
    create_rgb_composite,
    get_classification_cmap,
    CLASSIFICATION_COLORS,
)

# Auto-detect device: CUDA > MPS > CPU
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

print(f"Using device: {DEVICE}")
print(f"Model bands: {DEFAULT_MODEL_BANDS}")

## Configuration

In [None]:
# --- Mode ---
# True = demo mode (process single patch for quick testing)
# False = full tile (process all patches)
SINGLE_PATCH = False

# --- Inference Parameters ---
PATCH_SIZE = 1024   # Patch dimensions (must be divisible by 16 for UNet)
OVERLAP = 64        # Overlap between adjacent patches for stitching

# --- Classification Thresholds ---
TH_WATER = 0.50      # Water probability threshold
TH_CLOUD = 0.98      # Cloud probability threshold

# --- Post-Processing ---
USE_SPECTRAL_POSTPROCESS = True   # Spectral water detection for ocean areas
NIR_WATER_THRESHOLD = 1500        # DN threshold (raised from 500, accounts for atm baseline)
MNDWI_THRESHOLD = -0.1            # MNDWI threshold (lowered from 0.0, ocean can be slightly negative)
SWIR_FLATNESS_THRESHOLD = 0.10    # Max RELATIVE difference between SWIR1/SWIR2 (10%)
SWIR_MAX_THRESHOLD = 1500         # Max SWIR value for water (excludes bright land)

# --- Paths ---
DATA_DIR = Path("../DATA")
MODEL_DIR = Path("../models/WF2_unet_rbgiswirs")

print(f"Mode: {'SINGLE_PATCH (demo)' if SINGLE_PATCH else 'FULL_TILE'}")
print(f"Patch size: {PATCH_SIZE}x{PATCH_SIZE}")
print(f"Overlap: {OVERLAP}")
print(f"Thresholds: water={TH_WATER}, cloud={TH_CLOUD}")
print(f"Spectral post-processing: {USE_SPECTRAL_POSTPROCESS}")

## Load GeoTIFF

Auto-discover tiles from `05_maji_download.ipynb` output.

In [None]:
# Find all GeoTIFF files in DATA directory
tif_files = sorted(DATA_DIR.glob("*/*.tif"))

if not tif_files:
    raise FileNotFoundError(
        f"No GeoTIFF files found in {DATA_DIR}. "
        "Run 05_maji_download.ipynb first."
    )

print(f"Found {len(tif_files)} GeoTIFF file(s):")
for f in tif_files:
    size_mb = f.stat().st_size / 1e6
    print(f"  {f.relative_to(DATA_DIR)} ({size_mb:.1f} MB)")

# Use the first file (or modify to select a specific one)
GEOTIFF_PATH = tif_files[0]
print(f"\nUsing: {GEOTIFF_PATH.name}")

In [None]:
# Open file and display metadata
src = rasterio.open(GEOTIFF_PATH)

print(f"GeoTIFF: {GEOTIFF_PATH.name}")
print(f"Dimensions: {src.width} x {src.height} pixels")
print(f"Bands: {src.count}")
print(f"CRS: {src.crs}")
print(f"Dtype: {src.dtypes[0]}")
print()
print("Band descriptions:")
for i in range(1, src.count + 1):
    desc = src.descriptions[i - 1] or f"Band {i}"
    print(f"  {i}: {desc}")

## Tile Overview

Generate a downsampled thumbnail for visualization.

In [None]:
# Create 10x downsampled thumbnail
THUMBNAIL_SCALE = 10
thumbnail_height = src.height // THUMBNAIL_SCALE
thumbnail_width = src.width // THUMBNAIL_SCALE

# Read bands for false color composite at reduced resolution
# Band indices: B02=1, B03=2, B04=3, B08=4, B11=5, B12=6 (1-indexed)
# False color: B11 (SWIR), B08 (NIR), B03 (Green) -> indices 5, 4, 2
thumbnail_raw = src.read(
    [5, 4, 2],  # B11, B08, B03
    out_shape=(3, thumbnail_height, thumbnail_width)
).astype(np.float32)

# Percentile stretch for display
thumbnail_display = np.zeros((thumbnail_height, thumbnail_width, 3), dtype=np.float32)
for i in range(3):
    p2, p98 = np.percentile(thumbnail_raw[i], [2, 98])
    thumbnail_display[:, :, i] = np.clip(
        (thumbnail_raw[i] - p2) / (p98 - p2 + 1e-8), 0, 1
    )

print(f"Thumbnail: {thumbnail_width} x {thumbnail_height} pixels")

# Display
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(thumbnail_display)
ax.set_title(f"False Color (SWIR-NIR-Green)\n{GEOTIFF_PATH.name}")
ax.axis("off")
plt.tight_layout()
plt.show()

## Patch Grid Calculation

Calculate the grid of patch positions covering the full tile.

In [None]:
def calculate_patch_grid(height, width, patch_size, overlap):
    """Calculate grid of (row, col) positions for patches.
    
    Parameters
    ----------
    height : int
        Image height in pixels.
    width : int
        Image width in pixels.
    patch_size : int
        Size of square patches.
    overlap : int
        Overlap between adjacent patches.
        
    Returns
    -------
    list[tuple[int, int]]
        List of (row, col) starting positions.
    """
    stride = patch_size - overlap
    patches = []
    
    for row in range(0, height, stride):
        for col in range(0, width, stride):
            # Adjust final row/col to not exceed bounds
            if row + patch_size > height:
                row = max(0, height - patch_size)
            if col + patch_size > width:
                col = max(0, width - patch_size)
            patches.append((row, col))
    
    # Remove duplicates while preserving order
    seen = set()
    unique_patches = []
    for p in patches:
        if p not in seen:
            seen.add(p)
            unique_patches.append(p)
    
    return unique_patches


# Calculate patch grid
patches = calculate_patch_grid(src.height, src.width, PATCH_SIZE, OVERLAP)

# Calculate grid dimensions
stride = PATCH_SIZE - OVERLAP
n_rows = len(set(p[0] for p in patches))
n_cols = len(set(p[1] for p in patches))

print(f"Image: {src.width} x {src.height} pixels")
print(f"Patch size: {PATCH_SIZE} x {PATCH_SIZE}")
print(f"Stride: {stride}")
print(f"Grid: {n_cols} x {n_rows} = {len(patches)} patches")

## Patch Selection

Find an "interesting" patch with high variance (not empty or uniform).

In [None]:
def find_interesting_patch(src, patches, patch_size, sample_band=4):
    """Find patch with highest standard deviation (most content).
    
    Parameters
    ----------
    src : rasterio.DatasetReader
        Open rasterio file handle.
    patches : list[tuple[int, int]]
        List of (row, col) patch positions.
    patch_size : int
        Size of patches.
    sample_band : int
        Band index (1-indexed) to use for variance calculation.
        Default 4 = B08 (NIR) which has good contrast.
        
    Returns
    -------
    tuple[int, int]
        (row, col) of the most interesting patch.
    """
    best_patch = patches[0]
    best_std = 0
    
    for row, col in patches:
        window = rasterio.windows.Window(col, row, patch_size, patch_size)
        data = src.read(sample_band, window=window)
        
        # Skip if mostly nodata (zeros)
        valid_ratio = np.mean(data > 0)
        if valid_ratio < 0.5:
            continue
            
        std = np.std(data[data > 0])
        if std > best_std:
            best_std = std
            best_patch = (row, col)
    
    return best_patch


# Find the most interesting patch for demo mode
interesting_patch = find_interesting_patch(src, patches, PATCH_SIZE)
print(f"Most interesting patch: row={interesting_patch[0]}, col={interesting_patch[1]}")

In [None]:
# Select patches to process
if SINGLE_PATCH:
    selected_patches = [interesting_patch]
    print(f"SINGLE_PATCH mode: processing 1 patch")
else:
    selected_patches = patches
    print(f"FULL_TILE mode: processing {len(patches)} patches")

# Visualize selected patch(es) on thumbnail
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(thumbnail_display)

# Scale factor for drawing rectangles
scale_y = thumbnail_display.shape[0] / src.height
scale_x = thumbnail_display.shape[1] / src.width

# Draw all patches in light color
for row, col in patches:
    rect = Rectangle(
        (col * scale_x, row * scale_y),
        PATCH_SIZE * scale_x, PATCH_SIZE * scale_y,
        fill=False, edgecolor="white", linewidth=0.3, alpha=0.3
    )
    ax.add_patch(rect)

# Highlight selected patches in red
for row, col in selected_patches:
    rect = Rectangle(
        (col * scale_x, row * scale_y),
        PATCH_SIZE * scale_x, PATCH_SIZE * scale_y,
        fill=False, edgecolor="red", linewidth=2
    )
    ax.add_patch(rect)

title = f"Selected: {len(selected_patches)} patch(es)"
if SINGLE_PATCH:
    title += f" at ({interesting_patch[0]}, {interesting_patch[1]})"
ax.set_title(title)
ax.axis("off")
plt.tight_layout()
plt.show()

## Load Model

In [None]:
# Load pre-trained model
model = load_model(
    weights_path=MODEL_DIR / "model.pt",
    config_path=MODEL_DIR / "config.json",
    device=DEVICE,
)

param_count = sum(p.numel() for p in model.parameters())
print(f"Model loaded: {param_count:,} parameters on {DEVICE}")

## Run Inference

In [None]:
def create_blend_weights(height, width, margin):
    """Create smooth 2D blend weights using raised cosine (Hann) window.
    
    Uses a separable Hann window that provides smooth transitions with
    zero derivatives at patch boundaries, eliminating visible seams.
    
    Parameters
    ----------
    height : int
        Patch height in pixels.
    width : int
        Patch width in pixels.
    margin : int
        Size of the taper region at each edge. The taper uses a raised
        cosine function for smooth blending.
        
    Returns
    -------
    np.ndarray
        2D array of blend weights in [0, 1].
    """
    # Create 1D Hann (raised cosine) tapers for each dimension
    def hann_taper(size, margin):
        """Create 1D weights with Hann taper at edges."""
        weights = np.ones(size, dtype=np.float32)
        # Limit margin to half the size
        m = min(margin, size // 2)
        if m > 0:
            # Raised cosine taper: 0.5 * (1 - cos(pi * x / margin))
            # This has zero derivative at x=0 and x=margin
            taper = 0.5 * (1 - np.cos(np.pi * np.arange(m) / m))
            weights[:m] = taper
            weights[-m:] = taper[::-1]
        return weights
    
    # Create separable 2D weights
    weights_y = hann_taper(height, margin)
    weights_x = hann_taper(width, margin)
    
    # Outer product gives 2D weights
    weights_2d = np.outer(weights_y, weights_x)
    
    return weights_2d


def run_inference_on_patches(src, model, patches, patch_size, overlap, device):
    """Process patches and return stitched probability maps.
    
    Uses weighted blending with Hann window for seamless stitching.
    
    Parameters
    ----------
    src : rasterio.DatasetReader
        Open rasterio file handle.
    model : torch.nn.Module
        Loaded UNet model.
    patches : list[tuple[int, int]]
        List of (row, col) patch positions.
    patch_size : int
        Size of patches.
    overlap : int
        Overlap between adjacent patches.
    device : torch.device
        Device for inference.
        
    Returns
    -------
    water_prob : np.ndarray
        Water probability map.
    cloud_prob : np.ndarray
        Cloud probability map.
    invalid_mask : np.ndarray
        Boolean mask of invalid (nodata) pixels.
    bounds : tuple
        (min_row, min_col, max_row, max_col) of processed region.
    """
    # Determine output bounds from patch list
    min_row = min(p[0] for p in patches)
    min_col = min(p[1] for p in patches)
    max_row = max(p[0] for p in patches) + patch_size
    max_col = max(p[1] for p in patches) + patch_size
    
    # Clip to image bounds
    max_row = min(max_row, src.height)
    max_col = min(max_col, src.width)
    
    out_height = max_row - min_row
    out_width = max_col - min_col
    
    # Initialize output arrays with weighted accumulation
    water_sum = np.zeros((out_height, out_width), dtype=np.float32)
    cloud_sum = np.zeros((out_height, out_width), dtype=np.float32)
    weight_sum = np.zeros((out_height, out_width), dtype=np.float32)
    
    # Track invalid pixels (any band is zero)
    invalid_mask = np.zeros((out_height, out_width), dtype=bool)
    
    # Use larger blend margin for smoother transitions (2x overlap)
    blend_margin = overlap * 2
    
    # Process each patch
    for row, col in tqdm(patches, desc="Processing patches"):
        # Read patch (bands 1-6 are model bands, skip SCL)
        window = rasterio.windows.Window(
            col, row,
            min(patch_size, src.width - col),
            min(patch_size, src.height - row)
        )
        tile = src.read([1, 2, 3, 4, 5, 6], window=window).astype(np.float32)
        
        # Get actual patch dimensions (may be smaller at edges)
        _, tile_h, tile_w = tile.shape
        
        # Mark invalid pixels (all bands zero)
        patch_invalid = np.all(tile == 0, axis=0)
        
        # Normalize
        tile_norm = normalize_tile(tile)
        
        # Run inference
        water_prob, cloud_prob = run_inference(model, tile_norm, device)
        
        # Create blend weights for this patch
        blend_weights = create_blend_weights(tile_h, tile_w, margin=blend_margin)
        
        # Map to output coordinates
        out_row = row - min_row
        out_col = col - min_col
        
        # Weighted accumulation
        water_sum[out_row:out_row+tile_h, out_col:out_col+tile_w] += water_prob * blend_weights
        cloud_sum[out_row:out_row+tile_h, out_col:out_col+tile_w] += cloud_prob * blend_weights
        weight_sum[out_row:out_row+tile_h, out_col:out_col+tile_w] += blend_weights
        invalid_mask[out_row:out_row+tile_h, out_col:out_col+tile_w] |= patch_invalid
    
    # Weighted average
    weight_sum = np.maximum(weight_sum, 1e-8)  # Avoid division by zero
    water_prob = water_sum / weight_sum
    cloud_prob = cloud_sum / weight_sum
    
    return water_prob, cloud_prob, invalid_mask, (min_row, min_col, max_row, max_col)


# Run inference
print(f"Running inference on {len(selected_patches)} patch(es)...")
water_prob, cloud_prob, invalid_mask, bounds = run_inference_on_patches(
    src, model, selected_patches, PATCH_SIZE, OVERLAP, DEVICE
)

print(f"\nOutput shape: {water_prob.shape}")
print(f"Water prob range: [{water_prob.min():.3f}, {water_prob.max():.3f}]")
print(f"Cloud prob range: [{cloud_prob.min():.3f}, {cloud_prob.max():.3f}]")
print(f"Invalid pixels: {invalid_mask.sum():,} ({100*invalid_mask.mean():.1f}%)")

### Known Limitations

**Open Ocean Detection**: The WorldFloods model was trained on flood events
(temporary water on land). Pure ocean patches may produce low water probabilities.
The `USE_SPECTRAL_POSTPROCESS` option uses MNDWI and low-NIR detection as a
fallback for obvious water areas.

In [None]:
def detect_water_spectral(tile, water_prob, nir_threshold=1500, mndwi_threshold=-0.1,
                          swir_flatness_threshold=0.10, swir_max_threshold=1500):
    """Detect water using MNDWI + low-NIR + spectral flatness.

    Parameters
    ----------
    tile : np.ndarray
        Tile data (6, H, W) in model band order [B02, B03, B04, B08, B11, B12].
    water_prob : np.ndarray
        Model water probability (H, W).
    nir_threshold : float
        NIR DN value below which pixel may be water.
    mndwi_threshold : float
        MNDWI threshold for water detection (can be negative).
    swir_flatness_threshold : float
        Max RELATIVE difference |SWIR1-SWIR2|/mean for spectral flatness (e.g., 0.10 = 10%).
    swir_max_threshold : float
        Max SWIR value to be considered water.

    Returns
    -------
    boost_mask : np.ndarray
        Boolean mask of pixels to boost.
    mndwi : np.ndarray
        MNDWI values for diagnostics.
    """
    green = tile[1].astype(np.float32)
    nir = tile[3].astype(np.float32)
    swir1 = tile[4].astype(np.float32)
    swir2 = tile[5].astype(np.float32)

    # MNDWI: (Green - SWIR1) / (Green + SWIR1)
    mndwi = (green - swir1) / (green + swir1 + 1e-8)

    # Criterion 1: Low NIR (raised threshold to account for atm baseline)
    low_nir = nir < nir_threshold

    # Criterion 2: MNDWI above threshold (can be slightly negative for ocean)
    mndwi_water = mndwi > mndwi_threshold

    # Criterion 3: Spectral flatness in SWIR (ocean signature)
    # Ocean has nearly identical SWIR1 and SWIR2 due to uniform water absorption
    # Use RELATIVE difference: |SWIR1-SWIR2| / mean(SWIR1,SWIR2)
    # Ocean: 2% relative diff, Land: 31% relative diff
    swir_mean = (swir1 + swir2) / 2 + 1e-8
    swir_relative_diff = np.abs(swir1 - swir2) / swir_mean
    swir_flat = swir_relative_diff < swir_flatness_threshold

    # Criterion 4: SWIR values below max (excludes bright land/clouds)
    swir_low = (swir1 < swir_max_threshold) & (swir2 < swir_max_threshold)

    # Combined spectral flatness criterion
    spectral_flat_water = swir_flat & swir_low & low_nir

    # Combine all criteria: (MNDWI OR spectral_flatness) AND only where model is uncertain
    boost_mask = (mndwi_water | spectral_flat_water) & (water_prob < 0.5)

    return boost_mask, mndwi


if USE_SPECTRAL_POSTPROCESS:
    # Read tile data for spectral analysis
    min_row, min_col, max_row, max_col = bounds
    window = rasterio.windows.Window(
        min_col, min_row,
        max_col - min_col, max_row - min_row
    )
    tile_data_spectral = src.read([1, 2, 3, 4, 5, 6], window=window).astype(np.float32)

    print("Applying spectral water detection...")
    boost_mask, mndwi = detect_water_spectral(
        tile_data_spectral, water_prob,
        nir_threshold=NIR_WATER_THRESHOLD,
        mndwi_threshold=MNDWI_THRESHOLD,
        swir_flatness_threshold=SWIR_FLATNESS_THRESHOLD,
        swir_max_threshold=SWIR_MAX_THRESHOLD,
    )

    # Boost water probability for detected water pixels
    water_prob_boosted = water_prob.copy()
    water_prob_boosted[boost_mask] = np.maximum(water_prob[boost_mask], 0.7)

    boosted_pixels = boost_mask.sum()

    # Diagnostic statistics
    swir1 = tile_data_spectral[4]
    swir2 = tile_data_spectral[5]
    swir_mean = (swir1 + swir2) / 2 + 1e-8
    swir_relative_diff = np.abs(swir1 - swir2) / swir_mean
    swir_flat = swir_relative_diff < SWIR_FLATNESS_THRESHOLD
    swir_low = (swir1 < SWIR_MAX_THRESHOLD) & (swir2 < SWIR_MAX_THRESHOLD)

    print(f"  Low-NIR (<{NIR_WATER_THRESHOLD}) pixels: {(tile_data_spectral[3] < NIR_WATER_THRESHOLD).sum():,}")
    print(f"  MNDWI > {MNDWI_THRESHOLD} pixels: {(mndwi > MNDWI_THRESHOLD).sum():,}")
    print(f"  SWIR flat (<{100*SWIR_FLATNESS_THRESHOLD:.0f}% rel diff) pixels: {swir_flat.sum():,}")
    print(f"  Spectral flat water: {(swir_flat & swir_low & (tile_data_spectral[3] < NIR_WATER_THRESHOLD)).sum():,}")
    print(f"  Boosted: {boosted_pixels:,} pixels ({100*boosted_pixels/water_prob.size:.1f}%)")

    water_prob = water_prob_boosted
else:
    print("Spectral post-processing: disabled")

In [None]:
# Visualize spectral water detection results
if USE_SPECTRAL_POSTPROCESS:
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # MNDWI values
    im_mndwi = axes[0].imshow(mndwi, cmap="RdBu", vmin=-1, vmax=1)
    axes[0].set_title(f"MNDWI\n[{mndwi.min():.2f}, {mndwi.max():.2f}]")
    axes[0].axis("off")
    plt.colorbar(im_mndwi, ax=axes[0], shrink=0.8)

    # Low-NIR mask
    low_nir_mask = tile_data_spectral[3] < NIR_WATER_THRESHOLD
    axes[1].imshow(low_nir_mask, cmap="Blues")
    axes[1].set_title(f"Low-NIR Mask (NIR < {NIR_WATER_THRESHOLD})\n{low_nir_mask.sum():,} pixels")
    axes[1].axis("off")

    # Combined boost mask
    axes[2].imshow(boost_mask, cmap="Greens")
    axes[2].set_title(f"Spectral Water Boost Mask\n{boost_mask.sum():,} pixels boosted")
    axes[2].axis("off")

    plt.suptitle("Spectral Water Detection Diagnostics", fontsize=12)
    plt.tight_layout()
    plt.show()
else:
    print("Spectral post-processing disabled - no MNDWI visualization")

In [None]:
# Compare band statistics between ocean (center) and land (top-left) patches
if USE_SPECTRAL_POSTPROCESS:
    # Patch size for sampling
    DIAG_PATCH_SIZE = 512

    # Define patch locations
    # Top-left corner (land)
    land_row, land_col = 0, 0
    # Center of tile (ocean)
    center_row = (tile_data_spectral.shape[1] - DIAG_PATCH_SIZE) // 2
    center_col = (tile_data_spectral.shape[2] - DIAG_PATCH_SIZE) // 2

    # Extract patches
    land_patch = tile_data_spectral[:, land_row:land_row+DIAG_PATCH_SIZE,
                                       land_col:land_col+DIAG_PATCH_SIZE]
    ocean_patch = tile_data_spectral[:, center_row:center_row+DIAG_PATCH_SIZE,
                                        center_col:center_col+DIAG_PATCH_SIZE]

    # Band names
    band_names = ["B02\n(Blue)", "B03\n(Green)", "B04\n(Red)",
                  "B08\n(NIR)", "B11\n(SWIR1)", "B12\n(SWIR2)"]

    # Calculate statistics for each band
    land_median = [np.median(land_patch[i]) for i in range(6)]
    land_std = [np.std(land_patch[i]) for i in range(6)]
    ocean_median = [np.median(ocean_patch[i]) for i in range(6)]
    ocean_std = [np.std(ocean_patch[i]) for i in range(6)]

    # Plot
    fig, ax = plt.subplots(figsize=(10, 6))
    x = np.arange(6)
    width = 0.35

    # Land bars
    ax.bar(x - width/2, land_median, width, yerr=land_std,
           label=f"Land (top-left)", color="saddlebrown", alpha=0.8, capsize=4)
    # Ocean bars
    ax.bar(x + width/2, ocean_median, width, yerr=ocean_std,
           label=f"Ocean (center)", color="steelblue", alpha=0.8, capsize=4)

    ax.set_xlabel("Band")
    ax.set_ylabel("DN Value")
    ax.set_title("Band Statistics: Land vs Ocean Patches\n(median Â± stdev)")
    ax.set_xticks(x)
    ax.set_xticklabels(band_names)
    ax.legend()
    ax.grid(axis="y", alpha=0.3)

    # Add horizontal line at NIR threshold
    ax.axhline(y=NIR_WATER_THRESHOLD, color="red", linestyle="--",
               label=f"NIR threshold ({NIR_WATER_THRESHOLD})", alpha=0.7)
    ax.legend()

    plt.tight_layout()
    plt.show()

    # Print numeric values
    print("\nBand statistics:")
    print(f"{'Band':<8} {'Land median':>12} {'Land std':>10} {'Ocean median':>13} {'Ocean std':>11}")
    print("-" * 58)
    for i, name in enumerate(["B02", "B03", "B04", "B08", "B11", "B12"]):
        print(f"{name:<8} {land_median[i]:>12.0f} {land_std[i]:>10.0f} {ocean_median[i]:>13.0f} {ocean_std[i]:>11.0f}")

## Classification

In [None]:
# Apply thresholds to get discrete classes
prediction = classify_prediction(
    water_prob, cloud_prob, invalid_mask,
    th_water=TH_WATER, th_cloud=TH_CLOUD
)

# Print class distribution
print("Classification results:")
print(f"  Thresholds: water={TH_WATER}, cloud={TH_CLOUD}")
print()

unique, counts = np.unique(prediction, return_counts=True)
for cls, count in zip(unique, counts):
    pct = 100 * count / prediction.size
    print(f"  {CLASS_NAMES[cls]:8s}: {count:>10,} pixels ({pct:5.1f}%)")

## Export Classification GeoTIFF

Save the classification result as a georeferenced GeoTIFF for use in QGIS or other GIS software.

In [None]:
# --- Output Path ---
# Save alongside input file with "_classification" suffix
output_path = GEOTIFF_PATH.parent / f"{GEOTIFF_PATH.stem}_classification.tif"

# Reopen source to get geospatial metadata
with rasterio.open(GEOTIFF_PATH) as src_meta:
    # Adjust transform if we processed a subset (bounds != full image)
    min_row, min_col, max_row, max_col = bounds
    transform = src_meta.transform * rasterio.Affine.translation(min_col, min_row)

    profile = {
        "driver": "GTiff",
        "dtype": "uint8",
        "width": prediction.shape[1],
        "height": prediction.shape[0],
        "count": 1,
        "crs": src_meta.crs,
        "transform": transform,
        "compress": "lzw",
        "nodata": 0,  # Invalid class = 0
    }

# Write classification raster
with rasterio.open(output_path, "w", **profile) as dst:
    dst.write(prediction.astype(np.uint8), 1)
    dst.set_band_description(1, "Classification")

print(f"Saved: {output_path}")
print(f"  Size: {output_path.stat().st_size / 1e6:.1f} MB")
print(f"  Classes: 0=Invalid, 1=Land, 2=Water, 3=Cloud")

## Visualization

In [None]:
# Read RGB data for the processed region
min_row, min_col, max_row, max_col = bounds
window = rasterio.windows.Window(
    min_col, min_row,
    max_col - min_col, max_row - min_row
)
tile_data = src.read([1, 2, 3, 4, 5, 6], window=window).astype(np.float32)

# Create composites
false_color = create_rgb_composite(tile_data, bands=(4, 3, 1))  # SWIR-NIR-Green
true_color = create_rgb_composite(tile_data, bands=(2, 1, 0))   # Red-Green-Blue

# Get colormap
cmap = get_classification_cmap()

print(f"Processed region: {tile_data.shape[2]} x {tile_data.shape[1]} pixels")

In [None]:
# Main visualization: 2x2 grid
fig, axes = plt.subplots(2, 2, figsize=(14, 14))

# Cloud probability
im_cloud = axes[0, 0].imshow(cloud_prob, cmap="Greys", vmin=0, vmax=1)
axes[0, 0].set_title(f"Cloud Probability\n[{cloud_prob.min():.2f}, {cloud_prob.max():.2f}]")
axes[0, 0].axis("off")
plt.colorbar(im_cloud, ax=axes[0, 0], shrink=0.8)

# Water probability
im_water = axes[0, 1].imshow(water_prob, cmap="Blues", vmin=0, vmax=1)
axes[0, 1].set_title(f"Water Probability\n[{water_prob.min():.2f}, {water_prob.max():.2f}]")
axes[0, 1].axis("off")
plt.colorbar(im_water, ax=axes[0, 1], shrink=0.8)

# False color composite
axes[1, 0].imshow(false_color)
axes[1, 0].set_title("False Color (SWIR-NIR-Green)")
axes[1, 0].axis("off")

# Classification
im_class = axes[1, 1].imshow(prediction, cmap=cmap, vmin=0, vmax=3)
axes[1, 1].set_title("Classification")
axes[1, 1].axis("off")

# Add legend for classification
legend_patches = [
    Patch(facecolor=CLASSIFICATION_COLORS[i][:3], label=CLASS_NAMES[i])
    for i in range(4)
]
axes[1, 1].legend(
    handles=legend_patches, loc="lower right",
    framealpha=0.9, fontsize=10
)

plt.suptitle(f"{GEOTIFF_PATH.name}\n{prediction.shape[1]} x {prediction.shape[0]} pixels", fontsize=12)
plt.tight_layout()
plt.show()

## True Color Comparison

In [None]:
# Side-by-side: True color, False color, Classification
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

axes[0].imshow(true_color)
axes[0].set_title("True Color (RGB)")
axes[0].axis("off")

axes[1].imshow(false_color)
axes[1].set_title("False Color (SWIR-NIR-Green)")
axes[1].axis("off")

axes[2].imshow(prediction, cmap=cmap, vmin=0, vmax=3)
axes[2].set_title("Classification")
axes[2].axis("off")
axes[2].legend(
    handles=legend_patches, loc="lower right",
    framealpha=0.9, fontsize=10
)

plt.suptitle(f"Comparison: {GEOTIFF_PATH.name}", fontsize=12)
plt.tight_layout()
plt.show()

In [None]:
# Close file handle
src.close()
print("File handle closed.")