# Leakance Impact Figures

This notebook generates figures demonstrating the impact of leakance (groundwater-surface water exchange) across CONUS.

**Prerequisites:**
1. A trained model checkpoint with leakance enabled
2. `scripts/router.py` output (`chrout.zarr`) — contains `zeta_sum` and `q_prime_sum`
3. Two `scripts/test.py` outputs (`model_test.zarr`) — one with leakance ON, one with leakance OFF (same checkpoint)

**Figures:**
1. Learned parameter maps (K_D, d_gw, leakance_factor)
2. Normalized cumulative zeta map (leakance impact per reach)
3. Delta-NSE map (metric improvement from leakance)
4. Representative hydrographs (with/without leakance)

In [None]:
# Cell 1: Imports + Config
import logging
from pathlib import Path

import contextily as cx
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import xarray as xr
import yaml
from matplotlib.figure import Figure
from mpl_toolkits.axes_grid1 import make_axes_locatable

from ddr._version import __version__
from ddr.io.readers import ForcingsReader
from ddr.nn import kan, leakance_lstm
from ddr.routing.torch_mc import dmc
from ddr.routing.utils import denormalize
from ddr.scripts_utils import load_checkpoint
from ddr.validation import Config, Metrics, plot_cdf, plot_gauge_map, plot_time_series

log = logging.getLogger(__name__)

# ── User configuration ──────────────────────────────────────────────
CONFIG_PATH = Path("./example_config.yaml")  # Your MERIT leakance config
CHROUT_PATH = Path("./chrout.zarr")  # router.py output with zeta_sum
TEST_LEAKANCE_ON = Path("./model_test_leakance_on.zarr")  # test.py with leakance
TEST_LEAKANCE_OFF = Path("./model_test_leakance_off.zarr")  # test.py without leakance
MERIT_SHP = Path("./cat_pfaf_7_MERIT_Hydro_v07_Basins_v01_bugfix1.shp")  # MERIT catchments
GAGES_CSV = Path("./training_gauges.csv")  # Gage reference file
SAVE_DIR = Path("./figures")
SAVE_DIR.mkdir(exist_ok=True)

with open(CONFIG_PATH) as f:
    config = Config(**yaml.safe_load(f))

device = torch.device(config.device)

In [None]:
# Cell 2: Figure 1 — Learned Leakance Parameter Maps (3-panel)
#
# Run LSTM on forcings + attributes to extract K_D, d_gw, leakance_factor.
# Temporal mean gives spatial maps. Join to MERIT shapefile via COMID.

# Instantiate KAN (needed for dataset setup) and LSTM
nn = kan(
    input_var_names=config.kan.input_var_names,
    learnable_parameters=config.kan.learnable_parameters,
    hidden_size=config.kan.hidden_size,
    num_hidden_layers=config.kan.num_hidden_layers,
    grid=config.kan.grid,
    k=config.kan.k,
    seed=config.seed,
    device=config.device,
)
leakance_nn = leakance_lstm(
    input_var_names=config.leakance_lstm.input_var_names,
    forcing_var_names=config.leakance_lstm.forcing_var_names,
    hidden_size=config.leakance_lstm.hidden_size,
    num_layers=config.leakance_lstm.num_layers,
    dropout=config.leakance_lstm.dropout,
    seed=config.seed,
    device=config.device,
)
forcings_reader_nn = ForcingsReader(config)

# Load both KAN + LSTM from checkpoint
load_checkpoint(nn, config.experiment.checkpoint, device, leakance_nn=leakance_nn)
nn = nn.eval()
leakance_nn = leakance_nn.eval()

# Get dataset
dataset = config.geodataset.get_dataset_class(cfg=config)

# Batched LSTM leakance inference (avoids GPU OOM on full CONUS ~180k reaches)
BATCH_SIZE = 10_000

with torch.no_grad():
    all_forcings = forcings_reader_nn(
        routing_dataclass=dataset.routing_dataclass, device="cpu", dtype=torch.float32
    )
all_attributes = dataset.routing_dataclass.normalized_spatial_attributes

N = all_attributes.shape[0]
K_D_cpu = torch.zeros(N, dtype=torch.float16)
d_gw_cpu = torch.zeros(N, dtype=torch.float16)
lf_cpu = torch.zeros(N, dtype=torch.float16)

