# FloodMapper Inference Demo

This notebook demonstrates water/cloud classification using the WorldFloods UNet model
on a Sentinel-2 GeoTIFF tile.

**Key features:**
- Self-contained UNet implementation (no ml4floods dependency)
- Loads pre-trained WF2_unet_rbgiswirs weights
- **Interactive tile selection** - explore different regions using sliders or preset buttons
- Visualizes RGB input and water/cloud prediction

## 1. Setup & Imports

In [None]:
%matplotlib inline

import json
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import rasterio
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. UNet Architecture

Standard 4-level encoder/decoder with skip connections.
This matches the ml4floods UNet architecture exactly.

In [None]:
def _double_conv(in_ch: int, out_ch: int) -> nn.Sequential:
    """Two consecutive Conv2d-ReLU blocks."""
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
    )


class UNet(nn.Module):
    """
    4-level UNet for semantic segmentation.
    
    Architecture:
        Encoder: 4 double-conv blocks with max pooling
        Decoder: 3 upsampling + skip connection + double-conv blocks
        Output: 1x1 conv to num_classes
    
    Parameters
    ----------
    in_channels : int
        Number of input channels (e.g., 6 for bgriswirs bands).
    num_classes : int
        Number of output classes (2 for water/cloud binary heads).
    """
    
    def __init__(self, in_channels: int, num_classes: int):
        super().__init__()
        # Encoder
        self.dconv_down1 = _double_conv(in_channels, 64)
        self.dconv_down2 = _double_conv(64, 128)
        self.dconv_down3 = _double_conv(128, 256)
        self.dconv_down4 = _double_conv(256, 512)
        self.maxpool = nn.MaxPool2d(2)
        
        # Decoder
        self.dconv_up3 = _double_conv(256 + 512, 256)
        self.dconv_up2 = _double_conv(128 + 256, 128)
        self.dconv_up1 = _double_conv(64 + 128, 64)
        
        # Output
        self.conv_last = nn.Conv2d(64, num_classes, kernel_size=1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with skip connections."""
        # Encoder
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)
        
        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)
        
        x = self.dconv_down4(x)
        
        # Decoder with skip connections
        x = F.interpolate(x, size=conv3.shape[2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, conv3], dim=1)
        x = self.dconv_up3(x)
        
        x = F.interpolate(x, size=conv2.shape[2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, conv2], dim=1)
        x = self.dconv_up2(x)
        
        x = F.interpolate(x, size=conv1.shape[2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, conv1], dim=1)
        x = self.dconv_up1(x)
        
        return self.conv_last(x)

## 3. Model Loading

Load the config and weights, stripping the `network.` prefix from state dict keys.

In [None]:
# Paths
MODEL_DIR = Path("../scratch/FloodMapper/resources/models/WF2_unet_rbgiswirs")
CONFIG_PATH = MODEL_DIR / "config.json"
WEIGHTS_PATH = MODEL_DIR / "model.pt"

# Load config
with open(CONFIG_PATH) as f:
    config = json.load(f)

hyperparams = config["model_params"]["hyperparameters"]
print("Model config:")
print(f"  Channel configuration: {hyperparams['channel_configuration']}")
print(f"  Number of classes: {hyperparams['num_classes']}")
print(f"  Label names: {hyperparams['label_names']}")

In [None]:
def strip_prefix(state_dict: dict, prefix: str = "network.") -> dict:
    """
    Remove prefix from state dict keys.
    
    The ml4floods checkpoint saves weights as "network.dconv_down1.0.weight"
    but our UNet expects "dconv_down1.0.weight".
    """
    return {k.replace(prefix, ""): v for k, v in state_dict.items()}


# Load and prepare model
# Determine input channels from bgriswirs = 6 bands
IN_CHANNELS = 6  # B03, B04, B08, B8A, B11, B12
NUM_CLASSES = hyperparams["num_classes"]  # 2 (water head, cloud head)

model = UNet(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES)

# Load weights
state_dict = torch.load(WEIGHTS_PATH, map_location=device, weights_only=True)
state_dict = strip_prefix(state_dict, "network.")
model.load_state_dict(state_dict)
model.to(device)
model.eval()

print(f"Model loaded successfully with {sum(p.numel() for p in model.parameters()):,} parameters")

## 4. Data Loading

Load a tile from the Sentinel-2 GeoTIFF and apply normalization.

**Band order in GeoTIFF:** B03 (Green), B04 (Red), B08 (NIR), B8A (NIR2), B11 (SWIR1), B12 (SWIR2), SCL

**Model expects:** First 6 bands (excluding SCL), normalized per-band.

In [None]:
# Normalization constants from ml4floods SENTINEL2_NORMALIZATION
# Format: [mean, std] for each band in bgriswirs order
NORMALIZATION = {
    "B03": [3238.08, 2549.49],  # Green
    "B04": [3418.90, 2811.78],  # Red
    "B08": [3981.96, 2500.48],  # NIR
    "B8A": [4226.75, 2589.29],  # NIR2
    "B11": [2391.66, 1500.03],  # SWIR1
    "B12": [1790.32, 1241.98],  # SWIR2
}

BAND_ORDER = ["B03", "B04", "B08", "B8A", "B11", "B12"]

# Create normalization arrays
means = np.array([NORMALIZATION[b][0] for b in BAND_ORDER], dtype=np.float32)
stds = np.array([NORMALIZATION[b][1] for b in BAND_ORDER], dtype=np.float32)

print("Normalization constants:")
for i, band in enumerate(BAND_ORDER):
    print(f"  {band}: mean={means[i]:.2f}, std={stds[i]:.2f}")

In [None]:
# Load GeoTIFF and get metadata
GEOTIFF_PATH = Path("../DATA/37MGT/2025-12-17_S2L2A.tif")

# Open file and keep reference for interactive use
src = rasterio.open(GEOTIFF_PATH)

print(f"GeoTIFF info:")
print(f"  Shape: {src.count} bands x {src.height} x {src.width}")
print(f"  Band names: {src.descriptions}")
print(f"  CRS: {src.crs}")

# Store image dimensions
img_height = src.height
img_width = src.width
tile_size = 1024

print(f"\nImage dimensions: {img_height} x {img_width}")
print(f"Tile size: {tile_size} x {tile_size}")

## 5. Interactive Tile Selection

Use the sliders below to select a tile location. The red rectangle shows the selected 1024x1024 region.
When you've found an interesting area, run the next cell to execute inference.

In [None]:
# Create downsampled thumbnail for overview visualization
thumbnail_scale = 10
thumbnail_height = img_height // thumbnail_scale
thumbnail_width = img_width // thumbnail_scale

# Read RGB bands (B04=Red, B03=Green) at reduced resolution
thumbnail_rgb = src.read(
    [2, 1, 1],  # B04, B03, B03 for pseudo-RGB
    out_shape=(3, thumbnail_height, thumbnail_width)
).astype(np.float32)

# Scale thumbnail for display
p2, p98 = np.percentile(thumbnail_rgb, [2, 98])
thumbnail_display = np.clip((thumbnail_rgb - p2) / (p98 - p2), 0, 1)
thumbnail_display = np.transpose(thumbnail_display, (1, 2, 0))  # CHW -> HWC

# Calculate slider ranges for interactive selection
max_row = img_height - tile_size
max_col = img_width - tile_size
center_row = max_row // 2
center_col = max_col // 2

print(f"Thumbnail shape: {thumbnail_display.shape}")
print(f"Tile selection range: row [0, {max_row}], col [0, {max_col}]")
print(f"Center position: row={center_row}, col={center_col}")


def extract_tile(row_start: int, col_start: int, tile_size: int = 1024):
    """
    Extract a tile from the GeoTIFF at specified location.
    
    Parameters
    ----------
    row_start : int
        Starting row (y coordinate) for the tile.
    col_start : int
        Starting column (x coordinate) for the tile.
    tile_size : int
        Size of the square tile to extract.
    
    Returns
    -------
    tuple
        tile: np.ndarray of shape (6, tile_size, tile_size) with raw values
        scl: np.ndarray of shape (tile_size, tile_size) with SCL values
    """
    window = rasterio.windows.Window(col_start, row_start, tile_size, tile_size)
    tile = src.read([1, 2, 3, 4, 5, 6], window=window).astype(np.float32)
    scl = src.read(7, window=window)
    return tile, scl

In [None]:
# === Helper Functions for Processing Pipeline ===

def normalize_tile(tile: np.ndarray, means: np.ndarray, stds: np.ndarray) -> np.ndarray:
    """Apply per-band normalization: (value - mean) / std"""
    means_reshaped = means[:, np.newaxis, np.newaxis]
    stds_reshaped = stds[:, np.newaxis, np.newaxis]
    return (tile - means_reshaped) / stds_reshaped


def pad_to_multiple(x: torch.Tensor, multiple: int = 16) -> tuple[torch.Tensor, tuple[int, int]]:
    """Pad tensor height and width to be divisible by multiple."""
    h, w = x.shape[-2:]
    pad_h = (multiple - h % multiple) % multiple
    pad_w = (multiple - w % multiple) % multiple
    if pad_h > 0 or pad_w > 0:
        x = F.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
    return x, (h, w)


@torch.no_grad()
def run_inference(model: nn.Module, tile: np.ndarray, device: torch.device) -> tuple[np.ndarray, np.ndarray]:
    """Run inference on a normalized tile, returns (water_prob, cloud_prob)."""
    x = torch.from_numpy(tile).unsqueeze(0).to(device)
    x_padded, (orig_h, orig_w) = pad_to_multiple(x, multiple=16)
    logits = model(x_padded)
    probs = torch.sigmoid(logits)
    probs = probs[0, :, :orig_h, :orig_w].cpu().numpy()
    return probs[0], probs[1]  # water_prob, cloud_prob


def classify_prediction(water_prob: np.ndarray, cloud_prob: np.ndarray,
                        invalid_mask: np.ndarray | None = None,
                        th_water: float = 0.5, th_cloud: float = 0.5) -> np.ndarray:
    """Classify pixels: 0=invalid, 1=land, 2=water, 3=cloud."""
    pred = np.ones_like(water_prob, dtype=np.uint8)
    pred[water_prob > th_water] = 2
    pred[cloud_prob > th_cloud] = 3
    if invalid_mask is not None:
        pred[invalid_mask] = 0
    return pred


def create_rgb_composite(tile: np.ndarray, bands: tuple[int, int, int] = (1, 0, 0)) -> np.ndarray:
    """Create an RGB composite for visualization."""
    rgb = np.stack([tile[b] for b in bands], axis=-1)
    p2, p98 = np.percentile(rgb, [2, 98])
    return np.clip((rgb - p2) / (p98 - p2), 0, 1)


# Classification colormap and labels
colors = [[0, 0, 0, 1], [0.76, 0.70, 0.50, 1], [0, 0.3, 0.8, 1], [0.9, 0.9, 0.9, 1]]
cmap = ListedColormap(colors)
class_names = {0: "Invalid", 1: "Land", 2: "Water", 3: "Cloud"}

print("Helper functions defined.")

In [None]:
from ipywidgets import interact, IntSlider

# Create sliders (these persist between cells)
row_slider = IntSlider(min=0, max=max_row, step=256, value=center_row, 
                       description='Row:', continuous_update=False)
col_slider = IntSlider(min=0, max=max_col, step=256, value=center_col, 
                       description='Col:', continuous_update=False)

@interact(row_start=row_slider, col_start=col_slider)
def show_tile_selection(row_start, col_start):
    """Show thumbnail with selected tile location (lightweight - no inference)."""
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(thumbnail_display)
    
    # Scale rectangle to thumbnail coordinates
    scale_row = thumbnail_display.shape[0] / img_height
    scale_col = thumbnail_display.shape[1] / img_width
    
    rect = plt.Rectangle(
        (col_start * scale_col, row_start * scale_row),
        tile_size * scale_col, tile_size * scale_row,
        fill=False, edgecolor='red', linewidth=2
    )
    ax.add_patch(rect)
    ax.set_title(f"Selected tile: row={row_start}, col={col_start}")
    ax.axis('off')
    plt.tight_layout()
    plt.show()

## 6. Run Inference

Execute this cell to run inference on the currently selected tile position.

In [None]:
# Get current slider values
row_start = row_slider.value
col_start = col_slider.value

print(f"Running inference on tile at row={row_start}, col={col_start}...")

# Extract tile
tile, scl = extract_tile(row_start, col_start, tile_size)
print(f"  Tile shape: {tile.shape}, value range: [{tile.min():.0f}, {tile.max():.0f}]")

# Normalize and run inference
tile_normalized = normalize_tile(tile, means, stds)
water_prob, cloud_prob = run_inference(model, tile_normalized, device)
print(f"  Water prob: [{water_prob.min():.3f}, {water_prob.max():.3f}]")
print(f"  Cloud prob: [{cloud_prob.min():.3f}, {cloud_prob.max():.3f}]")

# Classify
invalid_mask = np.all(tile == 0, axis=0)
prediction = classify_prediction(water_prob, cloud_prob, invalid_mask)

# Print class distribution
unique, counts = np.unique(prediction, return_counts=True)
print("\nClassification results:")
for cls, count in zip(unique, counts):
    pct = 100 * count / prediction.size
    print(f"  {class_names[cls]}: {count:,} pixels ({pct:.1f}%)")

## 7. Visualization

Display the RGB composite, classification map, and false color composites.

In [None]:
# Main visualization: RGB, Classification, Water Probability
rgb = create_rgb_composite(tile, bands=(1, 0, 0))

fig1, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(rgb)
axes[0].set_title("RGB Composite (B04, B03, B03)")
axes[0].axis("off")

im = axes[1].imshow(prediction, cmap=cmap, vmin=0, vmax=3)
axes[1].set_title("Classification")
axes[1].axis("off")
cbar = plt.colorbar(im, ax=axes[1], ticks=[0, 1, 2, 3], shrink=0.8)
cbar.ax.set_yticklabels(["Invalid", "Land", "Water", "Cloud"])

im2 = axes[2].imshow(water_prob, cmap="Blues", vmin=0, vmax=1)
axes[2].set_title("Water Probability")
axes[2].axis("off")
plt.colorbar(im2, ax=axes[2], shrink=0.8)

plt.suptitle(f"Tile at row={row_start}, col={col_start}", fontsize=12)
plt.tight_layout()
plt.show()

# False color composites
fig2, axes = plt.subplots(1, 2, figsize=(12, 5))

false_color = create_rgb_composite(tile, bands=(2, 1, 0))
axes[0].imshow(false_color)
axes[0].set_title("False Color (NIR-Red-Green)\nWater appears dark")
axes[0].axis("off")

swir_color = create_rgb_composite(tile, bands=(4, 2, 1))
axes[1].imshow(swir_color)
axes[1].set_title("SWIR Composite (SWIR1-NIR-Red)\nWater appears very dark")
axes[1].axis("off")

plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated:

1. **Self-contained UNet implementation** - No ml4floods dependency required
2. **Model loading** - Strip `network.` prefix from state dict keys  
3. **Data loading** - Open GeoTIFF with rasterio, create thumbnail overview
4. **Interactive tile selection** - Use sliders to explore the image (lightweight)
5. **Inference pipeline** - Run on-demand when you've selected a tile
6. **Visualization** - RGB composite, classification map, and false color views

**Workflow:**
1. Use the sliders in Section 5 to explore the image and select a tile
2. When ready, execute Section 6 to run inference on the selected tile
3. Execute Section 7 to visualize the results
4. Return to Section 5 to select a new tile and repeat

The model produces two probability outputs (water, cloud) which are thresholded at 0.5 to produce discrete classes.