From 52581f6ab7fdf617f3b13bb71a64acb14f460137 Mon Sep 17 00:00:00 2001 From: Tadd Bindas Date: Tue, 7 Oct 2025 19:52:42 -0500 Subject: [PATCH] initial commit: added new geodatazoo class --- examples/eval/evaluate.ipynb | 4 +- .../parameter_maps/plot_parameter_map.ipynb | 4 +- pyproject.toml | 3 +- shoutout.md | 7 + src/ddr/geodatazoo/Dates.py | 128 ++++++ src/ddr/geodatazoo/Gauges.py | 54 +++ src/ddr/geodatazoo/__init__.py | 0 src/ddr/geodatazoo/protocols.py | 379 ++++++++++++++++++ .../providers/hydrofabric/README.md | 3 + .../providers/hydrofabric/__init__.py | 11 + .../providers/hydrofabric/attributes.py | 103 +++++ .../providers/hydrofabric/network.py | 148 +++++++ .../providers/hydrofabric/streamflow.py | 86 ++++ src/ddr/geodatazoo/types.py | 25 ++ 14 files changed, 950 insertions(+), 5 deletions(-) create mode 100644 shoutout.md create mode 100644 src/ddr/geodatazoo/Dates.py create mode 100644 src/ddr/geodatazoo/Gauges.py create mode 100644 src/ddr/geodatazoo/__init__.py create mode 100644 src/ddr/geodatazoo/protocols.py create mode 100644 src/ddr/geodatazoo/providers/hydrofabric/README.md create mode 100644 src/ddr/geodatazoo/providers/hydrofabric/__init__.py create mode 100644 src/ddr/geodatazoo/providers/hydrofabric/attributes.py create mode 100644 src/ddr/geodatazoo/providers/hydrofabric/network.py create mode 100644 src/ddr/geodatazoo/providers/hydrofabric/streamflow.py create mode 100644 src/ddr/geodatazoo/types.py diff --git a/examples/eval/evaluate.ipynb b/examples/eval/evaluate.ipynb index f804b7a..a1671b9 100644 --- a/examples/eval/evaluate.ipynb +++ b/examples/eval/evaluate.ipynb @@ -63,7 +63,7 @@ "# Loading paths to results. We're comparing to summed q_prime as it's a good indicator of if routing is working\n", "summed_q_prime_path = Path(\"./summed_q_prime.zarr\") # To obtain this, please run scripts/summed_q_prime.py\n", "predictions_path = Path(\n", - " \"./model_test.zarr\"\n", + " \"/projects/mhpi/tbindas/ddr/runs/ddr-v0.1.3-eval/2025-08-28_08-24-32/model_test.zarr\"\n", ") # To obtain this, please run scripts/test.py to evaluate a trained model\n", "\n", "ds_qp = xr.open_zarr(summed_q_prime_path)\n", @@ -221,7 +221,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.5" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/examples/parameter_maps/plot_parameter_map.ipynb b/examples/parameter_maps/plot_parameter_map.ipynb index 5ab778c..caa74c5 100644 --- a/examples/parameter_maps/plot_parameter_map.ipynb +++ b/examples/parameter_maps/plot_parameter_map.ipynb @@ -88,7 +88,9 @@ "outputs": [], "source": [ "# Load pretrained model states\n", - "model_states = Path(\"./ddr_v0.1.0a2_trained_model_weights.pt\")\n", + "model_states = Path(\n", + " \"/projects/mhpi/tbindas/ddr/runs/ddr-v0.1.3-train/2025-08-27_19-26-09/saved_models/_ddr-v0.1.3-train_epoch_5_mb_42.pt\"\n", + ")\n", "\n", "log.info(f\"Loading spatial_nn from checkpoint: {model_states.stem}\")\n", "state = torch.load(model_states, map_location=device)\n", diff --git a/pyproject.toml b/pyproject.toml index 167d3fa..7eca54a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "pykan==0.2.8", "scikit-learn==1.7.0", "scipy==1.16.0", + "torch==2.7.1", "tqdm==4.67.1", "xarray==2025.7.1", "zarr==3.0.9", @@ -52,7 +53,6 @@ docs = [ "sympy==1.14.0" ] cuda = [ - "torch==2.7.1", "cupy-cuda12x==13.4.1", ] @@ -63,7 +63,6 @@ tests = [ "ruff==0.12.2", "nbstripout==0.8.1", "boto3==1.39.14", - ] engine = [ "adbc-driver-sqlite==1.6.0", diff --git a/shoutout.md b/shoutout.md new file mode 100644 index 0000000..4c3f87a --- /dev/null +++ b/shoutout.md @@ -0,0 +1,7 @@ +### Shoutout File + +This is a file that shouts out work that was used as inspiration for parts of this code + +#### Neural Hydrology +https://github.com/neuralhydrology/neuralhydrology +The dataset abstractions used for hooking up many datasets in `datasetzoo` were used in abstracting the datasets in this code. diff --git a/src/ddr/geodatazoo/Dates.py b/src/ddr/geodatazoo/Dates.py new file mode 100644 index 0000000..03fc3d0 --- /dev/null +++ b/src/ddr/geodatazoo/Dates.py @@ -0,0 +1,128 @@ +import logging +from datetime import datetime +from typing import Any + +import numpy as np +import pandas as pd +import torch +from pydantic import BaseModel, ConfigDict, model_validator + +log = logging.getLogger(__name__) + + +class Dates(BaseModel): + """Dates class for handling time operations for training dMC models""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + daily_format: str = "%Y/%m/%d" + hourly_format: str = "%Y/%m/%d %H:%M:%S" + origin_start_date: str = "1980/01/01" + start_time: str + end_time: str + rho: int | None = None + batch_daily_time_range: pd.DatetimeIndex | None = pd.DatetimeIndex([], dtype="datetime64[ns]") + batch_hourly_time_range: pd.DatetimeIndex | None = pd.DatetimeIndex([], dtype="datetime64[ns]") + daily_time_range: pd.DatetimeIndex | None = pd.DatetimeIndex([], dtype="datetime64[ns]") + daily_indices: np.ndarray = np.empty(0) + hourly_time_range: pd.DatetimeIndex | None = pd.DatetimeIndex([], dtype="datetime64[ns]") + hourly_indices: torch.Tensor | None = torch.empty(0) + numerical_time_range: np.ndarray = np.empty(0) + + def __init__(self, **kwargs): + super().__init__( + start_time=kwargs["start_time"], + end_time=kwargs["end_time"], + rho=kwargs["rho"], + ) + + @model_validator(mode="after") + @classmethod + def validate_dates(cls, dates: Any) -> Any: + """Validates that the number of days you select is within the range of the start and end times""" + rho = dates.rho + if isinstance(rho, int): + if rho > len(dates.daily_time_range): + log.exception( + ValueError("Rho needs to be smaller than the routed period between start and end times") + ) + raise ValueError("Rho needs to be smaller than the routed period between start and end times") + return dates + + def model_post_init(self, __context: Any) -> None: + """Initializes the Dates object and time ranges""" + self.daily_time_range = pd.date_range( + datetime.strptime(self.start_time, self.daily_format), + datetime.strptime(self.end_time, self.daily_format), + freq="D", + inclusive="both", + ) + self.hourly_time_range = pd.date_range( + start=self.daily_time_range[0], + end=self.daily_time_range[-1], + freq="h", + inclusive="left", + ) + self.batch_daily_time_range = self.daily_time_range + self.set_batch_time(self.daily_time_range) + + def set_batch_time(self, daily_time_range: pd.DatetimeIndex) -> None: + """Sets the time range for the batch you're train/test/simulating + + Parameters + ---------- + daily_time_range : pd.DatetimeIndex + The daily time range you want to select + """ + self.batch_hourly_time_range = pd.date_range( + start=daily_time_range[0], + end=daily_time_range[-1], + freq="h", + inclusive="left", + ) + origin_start_date = datetime.strptime(self.origin_start_date, self.daily_format) + origin_base_start_time = int( + (daily_time_range[0].to_pydatetime() - origin_start_date).total_seconds() / 86400 + ) + origin_base_end_time = int( + (daily_time_range[-1].to_pydatetime() - origin_start_date).total_seconds() / 86400 + ) + + # The indices for the dates in your selected routing time range + self.numerical_time_range = np.arange(origin_base_start_time, origin_base_end_time + 1, 1) + + self._create_daily_indices() + self._create_hourly_indices() + + def _create_hourly_indices(self) -> None: + common_elements = self.hourly_time_range.intersection(self.batch_hourly_time_range) + self.hourly_indices = torch.tensor([self.hourly_time_range.get_loc(time) for time in common_elements]) + + def _create_daily_indices(self): + common_elements = self.daily_time_range.intersection(self.batch_daily_time_range) + self.daily_indices = np.array([self.daily_time_range.get_loc(time) for time in common_elements]) + + def calculate_time_period(self) -> None: + """Calculates the time period for the dataset using rho""" + if self.rho is not None: + sample_size = len(self.daily_time_range) + random_start = torch.randint(low=0, high=sample_size - self.rho, size=(1, 1))[0][0].item() + self.batch_daily_time_range = self.daily_time_range[random_start : (random_start + self.rho)] + self.set_batch_time(self.batch_daily_time_range) + + def set_date_range(self, chunk: np.ndarray) -> None: + """Sets the date range for the dataset + + Parameters + ---------- + chunk : np.ndarray + The chunk of the date range you want to select + """ + self.batch_daily_time_range = self.daily_time_range[chunk] + self.set_batch_time(self.batch_daily_time_range) + + def create_time_windows(self) -> np.ndarray: + """Creates the time slices, or windows, for testing the model""" + num_pieces = self.daily_time_range.shape[0] // self.rho + last_time = num_pieces * self.rho + reshaped_arr = np.reshape(self.daily_time_range[:last_time], (num_pieces, self.rho)) + return reshaped_arr diff --git a/src/ddr/geodatazoo/Gauges.py b/src/ddr/geodatazoo/Gauges.py new file mode 100644 index 0000000..9d62354 --- /dev/null +++ b/src/ddr/geodatazoo/Gauges.py @@ -0,0 +1,54 @@ +import csv +from pathlib import Path +from typing import Annotated + +from pydantic import AfterValidator, BaseModel, ConfigDict, PositiveFloat + + +def zfill_usgs_id(STAID: str) -> str: + """Ensures all USGS gauge strings that are filled to 8 digits + + Parameters + ---------- + STAID: str + The USGS Station ID + + Returns + ------- + str + The eight-digit USGS Gauge ID + """ + return STAID.zfill(8) + + +class Gauge(BaseModel): + """A pydantic object for managing properties for a Gauge and validating incoming CSV files""" + + model_config = ConfigDict(extra="ignore") + STAID: Annotated[str, AfterValidator(zfill_usgs_id)] + DRAIN_SQKM: PositiveFloat + + +class GaugeSet(BaseModel): + """A pydantic object for storing a list of Gauges""" + + gauges: list[Gauge] + + +def validate_gages(file_path: Path) -> GaugeSet: + """A function to read the training gauges file and validate based on a pydantic schema + + Parameters + ---------- + file_path: Path + The path to the gauges csv file + + Returns + ------- + GaugeSet + A set of pydantic-validated gauges + """ + with file_path.open() as f: + reader = csv.DictReader(f) + gauges = [Gauge.model_validate(row) for row in reader] + return GaugeSet(gauges=gauges) diff --git a/src/ddr/geodatazoo/__init__.py b/src/ddr/geodatazoo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ddr/geodatazoo/protocols.py b/src/ddr/geodatazoo/protocols.py new file mode 100644 index 0000000..51cdf68 --- /dev/null +++ b/src/ddr/geodatazoo/protocols.py @@ -0,0 +1,379 @@ +"""Protocol definitions for dataset providers. These define the interfaces geospatial data objects must implement.""" + +from abc import ABC, abstractmethod +from typing import Protocol + +import torch +from torch.utils.data import Dataset as TorchDataset + +from ddr.geodatazoo.Dates import Dates +from ddr.validation.validate_configs import Config + + +class BaseGeoDataset(TorchDataset, ABC): + """Lays out the base implementation of a geospatial routing dataset""" + + def __init__(self, cfg: Config, is_train: bool): + self.cfg = cfg + self.is_train = is_train + self.dates = Dates(**self.cfg.experiment.model_dump()) + + self.observation_provider = None + + # Load data via abstract methods + self._load_network_data() + self._load_attribute_data() + self._load_observation_data() + + # Setup statistics for normalization + self._setup_statistics() + + # For test mode, build network once + if not is_train: + self._setup_test_network() + else: + self.hydrofabric = None + + # ======================================================================== + # Abstract methods - subclasses MUST implement these + # ======================================================================== + + @abstractmethod + def _load_network_data(self): + """Load network topology and channel parameters. + + Subclasses should set self.network_provider to an object implementing + the NetworkProvider protocol. + + Examples + -------- + >>> def _load_network_data(self): + ... from ddr.providers.nextgen import NextGenNetworkProvider + ... + ... self.network_provider = NextGenNetworkProvider( + ... gpkg_path=self.cfg.data_sources.hydrofabric_gpkg, + ... adjacency_path=self.cfg.data_sources.conus_adjacency, + ... gages_adjacency_path=self.cfg.data_sources.gages_adjacency, + ... ) + """ + raise NotImplementedError + + @abstractmethod + def _load_attribute_data(self): + """Load catchment attributes. + + Subclasses should set self.attribute_provider to an object implementing + the AttributeProvider protocol. + + Examples + -------- + >>> def _load_attribute_data(self): + ... from ddr.dataset.attributes import AttributesReader + ... + ... self.attribute_provider = AttributesReader(cfg=self.cfg) + """ + raise NotImplementedError + + @abstractmethod + def _load_observation_data(self): + """Load observation data. + + Subclasses should set self.observation_provider to an object implementing + the ObservationProvider protocol, and populate self.gauge_ids. + + Examples + -------- + >>> def _load_observation_data(self): + ... from ddr.dataset.observations import IcechunkUSGSReader + ... + ... self.observation_provider = IcechunkUSGSReader(cfg=self.cfg) + ... self.gauge_ids = self.observation_provider.get_gauge_ids() + """ + raise NotImplementedError + + # ======================================================================== + # Common methods - implemented in base class + # ======================================================================== + + def _setup_statistics(self): + """Setup normalization statistics from attribute provider.""" + from ddr.dataset.statistics import set_statistics + + # Get the underlying dataset from attribute provider if available + attr_ds = getattr(self.attribute_provider, "ds", None) + if attr_ds is not None: + self.attr_stats = set_statistics(self.cfg, attr_ds) + + self.means = torch.tensor( + [self.attr_stats[attr].iloc[2] for attr in self.cfg.kan.input_var_names], + device=self.cfg.device, + dtype=torch.float32, + ).unsqueeze(1) + + self.stds = torch.tensor( + [self.attr_stats[attr].iloc[3] for attr in self.cfg.kan.input_var_names], + device=self.cfg.device, + dtype=torch.float32, + ).unsqueeze(1) + else: + self.means = None + self.stds = None + + def _setup_test_network(self): + """Build network once for all gauges (test mode only).""" + if not hasattr(self, "gauge_ids"): + raise ValueError("gauge_ids must be set before calling _setup_test_network") + + # Get all reach IDs for all gauges + all_reach_ids = [] + valid_gauge_ids = [] + + for gid in self.gauge_ids: + reach_id = f"wb-{gid}" + try: + upstream = self.network_provider.find_upstream_reaches(reach_id) + if len(upstream) > 0: + all_reach_ids.extend(upstream) + valid_gauge_ids.append(gid) + except Exception as e: + log.info(f"Cannot find upstream reaches for gauge {gid}: {e}") + continue + + # Remove duplicates + reach_ids = list(dict.fromkeys(all_reach_ids)) + + # Build hydrofabric + self.hydrofabric = self._build_hydrofabric(reach_ids, valid_gauge_ids) + + def _build_hydrofabric(self, reach_ids: list[str], gauge_ids: list[str] = None) -> Hydrofabric: + """Build hydrofabric object from reach IDs. + + Parameters + ---------- + reach_ids : list[str] + List of reach IDs to include in network + gauge_ids : list[str], optional + List of gauge IDs for observations + + Returns + ------- + Hydrofabric + Complete hydrofabric object ready for model input + """ + # Get network topology + adjacency = self.network_provider.get_topology(reach_ids) + + # Get catchment IDs + catchment_ids = self.network_provider.get_catchment_ids(reach_ids) + + # Get channel parameters + channel_params = self.network_provider.get_channel_parameters(reach_ids) + + # Get spatial attributes + raw_attrs, normalized_attrs = self.attribute_provider.get_attributes( + catchment_ids, self.cfg.kan.input_var_names + ) + + # Fill NaNs in raw attributes + for r in range(raw_attrs.shape[0]): + row_means = torch.nanmean(raw_attrs[r]) + nan_mask = torch.isnan(raw_attrs[r]) + raw_attrs[r, nan_mask] = row_means + + # Create sparse adjacency tensor + adjacency_matrix = torch.sparse_csr_tensor( + crow_indices=adjacency.indptr, + col_indices=adjacency.indices, + values=adjacency.data, + size=adjacency.shape, + device=self.cfg.device, + dtype=torch.float32, + ) + + # Convert channel parameters to tensors with NaN filling + phys_means = torch.tensor( + [ + naninfmean(channel_params["length"]), + naninfmean(channel_params["slope"]), + naninfmean(channel_params["top_width"]), + naninfmean(channel_params["side_slope"]), + naninfmean(channel_params["x"]), + ], + device=self.cfg.device, + dtype=torch.float32, + ).unsqueeze(1) + + length = fill_nans( + torch.tensor(channel_params["length"], dtype=torch.float32), + row_means=phys_means[0], + ) + slope = fill_nans( + torch.tensor(channel_params["slope"], dtype=torch.float32), + row_means=phys_means[1], + ) + top_width = fill_nans( + torch.tensor(channel_params["top_width"], dtype=torch.float32), + row_means=phys_means[2], + ) + side_slope = fill_nans( + torch.tensor(channel_params["side_slope"], dtype=torch.float32), + row_means=phys_means[3], + ) + x = fill_nans( + torch.tensor(channel_params["x"], dtype=torch.float32), + row_means=phys_means[4], + ) + + # Handle observations + observations = None + gage_idx = None + gage_wb = None + + if self.observation_provider and gauge_ids: + observations = self.observation_provider.get_observations(self.dates) + + # Map gauge IDs to reach indices + gage_idx = [] + gage_wb = [] + for gid in gauge_ids: + wb_id = f"wb-{gid}" + if wb_id in reach_ids: + idx = reach_ids.index(wb_id) + gage_idx.append(np.array([idx])) + gage_wb.append(wb_id) + + return Hydrofabric( + spatial_attributes=raw_attrs, + length=length, + slope=slope, + side_slope=side_slope, + top_width=top_width, + x=x, + dates=self.dates, + adjacency_matrix=adjacency_matrix, + normalized_spatial_attributes=normalized_attrs, + observations=observations, + divide_ids=np.array(catchment_ids), + gage_idx=gage_idx, + gage_wb=gage_wb, + ) + + # ======================================================================== + # PyTorch Dataset interface + # ======================================================================== + + def __len__(self) -> int: + """Return number of samples based on training mode.""" + if self.is_train: + return len(self.gauge_ids) + else: + return len(self.dates.daily_time_range) + + def __getitem__(self, idx: int) -> str | int: + """Get item at index. + + Returns gauge ID if training, timestep index if testing. + """ + if self.is_train: + return self.gauge_ids[idx] + else: + return idx + + def collate_fn(self, batch: list) -> Hydrofabric: + """Collate batch into Hydrofabric object. + + Parameters + ---------- + batch : list + List of gauge IDs (if is_train=True) or timestep indices (if is_train=False) + + Returns + ------- + Hydrofabric + Hydrofabric object ready for model input + """ + if self.is_train: + return self._collate_train(batch) + else: + return self._collate_test(batch) + + def _collate_train(self, gauge_ids: list[str]) -> Hydrofabric: + """Collate training batch (sample by gauge). + + Parameters + ---------- + gauge_ids : list[str] + Batch of gauge IDs + + Returns + ------- + Hydrofabric + Hydrofabric for these gauges with random time period + """ + # Calculate random time period if rho is set + self.dates.calculate_time_period() + + # Find all upstream reaches for these gauges + all_reach_ids = [] + valid_gauge_ids = [] + + for gid in gauge_ids: + reach_id = f"wb-{gid}" + try: + upstream = self.network_provider.find_upstream_reaches(reach_id) + if len(upstream) > 0: + all_reach_ids.extend(upstream) + valid_gauge_ids.append(gid) + except Exception as e: + log.info(f"Cannot find upstream reaches for gauge {gid}: {e}") + continue + + # Remove duplicates while preserving order + reach_ids = list(dict.fromkeys(all_reach_ids)) + + # Build and return hydrofabric + return self._build_hydrofabric(reach_ids, valid_gauge_ids) + + def _collate_test(self, time_indices: list[int]) -> Hydrofabric: + """Collate test batch (sample by time). + + Parameters + ---------- + time_indices : list[int] + Batch of timestep indices + + Returns + ------- + Hydrofabric + Pre-built hydrofabric with updated dates + """ + # Add previous day for interpolation if needed + indices = list(time_indices) + if 0 not in indices and len(indices) > 0: + prev_day = indices[0] - 1 + if prev_day >= 0: + indices.insert(0, prev_day) + + # Update dates + self.dates.set_date_range(np.array(indices)) + + # Return pre-built hydrofabric (network doesn't change) + return self.hydrofabric + + +class AttributeProvider(Protocol): + """Protocol for catchment attribute providers.""" + + def get_attributes( + self, catchment_ids: list[str], attribute_names: list[str] + ) -> tuple[torch.Tensor, torch.Tensor]: + """Return (raw_attributes, normalized_attributes).""" + ... + + +class StreamflowProvider(Protocol): + """Protocol for streamflow providers.""" + + def get_streamflow(self, catchment_ids: list[str], dates: Dates) -> torch.Tensor: + """Return streamflow (num_timesteps, num_catchments).""" + ... diff --git a/src/ddr/geodatazoo/providers/hydrofabric/README.md b/src/ddr/geodatazoo/providers/hydrofabric/README.md new file mode 100644 index 0000000..c683ea5 --- /dev/null +++ b/src/ddr/geodatazoo/providers/hydrofabric/README.md @@ -0,0 +1,3 @@ +### Hydrofabric + +This dataset is the Lynker v2.2 Hydrofabric diff --git a/src/ddr/geodatazoo/providers/hydrofabric/__init__.py b/src/ddr/geodatazoo/providers/hydrofabric/__init__.py new file mode 100644 index 0000000..efee0ea --- /dev/null +++ b/src/ddr/geodatazoo/providers/hydrofabric/__init__.py @@ -0,0 +1,11 @@ +"""This module contains providers for the NextGen Hydrofabric v2.2+""" + +from .attributes import NextGenAttributeProvider +from .network import NextGenNetworkProvider +from .streamflow import NextGenStreamflowProvider + +__all__ = [ + "NextGenNetworkProvider", + "NextGenAttributeProvider", + "NextGenStreamflowProvider", +] diff --git a/src/ddr/geodatazoo/providers/hydrofabric/attributes.py b/src/ddr/geodatazoo/providers/hydrofabric/attributes.py new file mode 100644 index 0000000..fe8ebb6 --- /dev/null +++ b/src/ddr/geodatazoo/providers/hydrofabric/attributes.py @@ -0,0 +1,103 @@ +"""NextGen attribute provider.""" + +import logging + +import torch + +from ddr.dataset.attributes import AttributesReader +from ddr.dataset.statistics import set_statistics +from ddr.validation.validate_configs import Config + +log = logging.getLogger(__name__) + + +class NextGenAttributeProvider: + """Provider for NextGen catchment attributes. + + This reads attributes from icechunk stores that use NextGen + catchment IDs (cat-XXX format). + + The icechunk store must have: + - Dimension: divide_id (with cat-XXX format IDs) + - Variables: Specified in cfg.kan.input_var_names + + Parameters + ---------- + cfg : Config + Configuration with: + - data_sources.attributes: Path to icechunk store + - kan.input_var_names: List of attribute names to load + - device: Device for tensors + + Examples + -------- + >>> provider = NextGenAttributeProvider(cfg) + >>> raw, normalized = provider.get_attributes( + ... catchment_ids=["cat-1", "cat-2"], attribute_names=["elevation", "slope"] + ... ) + >>> print(f"Raw shape: {raw.shape}") # (2, 2) - (n_attrs, n_catchments) + >>> print(f"Normalized shape: {normalized.shape}") # (2, 2) transposed + """ + + def __init__(self, cfg: Config): + self.cfg = cfg + + # Initialize icechunk reader + self.reader = AttributesReader(cfg) + + # Compute statistics for normalization + self.stats = set_statistics(cfg, self.reader.ds) + + # Pre-compute normalization parameters + self.means = torch.tensor( + [self.stats[attr].iloc[2] for attr in cfg.kan.input_var_names], + device=cfg.device, + dtype=torch.float32, + ).unsqueeze(1) + + self.stds = torch.tensor( + [self.stats[attr].iloc[3] for attr in cfg.kan.input_var_names], + device=cfg.device, + dtype=torch.float32, + ).unsqueeze(1) + + log.info(f"Initialized NextGenAttributeProvider with {len(cfg.kan.input_var_names)} attributes") + + def get_attributes( + self, catchment_ids: list[str], attribute_names: list[str] + ) -> tuple[torch.Tensor, torch.Tensor]: + """Get attributes for NextGen catchments. + + Parameters + ---------- + catchment_ids : list[str] + NextGen catchment IDs (cat-XXX format) + attribute_names : list[str] + Names of attributes to retrieve (must match cfg.kan.input_var_names) + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + (raw_attributes, normalized_attributes) + - raw: shape (n_attributes, n_catchments) + - normalized: shape (n_catchments, n_attributes) - transposed for NN input + """ + # Read from icechunk + raw = self.reader( + divide_ids=catchment_ids, + attr_means=self.means, + device=self.cfg.device, + dtype=torch.float32, + ) + + # Fill NaNs with row means + for r in range(raw.shape[0]): + row_mean = torch.nanmean(raw[r]) + nan_mask = torch.isnan(raw[r]) + raw[r, nan_mask] = row_mean + + # Normalize: (x - mean) / std + normalized = (raw - self.means) / self.stds + + # Transpose for NN input: (n_catchments, n_attributes) + return raw, normalized.T diff --git a/src/ddr/geodatazoo/providers/hydrofabric/network.py b/src/ddr/geodatazoo/providers/hydrofabric/network.py new file mode 100644 index 0000000..cd648b3 --- /dev/null +++ b/src/ddr/geodatazoo/providers/hydrofabric/network.py @@ -0,0 +1,148 @@ +"""NextGen network topology and geometry provider.""" + +from pathlib import Path + +import geopandas as gpd +import numpy as np +from scipy import sparse + + +class NextGenNetworkProvider: + """Provider for NextGen Hydrofabric network. + + Handles: + - Loading geopackage flowpath attributes + - Building topology from zarr adjacency matrix + - Extracting channel geometry + - Mapping wb-XXX (waterbodies) to cat-XXX (catchments) + + Parameters + ---------- + gpkg_path : Path + Path to NextGen geopackage (e.g., conus_nextgen.gpkg) + adjacency_path : Path + Path to zarr store with pre-computed adjacency matrix + (created by engine/adjacency.py) + gages_adjacency_path : Path, optional + Path to gauge-specific adjacency matrices + (created by engine/gages_adjacency.py) + + Examples + -------- + >>> provider = NextGenNetworkProvider( + ... gpkg_path=Path("data/conus_nextgen.gpkg"), + ... adjacency_path=Path("data/conus_adjacency.zarr"), + ... ) + >>> + >>> # Get full CONUS network + >>> topology = provider.get_topology() + >>> + >>> # Get subset for specific reaches + >>> subset_topology = provider.get_topology(["wb-1", "wb-2"]) + >>> + >>> # Get channel geometry + >>> geometry = provider.get_geometry(["wb-1", "wb-2"]) + >>> print(f"Length: {geometry.length}") + """ + + def __init__(self, gpkg_path: Path, adjacency_path: Path, gages_adjacency_path: Path | None = None): + self.gpkg_path = gpkg_path + self.adjacency_path = adjacency_path + self.gages_adjacency_path = gages_adjacency_path + + self._load_data() + + def _load_data(self): + """Load NextGen hydrofabric data.""" + # Load flowpath attributes from geopackage + self.flowpath_attr = gpd.read_file(self.gpkg_path, layer="flowpath-attributes-ml").set_index("id") + + # Remove duplicates (if any) + self.flowpath_attr = self.flowpath_attr[~self.flowpath_attr.index.duplicated(keep="first")] + + # Load CONUS adjacency matrix + self.conus_adjacency = read_zarr(self.adjacency_path) + self.hf_ids = self.conus_adjacency["order"][:] + + # Load gauge adjacency if provided + if self.gages_adjacency_path: + self.gages_adjacency = read_zarr(self.gages_adjacency_path) + else: + self.gages_adjacency = None + + self._build_topology() + + def _build_topology(self): + """Build sparse adjacency matrix.""" + rows = self.conus_adjacency["indices_0"][:].tolist() + cols = self.conus_adjacency["indices_1"][:].tolist() + shape = tuple(dict(self.conus_adjacency.attrs)["shape"]) + + self.full_topology = sparse.coo_matrix((np.ones(len(rows)), (rows, cols)), shape=shape).tocsr() + + self.id_to_idx = {f"wb-{_id}": i for i, _id in enumerate(self.hf_ids)} + + def get_topology(self, reach_ids: list[str] | None = None) -> sparse.csr_matrix: + """Get network topology as sparse adjacency matrix.""" + if reach_ids is None: + return self.full_topology + + indices = [self.id_to_idx[rid] for rid in reach_ids if rid in self.id_to_idx] + return self.full_topology[indices, :][:, indices] + + def get_reach_ids(self) -> list[str]: + """Get all reach IDs in topological order.""" + return [f"wb-{_id}" for _id in self.hf_ids] + + def get_geometry(self, reach_ids: list[str]) -> ChannelGeometry: + """Get channel geometry from geopackage.""" + subset = self.flowpath_attr.reindex(reach_ids) + + return ChannelGeometry( + length=subset["Length_m"].fillna(1000.0).values, + slope=subset["So"].fillna(0.001).values, + top_width=subset["TopWdth"].fillna(10.0).values, + side_slope=subset["ChSlp"].fillna(2.0).values, + x=subset["MusX"].fillna(0.2).values, + ) + + def get_catchment_ids(self, reach_ids: list[str]) -> list[str]: + """Map waterbody IDs to catchment IDs. + + NextGen convention: wb-XXX → cat-XXX + """ + return [rid.replace("wb-", "cat-") for rid in reach_ids] + + def find_upstream_reaches(self, outlet_reach_id: str) -> list[str]: + """Find all reaches upstream of outlet. + + Uses gauge adjacency if available, otherwise does BFS on full topology. + """ + # Try gauge adjacency first (pre-computed, faster) + if self.gages_adjacency: + gauge_id = outlet_reach_id.replace("wb-", "") + if gauge_id in self.gages_adjacency: + order = self.gages_adjacency[gauge_id]["order"][:] + return [f"wb-{_id}" for _id in order] + + # Fallback to BFS on full topology + outlet_idx = self.id_to_idx.get(outlet_reach_id) + if outlet_idx is None: + return [] + + visited = set() + queue = [outlet_idx] + upstream_indices = [] + + while queue: + idx = queue.pop(0) + if idx in visited: + continue + visited.add(idx) + upstream_indices.append(idx) + + # Find upstream reaches + upstream = self.full_topology[idx, :].nonzero()[1] + queue.extend(upstream) + + return [f"wb-{self.hf_ids[i]}" for i in upstream_indices] diff --git a/src/ddr/geodatazoo/providers/hydrofabric/streamflow.py b/src/ddr/geodatazoo/providers/hydrofabric/streamflow.py new file mode 100644 index 0000000..48d3fd0 --- /dev/null +++ b/src/ddr/geodatazoo/providers/hydrofabric/streamflow.py @@ -0,0 +1,86 @@ +"""NextGen streamflow provider.""" + +import logging + +import torch + +from ddr.dataset.Dates import Dates +from ddr.dataset.streamflow import StreamflowReader +from ddr.validation.validate_configs import Config + +log = logging.getLogger(__name__) + + +class NextGenStreamflowProvider: + """Provider for NextGen streamflow/runoff data. + + This reads streamflow from icechunk stores that use NextGen + catchment IDs (cat-XXX format). + + The icechunk store must have: + - Dimensions: divide_id, time + - Variable: Qr (runoff in m³/s) + - divide_id values in cat-XXX format + + Parameters + ---------- + cfg : Config + Configuration with: + - data_sources.streamflow: Path to icechunk store + - device: Device for tensors + + Examples + -------- + >>> provider = NextGenStreamflowProvider(cfg) + >>> streamflow = provider.get_streamflow(catchment_ids=["cat-1", "cat-2"], dates=dates_object) + >>> print(f"Shape: {streamflow.shape}") # (n_timesteps, n_catchments) + """ + + def __init__(self, cfg: Config): + self.cfg = cfg + self.reader = StreamflowReader(cfg) + log.info("Initialized NextGenStreamflowProvider") + + def get_streamflow(self, catchment_ids: list[str], dates: Dates) -> torch.Tensor: + """Get streamflow for NextGen catchments. + + Parameters + ---------- + catchment_ids : list[str] + NextGen catchment IDs (cat-XXX format) + dates : Dates + Time period for streamflow data + + Returns + ------- + torch.Tensor + Streamflow data with shape (n_timesteps, n_catchments) + """ + # Create temporary hydrofabric for reader interface + # (StreamflowReader expects Hydrofabric object) + import numpy as np + + from ddr.dataset.utils import Hydrofabric + + temp_hydrofabric = Hydrofabric( + divide_ids=np.array(catchment_ids), + dates=dates, + # Other fields not needed by StreamflowReader + spatial_attributes=None, + length=None, + slope=None, + side_slope=None, + top_width=None, + x=None, + adjacency_matrix=None, + normalized_spatial_attributes=None, + observations=None, + gage_idx=None, + gage_wb=None, + ) + + return self.reader( + hydrofabric=temp_hydrofabric, + device=self.cfg.device, + dtype=torch.float32, + ) diff --git a/src/ddr/geodatazoo/types.py b/src/ddr/geodatazoo/types.py new file mode 100644 index 0000000..6c96bce --- /dev/null +++ b/src/ddr/geodatazoo/types.py @@ -0,0 +1,25 @@ +"""Data types for dataset module.""" + +import numpy as np +from pydantic import BaseModel + + +class ChannelParameters(BaseModel): + """Channel parameter data""" + + length: np.ndarray # Channel length (meters) + slope: np.ndarray # Channel slope (dimensionless) + top_width: np.ndarray # Top width (meters) + side_slope: np.ndarray # Side slope ratio (horizontal:vertical) + x: np.ndarray # Muskingum X parameter (0-0.5) + + +class NetworkMetadata(BaseModel): + """Metadata about a river network""" + + name: str + version: str + num_reaches: int + coordinate_system: str + id_prefix: str + description: str