with torch.no_grad():
    for start in range(0, N, BATCH_SIZE):
        end = min(start + BATCH_SIZE, N)
        batch_forcings = all_forcings[:, start:end, :].to(device)
        batch_attrs = all_attributes[start:end, :].to(device)

        batch_out = leakance_nn(forcings=batch_forcings, attributes=batch_attrs)

        K_D_cpu[start:end] = denormalize(
            batch_out["K_D"].mean(dim=0), config.params.parameter_ranges["K_D"]
        ).cpu().half()
        d_gw_cpu[start:end] = denormalize(
            batch_out["d_gw"].mean(dim=0), config.params.parameter_ranges["d_gw"]
        ).cpu().half()
        lf_cpu[start:end] = denormalize(
            batch_out["leakance_factor"].mean(dim=0), config.params.parameter_ranges["leakance_factor"]
        ).cpu().half()

        del batch_forcings, batch_attrs, batch_out
        torch.cuda.empty_cache()

del all_forcings

K_D = K_D_cpu.float().numpy()
d_gw = d_gw_cpu.float().numpy()
leakance_factor = lf_cpu.float().numpy()

# Join to MERIT shapefile
gdf = gpd.read_file(MERIT_SHP).set_index("COMID")
divide_ids = np.array(dataset.routing_dataclass.divide_ids)
gdf = gdf.loc[divide_ids]
gdf["K_D"] = K_D
gdf["d_gw"] = d_gw
gdf["leakance_factor"] = leakance_factor
gdf = gdf.to_crs(epsg=4326)

In [None]:
# Plot 3-panel parameter maps
def param_plot(
    gdf,
    var,
    save_name,
    cmap="plasma",
    unit_label=None,
    title=None,
    vmin=None,
    vmax=None,
    ascending=False,
    dpi=100,
):
    """Create a CONUS parameter map with basemap and colorbar."""
    fig, ax = plt.subplots(figsize=(7, 4), dpi=dpi)
    gdf_clean = gdf.dropna(subset=[var]).sort_values(by=var, ascending=ascending)
    data = gdf_clean[var].values
    if vmin is None:
        vmin = np.min(data)
    if vmax is None:
        vmax = np.nanmax(data)

    gdf_clean.plot(ax=ax, column=var, cmap=cmap, linewidth=0.3, vmin=vmin, vmax=vmax, zorder=1)
    cx.add_basemap(
        ax, crs=gdf_clean.crs, source=cx.providers.CartoDB.Positron, alpha=0.6, zorder=0, attribution=False
    )
    ax.set_xlim(-125, -66)
    ax.set_ylim(24, 53)
    ax.set_xticks([])
    ax.set_yticks([])
    if title:
        ax.set_title(title, fontsize=14)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="3%", pad=0.1)
    sm = plt.cm.ScalarMappable(cmap=cmap)
    sm.set_array([])
    sm.set_clim(vmin, vmax)
    cbar = fig.colorbar(sm, cax=cax)
    cbar.set_label(f"{var} ({unit_label})" if unit_label else var)

    plt.tight_layout()
    plt.savefig(save_name, dpi=600, bbox_inches="tight")
    return fig, ax


param_plot(
    gdf,
    "K_D",
    SAVE_DIR / "K_D_map.png",
    cmap="viridis",
    title="Leakance Coefficient $K_D$ (1/s)",
    unit_label="1/s",
    dpi=200,
)
param_plot(
    gdf,
    "d_gw",
    SAVE_DIR / "d_gw_map.png",
    cmap="RdBu_r",
    title="Groundwater Depth Threshold $d_{gw}$ (m)",
    unit_label="m",
    dpi=200,
)
param_plot(
    gdf,
    "leakance_factor",
    SAVE_DIR / "leakance_factor_map.png",
    cmap="viridis",
    title="Leakance Factor",
    vmin=0,
    vmax=1,
    dpi=200,
)
plt.show()

In [None]:
# Cell 3: Figure 2 — Normalized Cumulative Zeta Map
#
# impact = zeta_sum / (q_prime_sum + epsilon)
# Positive = losing stream (water lost to groundwater)
# Negative = gaining stream (groundwater feeds stream)

ds_chrout = xr.open_zarr(CHROUT_PATH)

zeta_sum = ds_chrout["zeta_sum"].values
q_prime_sum = ds_chrout["q_prime_sum"].values
catchment_ids = ds_chrout["catchment_ids"].values

epsilon = 1e-6
impact = zeta_sum / (q_prime_sum + epsilon)
impact = np.clip(impact, -1.0, 1.0)  # Clamp extremes

# Join to MERIT shapefile
gdf_impact = gpd.read_file(MERIT_SHP).set_index("COMID")
gdf_impact = gdf_impact.loc[catchment_ids]
gdf_impact["leakance_impact"] = impact
gdf_impact = gdf_impact.to_crs(epsg=4326)

param_plot(
    gdf_impact,
    "leakance_impact",
    SAVE_DIR / "leakance_impact_map.png",
    cmap="RdBu_r",
    title="Normalized Cumulative Leakance ($\\sum \\zeta / \\sum q'$)",
    vmin=-0.5,
    vmax=0.5,
    dpi=200,
)
plt.show()

