diff --git a/dataretrieval/waterdata/api.py b/dataretrieval/waterdata/api.py index 6ffddc26..ddbc31b1 100644 --- a/dataretrieval/waterdata/api.py +++ b/dataretrieval/waterdata/api.py @@ -34,6 +34,7 @@ _check_profiles, _default_headers, _get_args, + _raise_for_non_200, get_ogc_data, get_stats_data, ) @@ -2125,7 +2126,7 @@ def get_codes(code_service: CODE_SERVICES) -> pd.DataFrame: response = httpx.get(url, headers=_default_headers(), **HTTPX_DEFAULTS) - response.raise_for_status() + _raise_for_non_200(response) data_dict = json.loads(response.text) data_list = data_dict["data"] @@ -2135,6 +2136,30 @@ def get_codes(code_service: CODE_SERVICES) -> pd.DataFrame: return df +def _get_samples_csv( + url: str, params: dict, ssl_check: bool +) -> tuple[pd.DataFrame, httpx.Response]: + """Issue a Samples CSV request and parse the body into a DataFrame. + + Shared tail for the Samples getters: sends the GET with the standard + headers (including ``X-Api-Key``), raises a typed error on a non-200 + (consistent with the OGC/stats path) instead of a bare + ``HTTPStatusError``, and reads the CSV. The caller wraps the response + as metadata and applies any per-getter post-step. + """ + logger.debug("Request: %s", httpx.URL(url).copy_merge_params(params)) + response = httpx.get( + url, + params=params, + verify=ssl_check, + headers=_default_headers(), + **HTTPX_DEFAULTS, + ) + _raise_for_non_200(response) + df = pd.read_csv(StringIO(response.text), delimiter=",") + return df, response + + def get_samples( ssl_check: bool = True, service: SERVICES = "results", @@ -2349,19 +2374,7 @@ def get_samples( url = f"{SAMPLES_URL}/{service}/{profile}" - logger.debug("Request: %s", httpx.URL(url).copy_merge_params(params)) - - response = httpx.get( - url, - params=params, - verify=ssl_check, - headers=_default_headers(), - **HTTPX_DEFAULTS, - ) - - response.raise_for_status() - - df = pd.read_csv(StringIO(response.text), delimiter=",") + df, response = _get_samples_csv(url, params, ssl_check) df = _attach_datetime_columns(df) return df, BaseMetadata(response) @@ -2423,19 +2436,7 @@ def get_samples_summary( url = f"{SAMPLES_URL}/summary/{quote(monitoringLocationIdentifier, safe='')}" params = {"mimeType": "text/csv"} - logger.debug("Request: %s", httpx.URL(url).copy_merge_params(params)) - - response = httpx.get( - url, - params=params, - verify=ssl_check, - headers=_default_headers(), - **HTTPX_DEFAULTS, - ) - - response.raise_for_status() - - df = pd.read_csv(StringIO(response.text), delimiter=",") + df, response = _get_samples_csv(url, params, ssl_check) return df, BaseMetadata(response) @@ -2767,6 +2768,8 @@ def get_channel( channel_name : string or iterable of strings, optional The channel name. channel_flow : string or iterable of strings, optional + The channel discharge (flow). + channel_flow_unit : string or iterable of strings, optional The units for channel discharge. channel_width : string or iterable of strings, optional The channel width. @@ -2797,24 +2800,7 @@ def get_channel( longitudinal_velocity_description : string or iterable of strings, optional The longitudinal velocity description. measurement_type : string or iterable of strings, optional - The measurement type. - The last time a record was refreshed in our database. This may happen - due to regular operational processes and does not necessarily indicate - anything about the measurement has changed. You can query this field - using date-times or intervals, adhering to RFC 3339, or using ISO 8601 - duration objects. Intervals may be bounded or half-bounded (double-dots - at start or end). - Examples: - - * A date-time: "2018-02-12T23:20:50Z" - * A bounded interval: "2018-02-12T00:00:00Z/2018-03-18T12:31:12Z" - * Half-bounded intervals: "2018-02-12T00:00:00Z/.." or - "../2018-03-18T12:31:12Z" - * Duration objects: "P1M" for data from the past month or "PT36H" for the - last 36 hours - - Only features that have a last_modified that intersects the value of - datetime are selected. + The type of channel measurement. skip_geometry : boolean, optional This option can be used to skip response geometries for each feature. The returning object will be a data frame with no spatial information. diff --git a/dataretrieval/waterdata/ratings.py b/dataretrieval/waterdata/ratings.py index 0e1b503d..8bdd99b8 100644 --- a/dataretrieval/waterdata/ratings.py +++ b/dataretrieval/waterdata/ratings.py @@ -29,6 +29,7 @@ _check_monitoring_location_id, _default_headers, _format_api_dates, + _raise_for_non_200, ) logger = logging.getLogger(__name__) @@ -248,7 +249,7 @@ def _search( verify=ssl_check, **HTTPX_DEFAULTS, ) - response.raise_for_status() + _raise_for_non_200(response) return response.json().get("features", []) @@ -262,7 +263,7 @@ def _download_and_parse( response = httpx.get( url, headers=_default_headers(), verify=ssl_check, **HTTPX_DEFAULTS ) - response.raise_for_status() + _raise_for_non_200(response) if file_path is not None: with open(os.path.join(file_path, feature["id"]), "w") as f: diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index 475998ca..0f86730f 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -268,8 +268,6 @@ def _format_api_dates( """ if datetime_input is None: return None - # Get timezone - local_timezone = datetime.now().astimezone().tzinfo # Convert single string to list for uniform processing if isinstance(datetime_input, str): @@ -300,7 +298,9 @@ def _format_api_dates( return single # Half-bounded ranges: NA endpoints render as ".."; any unparseable non-NA - # element invalidates the range. + # element invalidates the range. Resolve the local tz only now — after the + # all-NA / duration / interval guards above have had their chance to return. + local_timezone = datetime.now().astimezone().tzinfo formatted = [ _format_one(dt, date=date, local_tz=local_timezone) for dt in datetime_input ] diff --git a/tests/waterdata_test.py b/tests/waterdata_test.py index c3836534..92c978de 100644 --- a/tests/waterdata_test.py +++ b/tests/waterdata_test.py @@ -129,6 +129,30 @@ def test_get_samples_summary_rejects_list(): get_samples_summary(monitoringLocationIdentifier=["USGS-04183500"]) +def test_get_samples_raises_typed_error_on_429(httpx_mock): + """Non-200 from the Samples endpoint now raises the module's typed error + (RateLimited on 429) — consistent with the OGC/stats path — instead of a + bare httpx.HTTPStatusError.""" + from dataretrieval.waterdata.chunking import RateLimited + + httpx_mock.add_response(status_code=429, headers={"Retry-After": "30"}) + with pytest.raises(RateLimited): + get_samples( + service="results", + profile="fullphyschem", + monitoringLocationIdentifier="USGS-05406500", + ) + + +def test_get_samples_summary_raises_typed_error_on_5xx(httpx_mock): + """A 5xx from the Samples summary endpoint raises ServiceUnavailable.""" + from dataretrieval.waterdata.chunking import ServiceUnavailable + + httpx_mock.add_response(status_code=503) + with pytest.raises(ServiceUnavailable): + get_samples_summary(monitoringLocationIdentifier="USGS-04183500") + + def test_check_profiles(): """Tests that correct errors are raised for invalid profiles.""" with pytest.raises(ValueError):