diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 0ab3d142..d5b7a2d2 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -26,6 +26,22 @@ jobs: ruff check . --output-format=github ruff format --check . + type-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - name: Set up Python 3.13 + uses: actions/setup-python@v6 + with: + python-version: "3.13" + cache: "pip" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[type-check] + - name: Type-check with mypy + run: mypy + test: needs: lint runs-on: ${{ matrix.os }} diff --git a/dataretrieval/nadp.py b/dataretrieval/nadp.py index 3d1ee442..d6b26381 100644 --- a/dataretrieval/nadp.py +++ b/dataretrieval/nadp.py @@ -29,6 +29,8 @@ """ +from __future__ import annotations + import io import re import warnings @@ -45,7 +47,7 @@ ) -def _warn_deprecated(): +def _warn_deprecated() -> None: warnings.warn(_DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=3) @@ -74,19 +76,19 @@ def _warn_deprecated(): class NADP_ZipFile(zipfile.ZipFile): """Extend zipfile.ZipFile for working on data from NADP""" - def tif_name(self): + def tif_name(self) -> str: """Get the name of the tif file in the zip file.""" filenames = self.namelist() r = re.compile(".*tif$") tif_list = list(filter(r.match, filenames)) return tif_list[0] - def tif(self): + def tif(self) -> bytes: """Read the tif file in the zip file.""" return self.read(self.tif_name()) -def get_annual_MDN_map(measurement_type, year, path): +def get_annual_MDN_map(measurement_type: str, year: str, path: str) -> str: """Download a MDN map from NDAP. This function looks for a zip file containing gridded information at: @@ -135,7 +137,12 @@ def get_annual_MDN_map(measurement_type, year, path): return str(path) -def get_annual_NTN_map(measurement_type, measurement=None, year=None, path="."): +def get_annual_NTN_map( + measurement_type: str, + measurement: str | None = None, + year: str | None = None, + path: str = ".", +) -> str: """Download a NTN map from NDAP. This function looks for a zip file containing gridded information at: @@ -193,7 +200,7 @@ def get_annual_NTN_map(measurement_type, measurement=None, year=None, path="."): return str(path) -def get_zip(url, filename): +def get_zip(url: str, filename: str) -> NADP_ZipFile: """Gets a ZipFile at url and returns it Parameters diff --git a/dataretrieval/nldi.py b/dataretrieval/nldi.py index 8d61fcc2..a03aa1e6 100644 --- a/dataretrieval/nldi.py +++ b/dataretrieval/nldi.py @@ -1,7 +1,7 @@ from __future__ import annotations from json import JSONDecodeError -from typing import Literal +from typing import Any, Literal, cast from dataretrieval.utils import query @@ -16,13 +16,17 @@ _VALID_NAVIGATION_MODES = ("UM", "DM", "UT", "DD") -def _query_nldi(url, query_params, error_message): +def _query_nldi( + url: str, + query_params: dict[str, str], + error_message: str, +) -> dict[str, Any] | list[Any]: # A helper function to query the NLDI API response = query(url, payload=query_params) if response.status_code != 200: raise ValueError(f"{error_message}. Error reason: {response.reason_phrase}") - response_data = {} + response_data: dict[str, Any] | list[Any] = {} try: response_data = response.json() except JSONDecodeError: @@ -32,7 +36,7 @@ def _query_nldi(url, query_params, error_message): return response_data -def _features_to_gdf(feature_collection: dict) -> gpd.GeoDataFrame: +def _features_to_gdf(feature_collection: dict[str, Any]) -> gpd.GeoDataFrame: """Build a GeoDataFrame from an NLDI FeatureCollection, tolerating empties. NLDI can legitimately return no features (e.g. a feature with nothing @@ -56,7 +60,7 @@ def get_flowlines( stop_comid: int | None = None, trim_start: bool = False, as_json: bool = False, -) -> gpd.GeoDataFrame | dict: +) -> gpd.GeoDataFrame | dict[str, Any]: """Gets the flowlines for the specified navigation either by comid or feature source in WGS84 lat/long coordinates as GeoDataFrame containing a polyline geometry. @@ -116,7 +120,7 @@ def get_flowlines( else: err_msg = f"Error getting flowlines for comid '{comid}'" - feature_collection = _query_nldi(url, query_params, err_msg) + feature_collection = cast("dict[str, Any]", _query_nldi(url, query_params, err_msg)) if as_json: return feature_collection gdf = _features_to_gdf(feature_collection) @@ -129,7 +133,7 @@ def get_basin( simplified: bool = True, split_catchment: bool = False, as_json: bool = False, -) -> gpd.GeoDataFrame | dict: +) -> gpd.GeoDataFrame | dict[str, Any]: """Gets the aggregated basin for the specified feature in WGS84 lat/lon as GeoDataFrame or as JSON conatining a polygon geometry. @@ -162,14 +166,17 @@ def get_basin( raise ValueError("feature_id is required") url = f"{NLDI_API_BASE_URL}/{feature_source}/{feature_id}/basin" - simplified = str(simplified).lower() - split_catchment = str(split_catchment).lower() - query_params = {"simplified": simplified, "splitCatchment": split_catchment} + simplified_str = str(simplified).lower() + split_catchment_str = str(split_catchment).lower() + query_params = { + "simplified": simplified_str, + "splitCatchment": split_catchment_str, + } err_msg = ( f"Error getting basin for feature source '{feature_source}' and " f"feature_id '{feature_id}'" ) - feature_collection = _query_nldi(url, query_params, err_msg) + feature_collection = cast("dict[str, Any]", _query_nldi(url, query_params, err_msg)) if as_json: return feature_collection gdf = _features_to_gdf(feature_collection) @@ -187,7 +194,7 @@ def get_features( long: float | None = None, stop_comid: int | None = None, as_json: bool = False, -) -> gpd.GeoDataFrame | dict: +) -> gpd.GeoDataFrame | dict[str, Any]: """Gets all features found along the specified navigation either by comid or feature source as points in WGS84 lat/long coordinates - a GeoDataFrame containing a point geometry. @@ -285,7 +292,7 @@ def get_features( query_params = {} err_msg = _features_err_msg(feature_source, feature_id, comid, data_source) - feature_collection = _query_nldi(url, query_params, err_msg) + feature_collection = cast("dict[str, Any]", _query_nldi(url, query_params, err_msg)) if as_json: return feature_collection gdf = _features_to_gdf(feature_collection) @@ -321,7 +328,7 @@ def get_features_by_data_source(data_source: str) -> gpd.GeoDataFrame: _validate_data_source(data_source) url = f"{NLDI_API_BASE_URL}/{data_source}" err_msg = f"Error getting features for data source '{data_source}'" - feature_collection = _query_nldi(url, {}, err_msg) + feature_collection = cast("dict[str, Any]", _query_nldi(url, {}, err_msg)) gdf = _features_to_gdf(feature_collection) return gdf @@ -336,7 +343,7 @@ def search( lat: float | None = None, long: float | None = None, distance: int = 50, -) -> dict: +) -> dict[str, Any]: """Searches for the specified feature in NLDI and returns the results as a dictionary. @@ -408,7 +415,7 @@ def search( if (lat is None) != (long is None): raise ValueError("Both lat and long are required") - find = find.lower() + find = cast(Literal["basin", "flowlines", "features"], find.lower()) if find not in ("basin", "flowlines", "features"): raise ValueError( f"Invalid value for find: {find} - allowed values are:" @@ -428,6 +435,10 @@ def search( return get_features(lat=lat, long=long, as_json=True) if find == "basin": + if feature_source is None or feature_id is None: + raise ValueError( + "feature_source and feature_id are required to find a basin" + ) return get_basin( feature_source=feature_source, feature_id=feature_id, as_json=True ) @@ -458,7 +469,7 @@ def search( ) -def _validate_data_source(data_source: str): +def _validate_data_source(data_source: str) -> None: # A helper function to validate user specified data source/feature source global _AVAILABLE_DATA_SOURCES @@ -487,7 +498,12 @@ def _validate_data_source(data_source: str): raise ValueError(err_msg) -def _features_err_msg(feature_source, feature_id, comid, data_source) -> str: +def _features_err_msg( + feature_source: str | None, + feature_id: str | None, + comid: int | None, + data_source: str | None, +) -> str: if feature_source is not None: return ( f"Error getting features for feature source '{feature_source}'" @@ -512,7 +528,7 @@ def _validate_navigation_mode(navigation_mode: str | None) -> str: def _validate_feature_source_comid( feature_source: str | None, feature_id: str | None, comid: int | None -): +) -> None: if feature_source is not None and feature_id is None: raise ValueError("feature_id is required if feature_source is provided") if feature_id is not None and feature_source is None: diff --git a/dataretrieval/nwis.py b/dataretrieval/nwis.py index 1372caa7..fafd0a08 100644 --- a/dataretrieval/nwis.py +++ b/dataretrieval/nwis.py @@ -9,7 +9,9 @@ import functools import threading import warnings +from collections.abc import Callable from json import JSONDecodeError +from typing import Any, NoReturn, TypeVar, cast import httpx import pandas as pd @@ -24,6 +26,8 @@ except ImportError: gpd = None +F = TypeVar("F", bound=Callable[..., Any]) + WATERDATA_BASE_URL = "https://nwis.waterdata.usgs.gov/" WATERDATA_URL = WATERDATA_BASE_URL + "nwis/" WATERSERVICE_URL = "https://waterservices.usgs.gov/nwis/" @@ -75,7 +79,7 @@ def _warn_deprecated(func_name: str) -> None: ) -def _deprecated(func): +def _deprecated(func: F) -> F: """Mark an nwis function as deprecated. Wrappers like ``get_record`` -> ``get_iv`` -> ``query_waterservices`` would @@ -89,7 +93,7 @@ def _deprecated(func): ) @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: if getattr(_deprecation_state, "active", False): return func(*args, **kwargs) _deprecation_state.active = True @@ -99,7 +103,7 @@ def wrapper(*args, **kwargs): finally: _deprecation_state.active = False - return wrapper + return cast(F, wrapper) def _parse_json_or_raise(response: httpx.Response) -> pd.DataFrame: @@ -123,7 +127,7 @@ def _parse_json_or_raise(response: httpx.Response) -> pd.DataFrame: def format_response( - df: pd.DataFrame, service: str | None = None, **kwargs + df: pd.DataFrame, service: str | None = None, **kwargs: Any ) -> pd.DataFrame: """Setup index for response from query. @@ -197,14 +201,14 @@ def preformat_peaks_response(df: pd.DataFrame) -> pd.DataFrame: return df -def get_qwdata(**kwargs): +def get_qwdata(**kwargs: Any) -> NoReturn: """Defunct: use ``waterdata.get_samples()``.""" raise NameError( "`nwis.get_qwdata` has been replaced with `waterdata.get_samples()`." ) -def get_discharge_measurements(**kwargs): +def get_discharge_measurements(**kwargs: Any) -> NoReturn: """Defunct: use ``waterdata.get_field_measurements()``.""" raise NameError( "`nwis.get_discharge_measurements` has been replaced " @@ -219,8 +223,8 @@ def get_discharge_peaks( end: str | None = None, multi_index: bool = True, ssl_check: bool = True, - **kwargs, -) -> tuple[pd.DataFrame, BaseMetadata]: + **kwargs: Any, +) -> tuple[pd.DataFrame, NWIS_Metadata]: """ Get discharge peaks from the waterdata service. @@ -285,7 +289,7 @@ def get_discharge_peaks( ) -def get_gwlevels(**kwargs): +def get_gwlevels(**kwargs: Any) -> NoReturn: """Defunct: use ``waterdata.get_continuous()``, ``waterdata.get_daily()``, or ``waterdata.get_field_measurements()``.""" raise NameError( @@ -298,8 +302,8 @@ def get_gwlevels(**kwargs): @_deprecated def get_stats( - sites: list[str] | str | None = None, ssl_check: bool = True, **kwargs -) -> tuple[pd.DataFrame, BaseMetadata]: + sites: list[str] | str | None = None, ssl_check: bool = True, **kwargs: Any +) -> tuple[pd.DataFrame, NWIS_Metadata]: """ Queries water services statistics information. @@ -359,7 +363,9 @@ def get_stats( @_deprecated -def query_waterdata(service: str, ssl_check: bool = True, **kwargs) -> httpx.Response: +def query_waterdata( + service: str, ssl_check: bool = True, **kwargs: Any +) -> httpx.Response: """ Queries waterdata. @@ -404,7 +410,7 @@ def query_waterdata(service: str, ssl_check: bool = True, **kwargs) -> httpx.Res @_deprecated def query_waterservices( - service: str, ssl_check: bool = True, **kwargs + service: str, ssl_check: bool = True, **kwargs: Any ) -> httpx.Response: """ Queries waterservices.usgs.gov @@ -473,8 +479,8 @@ def get_dv( end: str | None = None, multi_index: bool = True, ssl_check: bool = True, - **kwargs, -) -> tuple[pd.DataFrame, BaseMetadata]: + **kwargs: Any, +) -> tuple[pd.DataFrame, NWIS_Metadata]: """ Get daily values data from NWIS and return it as a ``pandas.DataFrame``. @@ -539,7 +545,9 @@ def get_dv( @_deprecated -def get_info(ssl_check: bool = True, **kwargs) -> tuple[pd.DataFrame, BaseMetadata]: +def get_info( + ssl_check: bool = True, **kwargs: Any +) -> tuple[pd.DataFrame, NWIS_Metadata]: """ Get site description information from NWIS. @@ -661,8 +669,8 @@ def get_iv( end: str | None = None, multi_index: bool = True, ssl_check: bool = True, - **kwargs, -) -> tuple[pd.DataFrame, BaseMetadata]: + **kwargs: Any, +) -> tuple[pd.DataFrame, NWIS_Metadata]: """Get instantaneous values data from NWIS and return it as a DataFrame. .. note:: @@ -725,7 +733,7 @@ def get_iv( return format_response(df, **kwargs), NWIS_Metadata(response, **kwargs) -def get_pmcodes(**kwargs): +def get_pmcodes(**kwargs: Any) -> NoReturn: """Defunct: use ``get_reference_table(collection='parameter-codes')``.""" raise NameError( "`nwis.get_pmcodes` has been replaced " @@ -733,7 +741,7 @@ def get_pmcodes(**kwargs): ) -def get_water_use(**kwargs): +def get_water_use(**kwargs: Any) -> NoReturn: """Defunct: no current replacement.""" raise NameError("`nwis.get_water_use` is defunct.") @@ -743,8 +751,8 @@ def get_ratings( site: str | None = None, file_type: str = "base", ssl_check: bool = True, - **kwargs, -) -> tuple[pd.DataFrame, BaseMetadata]: + **kwargs: Any, +) -> tuple[pd.DataFrame, NWIS_Metadata]: """ Rating table for an active USGS streamgage retrieval. @@ -797,7 +805,9 @@ def get_ratings( @_deprecated -def what_sites(ssl_check: bool = True, **kwargs) -> tuple[pd.DataFrame, BaseMetadata]: +def what_sites( + ssl_check: bool = True, **kwargs: Any +) -> tuple[pd.DataFrame, NWIS_Metadata]: """ Search NWIS for sites within a region with specific data. @@ -847,7 +857,7 @@ def get_record( state: str | None = None, service: str = "iv", ssl_check: bool = True, - **kwargs, + **kwargs: Any, ) -> pd.DataFrame: """ Get data from NWIS and return it as a ``pandas.DataFrame``. @@ -985,7 +995,10 @@ def get_record( return df elif service == "ratings": - df, _ = get_ratings(site=sites, ssl_check=ssl_check, **kwargs) + # the ratings service is single-site; get_ratings takes a scalar site + df, _ = get_ratings( + site=cast("str | None", sites), ssl_check=ssl_check, **kwargs + ) return df elif service == "stat": @@ -996,7 +1009,7 @@ def get_record( raise TypeError(f"{service} service not yet implemented") -def _read_json(json): +def _read_json(json: dict[str, Any]) -> pd.DataFrame: """ Reads a NWIS Water Services formatted JSON into a ``pandas.DataFrame``. @@ -1092,7 +1105,7 @@ def _read_json(json): return merged_df -def _read_rdb(rdb): +def _read_rdb(rdb: str) -> pd.DataFrame: """Parse an NWIS RDB response and apply NWIS-specific post-processing. Thin wrapper around :func:`dataretrieval.rdb.read_rdb` that adds the @@ -1102,7 +1115,7 @@ def _read_rdb(rdb): return format_response(read_rdb(rdb, dtypes=_NWIS_RDB_DTYPES)) -def _check_sites_value_types(sites): +def _check_sites_value_types(sites: list[str] | str | None) -> None: if sites and not isinstance(sites, list) and not isinstance(sites, str): raise TypeError("sites must be a string or a list of strings") @@ -1128,7 +1141,7 @@ class NWIS_Metadata(BaseMetadata): """ - def __init__(self, response, **parameters) -> None: + def __init__(self, response: httpx.Response, **parameters: Any) -> None: """Generates a standard set of metadata informed by the response with specific metadata for NWIS data. diff --git a/dataretrieval/samples.py b/dataretrieval/samples.py index 2259969c..025fa76e 100644 --- a/dataretrieval/samples.py +++ b/dataretrieval/samples.py @@ -7,9 +7,15 @@ from __future__ import annotations import warnings +from typing import TYPE_CHECKING, Any +if TYPE_CHECKING: + import pandas as pd -def get_usgs_samples(**kwargs): + from dataretrieval.utils import BaseMetadata + + +def get_usgs_samples(**kwargs: Any) -> tuple[pd.DataFrame, BaseMetadata]: """Deprecated: use ``waterdata.get_samples()`` instead. All keyword arguments are forwarded directly to diff --git a/dataretrieval/streamstats.py b/dataretrieval/streamstats.py index 6737d54c..039f292b 100644 --- a/dataretrieval/streamstats.py +++ b/dataretrieval/streamstats.py @@ -5,14 +5,17 @@ """ +from __future__ import annotations + import json +from typing import Any, cast import httpx from dataretrieval.utils import HTTPX_DEFAULTS -def download_workspace(workspaceID, format=""): +def download_workspace(workspaceID: str, format: str = "") -> httpx.Response: """Function to download streamstats workspace. Parameters @@ -46,7 +49,7 @@ def download_workspace(workspaceID, format=""): # return -def get_sample_watershed(): +def get_sample_watershed() -> Watershed: """Sample function to get a watershed object for a location in NY. Makes the function call :obj:`dataretrieval.streamstats.get_watershed` @@ -60,20 +63,23 @@ def get_sample_watershed(): from the streamstats JSON object. """ - return get_watershed("NY", -74.524, 43.939, format="object") + return cast( + "Watershed", + get_watershed("NY", -74.524, 43.939, format="object"), + ) def get_watershed( - rcode, - xlocation, - ylocation, - crs=4326, - includeparameters=True, - includeflowtypes=False, - includefeatures=True, - simplify=True, - format="geojson", -): + rcode: str, + xlocation: float, + ylocation: float, + crs: int | str = 4326, + includeparameters: bool = True, + includeflowtypes: bool = False, + includefeatures: bool = True, + simplify: bool = True, + format: str = "geojson", +) -> httpx.Response | Watershed: """Get watershed object based on location **Streamstats documentation:** @@ -115,7 +121,7 @@ def get_watershed( from the streamstats JSON object. """ - payload = { + payload: dict[str, str | int | float | bool] = { "rcode": rcode, "xlocation": xlocation, "ylocation": ylocation, @@ -170,14 +176,17 @@ class Watershed: :obj:`dataretrieval.streamstats.download_workspace`. """ - def __init__(self, rcode, xlocation, ylocation): + def __init__(self, rcode: str, xlocation: float, ylocation: float) -> None: """Delineate the watershed at ``(xlocation, ylocation)`` and parse the response onto this instance.""" - response = get_watershed(rcode, xlocation, ylocation, format="geojson") + response = cast( + httpx.Response, + get_watershed(rcode, xlocation, ylocation, format="geojson"), + ) self._populate(json.loads(response.text)) @classmethod - def from_streamstats_json(cls, streamstats_json) -> "Watershed": + def from_streamstats_json(cls, streamstats_json: dict[str, Any]) -> Watershed: """Create a :class:`Watershed` from an already-parsed StreamStats JSON payload, without issuing a new request. @@ -190,7 +199,7 @@ class state. self._populate(streamstats_json) return self - def _populate(self, streamstats_json) -> None: + def _populate(self, streamstats_json: dict[str, Any]) -> None: """Extract watershed fields from a StreamStats JSON payload onto this instance.""" self.watershed_point = streamstats_json["featurecollection"][0]["feature"] diff --git a/dataretrieval/utils.py b/dataretrieval/utils.py index 7bb03a69..f9766ee6 100644 --- a/dataretrieval/utils.py +++ b/dataretrieval/utils.py @@ -2,8 +2,11 @@ Useful utilities for data munging. """ +from __future__ import annotations + import warnings from collections.abc import Iterable +from typing import Any import httpx import pandas as pd @@ -11,13 +14,16 @@ import dataretrieval from dataretrieval.codes import tz -HTTPX_DEFAULTS = { +# Typed as ``dict[str, Any]`` (not the inferred ``dict[str, object]``) so that +# splatting it as ``**HTTPX_DEFAULTS`` into ``httpx.get`` / ``httpx.AsyncClient`` +# type-checks: the values are a heterogeneous bag of httpx keyword arguments. +HTTPX_DEFAULTS: dict[str, Any] = { "follow_redirects": True, "timeout": httpx.Timeout(60.0, connect=10.0), } -def to_str(listlike, delimiter=","): +def to_str(listlike: object, delimiter: str = ",") -> str | None: """Translates list-like objects into strings. Parameters @@ -54,7 +60,9 @@ def to_str(listlike, delimiter=","): return None -def format_datetime(df, date_field, time_field, tz_field): +def format_datetime( + df: pd.DataFrame, date_field: str, time_field: str, tz_field: str +) -> pd.DataFrame: """Creates a datetime field from separate date, time, and time zone fields. @@ -190,6 +198,7 @@ def _attach_datetime_columns(df: pd.DataFrame) -> pd.DataFrame: # Concat in one shot — per-column assignment on a wide CSV-derived # frame triggers pandas' fragmentation PerformanceWarning. df = pd.concat([df, pd.DataFrame(new_columns, index=df.index)], axis=1) + sort_key: str | None if "Activity_StartDateTime" in df.columns: sort_key = "Activity_StartDateTime" elif "ActivityStartDateTime" in df.columns: @@ -215,7 +224,7 @@ class BaseMetadata: """ - def __init__(self, response) -> None: + def __init__(self, response: httpx.Response) -> None: """Generates a standard set of metadata informed by the response. Parameters @@ -234,7 +243,7 @@ def __init__(self, response) -> None: self.url = str(response.url) self.query_time = response.elapsed self.header = response.headers - self.comment = None + self.comment: str | None = None # # not sure what statistic_info is # self.statistic_info = None @@ -244,13 +253,13 @@ def __init__(self, response) -> None: # These properties are to be set by `nwis` or `wqp`-specific metadata classes. @property - def site_info(self): + def site_info(self) -> Any: raise NotImplementedError( "site_info must be implemented by utils.BaseMetadata children" ) @property - def variable_info(self): + def variable_info(self) -> Any: raise NotImplementedError( "variable_info must be implemented by utils.BaseMetadata children" ) @@ -278,7 +287,12 @@ def _url_too_long_error(detail: str) -> ValueError: ) -def query(url, payload, delimiter=",", ssl_check=True): +def query( + url: str, + payload: dict[str, Any], + delimiter: str = ",", + ssl_check: bool = True, +) -> httpx.Response: """Send a query. Wrapper for httpx.get that handles errors, converts listed @@ -347,10 +361,10 @@ def query(url, payload, delimiter=",", ssl_check=True): class NoSitesError(Exception): """Custom error class used when selection criteria returns no sites/data.""" - def __init__(self, url): + def __init__(self, url: httpx.URL) -> None: self.url = url - def __str__(self): + def __str__(self) -> str: return ( "No sites/data found using the selection criteria specified in " f"url: {self.url}" diff --git a/dataretrieval/waterdata/api.py b/dataretrieval/waterdata/api.py index 1b609324..3144bf80 100644 --- a/dataretrieval/waterdata/api.py +++ b/dataretrieval/waterdata/api.py @@ -10,7 +10,7 @@ import logging from collections.abc import Iterable from io import StringIO -from typing import get_args +from typing import Any, get_args from urllib.parse import quote import httpx @@ -2018,7 +2018,7 @@ def get_peaks( def get_reference_table( collection: str, limit: int | None = None, - query: dict | None = None, + query: dict[str, Any] | None = None, max_rows: int | None = None, ) -> tuple[pd.DataFrame, BaseMetadata]: """Get metadata reference tables for the USGS Water Data API. @@ -2140,7 +2140,7 @@ def get_codes(code_service: CODE_SERVICES) -> tuple[pd.DataFrame, BaseMetadata]: def _get_samples_csv( - url: str, params: dict, ssl_check: bool + url: str, params: dict[str, Any], ssl_check: bool ) -> tuple[pd.DataFrame, httpx.Response]: """Issue a Samples CSV request and parse the body into a DataFrame. @@ -2852,7 +2852,7 @@ def get_channel( def get_cql( service: WATERDATA_SERVICES, - cql: str | dict, + cql: str | dict[str, Any], *, properties: str | Iterable[str] | None = None, bbox: list[float] | None = None, diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index ab079070..1f49ed97 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -59,7 +59,7 @@ from contextvars import ContextVar from dataclasses import dataclass from datetime import timedelta -from typing import Any, ClassVar +from typing import Any, ClassVar, cast from urllib.parse import quote_plus import httpx @@ -681,7 +681,7 @@ def _set_response_url(response: httpx.Response, url: str | httpx.URL) -> None: same ``.request``. """ try: - response.url = url # type: ignore[misc] + response.url = url # type: ignore[misc, assignment] except AttributeError: target = httpx.URL(str(url)) try: @@ -800,7 +800,7 @@ def _extract_axes(args: dict[str, Any]) -> list[_Axis]: axes.append(_Axis(arg_key=key, atoms=tuple(value), joiner=_LIST_SEP)) filter_expr = args.get("filter") - if _is_chunkable(filter_expr, args.get("filter_lang")): + if filter_expr is not None and _is_chunkable(filter_expr, args.get("filter_lang")): _check_numeric_filter_pitfall(filter_expr) clauses = _split_top_level_or(filter_expr) if len(clauses) >= 2: @@ -1560,7 +1560,12 @@ def resume(self) -> tuple[pd.DataFrame, Any]: """ concurrency = _read_concurrency_env() with start_blocking_portal() as portal: - return portal.call(functools.partial(self._run, concurrency)) + # ``portal.call`` returns ``Any`` because ``functools.partial`` + # erases ``_run``'s return type; restore the declared tuple. + return cast( + "tuple[pd.DataFrame, Any]", + portal.call(functools.partial(self._run, concurrency)), + ) async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]: """ diff --git a/dataretrieval/waterdata/nearest.py b/dataretrieval/waterdata/nearest.py index 12aad61c..39a80332 100644 --- a/dataretrieval/waterdata/nearest.py +++ b/dataretrieval/waterdata/nearest.py @@ -6,7 +6,7 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Literal, get_args +from typing import Any, Literal, get_args import pandas as pd @@ -18,13 +18,13 @@ def get_nearest_continuous( - targets, + targets: Iterable[Any], monitoring_location_id: str | Iterable[str] | None = None, parameter_code: str | Iterable[str] | None = None, *, window: str | pd.Timedelta = "PT7M30S", on_tie: OnTie = "first", - **kwargs, + **kwargs: Any, ) -> tuple[pd.DataFrame, BaseMetadata]: """For each target timestamp, return the nearest continuous observation. @@ -138,13 +138,13 @@ def get_nearest_continuous( ... ) """ _check_nearest_kwargs(kwargs, on_tie) - targets = pd.DatetimeIndex(pd.to_datetime(targets, utc=True)) + target_index = pd.DatetimeIndex(pd.to_datetime(targets, utc=True)) window_td = pd.Timedelta(window) - if len(targets) == 0: + if len(target_index) == 0: raise ValueError("targets must contain at least one timestamp") - filter_expr = _build_window_or_filter(targets, window_td) + filter_expr = _build_window_or_filter(target_index, window_td) df, md = get_continuous( monitoring_location_id=monitoring_location_id, parameter_code=parameter_code, @@ -165,7 +165,7 @@ def get_nearest_continuous( selected = [ row for _, site_df in site_groups - for target in targets + for target in target_index if (row := _pick_nearest_row(site_df, target, window_td, on_tie)) is not None ] if not selected: @@ -173,7 +173,7 @@ def get_nearest_continuous( return pd.DataFrame(selected).reset_index(drop=True), md -def _check_nearest_kwargs(kwargs: dict, on_tie: OnTie) -> None: +def _check_nearest_kwargs(kwargs: dict[str, Any], on_tie: OnTie) -> None: """Reject kwargs the helper owns; validate ``on_tie``.""" for forbidden in ("time", "filter", "filter_lang"): if forbidden in kwargs: diff --git a/dataretrieval/waterdata/ratings.py b/dataretrieval/waterdata/ratings.py index ed242612..c0f870c1 100644 --- a/dataretrieval/waterdata/ratings.py +++ b/dataretrieval/waterdata/ratings.py @@ -246,15 +246,18 @@ def _search( STAC ``next`` link is followed until exhausted so a result set larger than one page isn't silently truncated. """ - params: dict[str, Any] | None = {"limit": min(limit, 10000)} + query_params: dict[str, Any] = {"limit": min(limit, 10000)} if filter_str is not None: - params["filter"] = filter_str + query_params["filter"] = filter_str if time_str is not None: - params["datetime"] = time_str + query_params["datetime"] = time_str if bbox is not None: - params["bbox"] = ",".join(map(str, bbox)) + query_params["bbox"] = ",".join(map(str, bbox)) url: str | None = f"{STAC_URL}/search" + # ``params`` is sent only on the first request; each STAC ``next`` link + # already carries the query, so it is reset to None inside the loop. + params: dict[str, Any] | None = query_params features: list[dict[str, Any]] = [] while url is not None: response = httpx.get( diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index ad1b3afd..5581d086 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -14,11 +14,12 @@ Iterable, Iterator, Mapping, + Sequence, ) from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar from datetime import datetime, timedelta -from typing import Any, TypeVar, get_args +from typing import Any, TypeVar, cast, get_args from zoneinfo import ZoneInfo import httpx @@ -97,7 +98,7 @@ } -def _switch_arg_id(ls: dict[str, Any], id_name: str, service: str): +def _switch_arg_id(ls: dict[str, Any], id_name: str, service: str) -> dict[str, Any]: """ Switch argument id from its package-specific identifier to the standardized "id" key that the API recognizes. @@ -142,7 +143,9 @@ def _switch_arg_id(ls: dict[str, Any], id_name: str, service: str): return ls -def _switch_properties_id(properties: list[str] | None, id_name: str, service: str): +def _switch_properties_id( + properties: list[str] | None, id_name: str, service: str +) -> list[str]: """ Switch properties id from its package-specific identifier to the standardized "id" key that the API recognizes. @@ -233,7 +236,7 @@ def _parse_datetime(value: str) -> datetime | None: return None -def _format_one(dt, *, date: bool) -> str | None: +def _format_one(dt: str | None, *, date: bool) -> str | None: """Format a single datetime element for inclusion in the API time arg.""" if pd.isna(dt) or dt == "" or dt is None: return ".." @@ -251,7 +254,7 @@ def _format_one(dt, *, date: bool) -> str | None: def _format_api_dates( - datetime_input: str | list[str | None] | None, date: bool = False + datetime_input: str | Sequence[str | None] | None, date: bool = False ) -> str | None: """ Formats date or datetime input(s) for use with an API. @@ -330,11 +333,13 @@ def _format_api_dates( if _DURATION_RE.match(single) or "/" in single: return single - # Half-bounded ranges: NA endpoints render as ".."; any unparseable non-NA # element invalidates the range. - formatted = [_format_one(dt, date=date) for dt in datetime_input] - if any(f is None for f in formatted): - return None + formatted: list[str] = [] + for dt in datetime_input: + one = _format_one(dt, date=date) + if one is None: + return None + formatted.append(one) return "/".join(formatted) @@ -371,7 +376,7 @@ def _cql2_param(args: dict[str, Any]) -> str: return json.dumps(query, separators=(",", ":")) -def _default_headers(): +def _default_headers() -> dict[str, str]: """ Generate default HTTP headers for API requests. @@ -394,7 +399,9 @@ def _default_headers(): return headers -def _check_ogc_requests(endpoint: str = "daily", req_type: str = "queryables"): +def _check_ogc_requests( + endpoint: str = "daily", req_type: str = "queryables" +) -> dict[str, Any]: """ Sends an HTTP GET request to the specified OGC endpoint and request type, returning the JSON response. @@ -426,10 +433,12 @@ def _check_ogc_requests(endpoint: str = "daily", req_type: str = "queryables"): url = f"{OGC_API_URL}/collections/{endpoint}/{req_type}" resp = httpx.get(url, headers=_default_headers(), **HTTPX_DEFAULTS) _raise_for_non_200(resp) - return resp.json() + # ``Response.json`` is typed ``Any``; the OGC queryables/schema endpoints + # return a JSON object, and callers index it as a dict. + return cast("dict[str, Any]", resp.json()) -def _error_body(resp: httpx.Response): +def _error_body(resp: httpx.Response) -> str: """ Build an informative error message from an HTTP response. @@ -626,7 +635,7 @@ def _construct_api_requests( bbox: list[float] | None = None, limit: int | None = None, skip_geometry: bool = False, - **kwargs, + **kwargs: Any, ) -> httpx.Request: """ Constructs an HTTP request object for the specified water data API service. @@ -823,6 +832,8 @@ def _next_req_url( # body might supply. Guarded against mock-shaped ``resp.url`` # attributes (tests sometimes set strings or ``MagicMock``) # by falling open when host extraction isn't reliable. + next_host: str | None + cur_host: str | None try: next_host = httpx.URL(href).host resp_url = ( @@ -838,7 +849,10 @@ def _next_req_url( f"Refusing to follow cross-host next-page URL: " f"{next_host} != {cur_host}" ) - return href + # ``href`` comes from the JSON ``links`` array (typed ``Any``); the + # ``not href`` guard above already excluded empty/None, and it is a + # URL string (passed to ``httpx.URL`` above). + return cast("str", href) return None @@ -1915,11 +1929,10 @@ def _as_str_list( ``",".join(...)`` doesn't iterate it character-by-character — and materializes any other iterable via :func:`_normalize_str_iterable`. """ - return ( - [value] - if isinstance(value, str) - else _normalize_str_iterable(value, param_name) - ) + normalized = _normalize_str_iterable(value, param_name) + if isinstance(normalized, str): + return [normalized] + return normalized def _check_monitoring_location_id( diff --git a/dataretrieval/wqp.py b/dataretrieval/wqp.py index ff01c46a..1ca0098a 100644 --- a/dataretrieval/wqp.py +++ b/dataretrieval/wqp.py @@ -13,13 +13,14 @@ import warnings from io import StringIO -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pandas as pd from .utils import BaseMetadata, _attach_datetime_columns, query if TYPE_CHECKING: + import httpx from pandas import DataFrame @@ -67,9 +68,9 @@ def _read_wqp_csv(text: str) -> DataFrame: def get_results( - ssl_check=True, - legacy=True, - **kwargs, + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, ) -> tuple[DataFrame, WQP_Metadata]: """Query the WQP for results. @@ -185,9 +186,9 @@ def get_results( def what_sites( - ssl_check=True, - legacy=True, - **kwargs, + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, ) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for sites within a region with specific data. @@ -240,9 +241,9 @@ def what_sites( def what_organizations( - ssl_check=True, - legacy=True, - **kwargs, + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, ) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for organizations within a region with specific data. @@ -290,7 +291,11 @@ def what_organizations( return df, WQP_Metadata(response, **kwargs) -def what_projects(ssl_check=True, legacy=True, **kwargs): +def what_projects( + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, +) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for projects within a region with specific data. Any WQP API parameter can be passed as a keyword argument to this function. @@ -338,9 +343,9 @@ def what_projects(ssl_check=True, legacy=True, **kwargs): def what_activities( - ssl_check=True, - legacy=True, - **kwargs, + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, ) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for activities within a region with specific data. @@ -402,9 +407,9 @@ def what_activities( def what_detection_limits( - ssl_check=True, - legacy=True, - **kwargs, + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, ) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for result detection limits within a region with specific data. @@ -460,9 +465,9 @@ def what_detection_limits( def what_habitat_metrics( - ssl_check=True, - legacy=True, - **kwargs, + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, ) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for habitat metrics within a region with specific data. @@ -510,7 +515,11 @@ def what_habitat_metrics( return df, WQP_Metadata(response, **kwargs) -def what_project_weights(ssl_check=True, legacy=True, **kwargs): +def what_project_weights( + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, +) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for project weights within a region with specific data. Any WQP API parameter can be passed as a keyword argument to this function. @@ -562,7 +571,11 @@ def what_project_weights(ssl_check=True, legacy=True, **kwargs): return df, WQP_Metadata(response, **kwargs) -def what_activity_metrics(ssl_check=True, legacy=True, **kwargs): +def what_activity_metrics( + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, +) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for activity metrics within a region with specific data. Any WQP API parameter can be passed as a keyword argument to this function. @@ -614,7 +627,7 @@ def what_activity_metrics(ssl_check=True, legacy=True, **kwargs): return df, WQP_Metadata(response, **kwargs) -def wqp_url(service): +def wqp_url(service: str) -> str: """Construct the WQP URL for a given service.""" base_url = "https://www.waterqualitydata.us/data/" @@ -628,7 +641,7 @@ def wqp_url(service): return f"{base_url}{service}/Search?" -def wqx3_url(service): +def wqx3_url(service: str) -> str: """Construct the WQP URL for a given WQX 3.0 service.""" base_url = "https://www.waterqualitydata.us/wqx3/" @@ -659,7 +672,7 @@ class WQP_Metadata(BaseMetadata): Site information (via ``what_sites``) if the query included a ``siteid``. """ - def __init__(self, response, **parameters) -> None: + def __init__(self, response: httpx.Response, **parameters: Any) -> None: """Generates a standard set of metadata informed by the response with specific metadata for WQP data. @@ -703,7 +716,7 @@ def site_info(self) -> tuple[DataFrame, WQP_Metadata] | None: return what_sites(siteid=siteid) -def _check_kwargs(kwargs): +def _check_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: """Private function to check kwargs for unsupported parameters.""" mimetype = kwargs.get("mimeType") if mimetype == "geojson": @@ -716,7 +729,7 @@ def _check_kwargs(kwargs): return kwargs -def _warn_wqx3_use(): +def _warn_wqx3_use() -> None: message = ( "Support for the WQX3.0 profiles is experimental. " "Queries may be slow or fail intermittently." @@ -724,7 +737,7 @@ def _warn_wqx3_use(): warnings.warn(message, UserWarning, stacklevel=2) -def _warn_legacy_use(): +def _warn_legacy_use() -> None: message = ( "This function call will return the legacy WQX format, " "which means USGS data have not been updated since March 2024. " @@ -735,7 +748,7 @@ def _warn_legacy_use(): warnings.warn(message, DeprecationWarning, stacklevel=2) -def _warn_wqx3_unavailable(): +def _warn_wqx3_unavailable() -> None: # stacklevel=3: warn -> _warn_wqx3_unavailable -> _legacy_only_url -> what_* warnings.warn( "WQX3.0 profile not available, returning legacy profile.", diff --git a/pyproject.toml b/pyproject.toml index 62ac7478..57f60161 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,11 @@ packages = ["dataretrieval", "dataretrieval.codes"] dataretrieval = ["py.typed"] [project.optional-dependencies] +# Minimal set the CI ``type-check`` job installs — just mypy + the package, +# not the whole test stack. +type-check = [ + "mypy<2", # <2 so it can still target Python 3.9 (the project's floor) +] test = [ "pytest > 5.0.0", "pytest-cov[all]", @@ -39,6 +44,7 @@ test = [ "coverage", "pytest-httpx", "ruff", + "dataretrieval[type-check]", # mypy, pinned once in the type-check extra ] doc = [ "docutils<0.22", @@ -102,3 +108,14 @@ skip-magic-trailing-comma = false line-ending = "auto" docstring-code-format = true docstring-code-line-length = 72 + +[tool.mypy] +# The package is fully annotated and passes ``mypy --strict``. The one +# remaining relaxation is ``ignore_missing_imports``: untyped third-party +# libraries (pandas, geopandas, anyio) are treated as ``Any`` instead of +# requiring stub packages. Dropping that — via pandas-stubs/types-requests and +# per-module overrides — can follow. +python_version = "3.9" # the project's minimum supported version +files = ["dataretrieval"] +strict = true +ignore_missing_imports = true