In [None]:
# Cell 4: Figure 3 — Delta-NSE Map (Leakance ON minus OFF)
#
# Compare two test.py runs (same checkpoint, leakance toggled at inference)

ds_on = xr.open_zarr(TEST_LEAKANCE_ON)
ds_off = xr.open_zarr(TEST_LEAKANCE_OFF)

# Align on common gages
common_gages = np.intersect1d(ds_on.gage_ids.values, ds_off.gage_ids.values)
ds_on = ds_on.sel(gage_ids=common_gages)
ds_off = ds_off.sel(gage_ids=common_gages)

metrics_on = Metrics(pred=ds_on.predictions.values, target=ds_on.observations.values)
metrics_off = Metrics(pred=ds_off.predictions.values, target=ds_off.observations.values)

nse_on = np.clip(metrics_on.nse, -1, 1)
nse_off = np.clip(metrics_off.nse, -1, 1)
delta_nse = nse_on - nse_off

# Build gage DataFrame for plotting
gages_df = pd.read_csv(GAGES_CSV)
gages_df["STAID"] = gages_df["STAID"].astype(str).str.zfill(8)
gages_df = gages_df.set_index("STAID")
selected_gages = gages_df.loc[common_gages].reset_index()
selected_gages["delta_NSE"] = delta_nse
selected_gages["NSE_leakance_on"] = nse_on
selected_gages["NSE_leakance_off"] = nse_off

# Map
fig = plot_gauge_map(
    gages=selected_gages,
    metric_column="delta_NSE",
    title=r"$\Delta$NSE (Leakance ON $-$ OFF)",
    colormap="RdBu",
    colorbar_label=r"$\Delta$NSE",
    vmin=-0.1,
    vmax=0.1,
    figsize=(16, 8),
    point_size=30,
    path=SAVE_DIR / "delta_nse_map.png",
    show_plot=True,
)

# CDF overlay
fig_cdf, ax_cdf = plot_cdf(
    data_list=[nse_off, nse_on],
    title="NSE CDF: Leakance ON vs OFF",
    legend_labels=["Leakance OFF", "Leakance ON"],
    figsize=(10, 6),
    xlabel="NSE",
    ylabel="Cumulative Frequency",
    reference_line=None,
    xlim=(0, 1),
)
plt.savefig(SAVE_DIR / "nse_cdf_comparison.png", dpi=300, bbox_inches="tight")
plt.show()

print(f"Median NSE (leakance ON):  {np.nanmedian(nse_on):.4f}")
print(f"Median NSE (leakance OFF): {np.nanmedian(nse_off):.4f}")
print(f"Median delta-NSE:          {np.nanmedian(delta_nse):.4f}")
print(f"Gages improved:            {np.sum(delta_nse > 0)} / {len(delta_nse)}")
print(f"Gages degraded:            {np.sum(delta_nse < 0)} / {len(delta_nse)}")

In [None]:
# Cell 5: Figure 4 — Representative Hydrographs
#
# Select gages by delta-NSE ranking: best improved, worst degraded, neutral.

# Rank gages by delta-NSE
ranked = selected_gages.sort_values("delta_NSE", ascending=False)
best_improved = ranked.iloc[0]
worst_degraded = ranked.iloc[-1]
neutral_idx = (ranked["delta_NSE"].abs()).idxmin()
neutral = ranked.loc[neutral_idx]

representative_gages = [best_improved, neutral, worst_degraded]
panel_labels = ["Best Improved", "Neutral", "Most Degraded"]

time_range = ds_on.time.values

for gage_info, label in zip(representative_gages, panel_labels, strict=False):
    gage_id = gage_info["STAID"]
    gage_name = gage_info.get("STANAME", gage_id)

    pred_on = ds_on.sel(gage_ids=gage_id).predictions.values
    pred_off = ds_off.sel(gage_ids=gage_id).predictions.values
    obs = ds_on.sel(gage_ids=gage_id).observations.values

    gage_metrics_on = {"nse": float(gage_info["NSE_leakance_on"])}
    gage_metrics_off = {"nse": float(gage_info["NSE_leakance_off"])}

    title = f"{label}: {gage_name} ({gage_id}) | $\\Delta$NSE={gage_info['delta_NSE']:.4f}"

    plot_time_series(
        prediction=pred_on,
        observation=obs,
        time_range=time_range,
        gage_id=gage_id,
        name=gage_name,
        metrics=gage_metrics_on,
        path=SAVE_DIR / f"hydrograph_{gage_id}_{label.lower().replace(' ', '_')}.png",
        title=title,
        additional_predictions=[
            (pred_off, "DDR (no leakance)", gage_metrics_off),
        ],
    )

print("Hydrograph figures saved.")