Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 31 additions & 45 deletions dataretrieval/waterdata/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_check_profiles,
_default_headers,
_get_args,
_raise_for_non_200,
get_ogc_data,
get_stats_data,
)
Expand Down Expand Up @@ -2125,7 +2126,7 @@ def get_codes(code_service: CODE_SERVICES) -> pd.DataFrame:

response = httpx.get(url, headers=_default_headers(), **HTTPX_DEFAULTS)

response.raise_for_status()
_raise_for_non_200(response)

data_dict = json.loads(response.text)
data_list = data_dict["data"]
Expand All @@ -2135,6 +2136,30 @@ def get_codes(code_service: CODE_SERVICES) -> pd.DataFrame:
return df


def _get_samples_csv(
url: str, params: dict, ssl_check: bool
) -> tuple[pd.DataFrame, httpx.Response]:
"""Issue a Samples CSV request and parse the body into a DataFrame.

Shared tail for the Samples getters: sends the GET with the standard
headers (including ``X-Api-Key``), raises a typed error on a non-200
(consistent with the OGC/stats path) instead of a bare
``HTTPStatusError``, and reads the CSV. The caller wraps the response
as metadata and applies any per-getter post-step.
"""
logger.debug("Request: %s", httpx.URL(url).copy_merge_params(params))
response = httpx.get(
url,
params=params,
verify=ssl_check,
headers=_default_headers(),
**HTTPX_DEFAULTS,
)
_raise_for_non_200(response)
df = pd.read_csv(StringIO(response.text), delimiter=",")
return df, response


def get_samples(
ssl_check: bool = True,
service: SERVICES = "results",
Expand Down Expand Up @@ -2349,19 +2374,7 @@ def get_samples(

url = f"{SAMPLES_URL}/{service}/{profile}"

logger.debug("Request: %s", httpx.URL(url).copy_merge_params(params))

response = httpx.get(
url,
params=params,
verify=ssl_check,
headers=_default_headers(),
**HTTPX_DEFAULTS,
)

response.raise_for_status()

df = pd.read_csv(StringIO(response.text), delimiter=",")
df, response = _get_samples_csv(url, params, ssl_check)
df = _attach_datetime_columns(df)

return df, BaseMetadata(response)
Expand Down Expand Up @@ -2423,19 +2436,7 @@ def get_samples_summary(
url = f"{SAMPLES_URL}/summary/{quote(monitoringLocationIdentifier, safe='')}"
params = {"mimeType": "text/csv"}

logger.debug("Request: %s", httpx.URL(url).copy_merge_params(params))

response = httpx.get(
url,
params=params,
verify=ssl_check,
headers=_default_headers(),
**HTTPX_DEFAULTS,
)

response.raise_for_status()

df = pd.read_csv(StringIO(response.text), delimiter=",")
df, response = _get_samples_csv(url, params, ssl_check)

return df, BaseMetadata(response)

Expand Down Expand Up @@ -2767,6 +2768,8 @@ def get_channel(
channel_name : string or iterable of strings, optional
The channel name.
channel_flow : string or iterable of strings, optional
The channel discharge (flow).
channel_flow_unit : string or iterable of strings, optional
The units for channel discharge.
channel_width : string or iterable of strings, optional
The channel width.
Expand Down Expand Up @@ -2797,24 +2800,7 @@ def get_channel(
longitudinal_velocity_description : string or iterable of strings, optional
The longitudinal velocity description.
measurement_type : string or iterable of strings, optional
The measurement type.
The last time a record was refreshed in our database. This may happen
due to regular operational processes and does not necessarily indicate
anything about the measurement has changed. You can query this field
using date-times or intervals, adhering to RFC 3339, or using ISO 8601
duration objects. Intervals may be bounded or half-bounded (double-dots
at start or end).
Examples:

* A date-time: "2018-02-12T23:20:50Z"
* A bounded interval: "2018-02-12T00:00:00Z/2018-03-18T12:31:12Z"
* Half-bounded intervals: "2018-02-12T00:00:00Z/.." or
"../2018-03-18T12:31:12Z"
* Duration objects: "P1M" for data from the past month or "PT36H" for the
last 36 hours

Only features that have a last_modified that intersects the value of
datetime are selected.
The type of channel measurement.
skip_geometry : boolean, optional
This option can be used to skip response geometries for each feature.
The returning object will be a data frame with no spatial information.
Expand Down
5 changes: 3 additions & 2 deletions dataretrieval/waterdata/ratings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_check_monitoring_location_id,
_default_headers,
_format_api_dates,
_raise_for_non_200,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -248,7 +249,7 @@ def _search(
verify=ssl_check,
**HTTPX_DEFAULTS,
)
response.raise_for_status()
_raise_for_non_200(response)
return response.json().get("features", [])


Expand All @@ -262,7 +263,7 @@ def _download_and_parse(
response = httpx.get(
url, headers=_default_headers(), verify=ssl_check, **HTTPX_DEFAULTS
)
response.raise_for_status()
_raise_for_non_200(response)

if file_path is not None:
with open(os.path.join(file_path, feature["id"]), "w") as f:
Expand Down
6 changes: 3 additions & 3 deletions dataretrieval/waterdata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,6 @@ def _format_api_dates(
"""
if datetime_input is None:
return None
# Get timezone
local_timezone = datetime.now().astimezone().tzinfo

# Convert single string to list for uniform processing
if isinstance(datetime_input, str):
Expand Down Expand Up @@ -300,7 +298,9 @@ def _format_api_dates(
return single

# Half-bounded ranges: NA endpoints render as ".."; any unparseable non-NA
# element invalidates the range.
# element invalidates the range. Resolve the local tz only now — after the
# all-NA / duration / interval guards above have had their chance to return.
local_timezone = datetime.now().astimezone().tzinfo
formatted = [
_format_one(dt, date=date, local_tz=local_timezone) for dt in datetime_input
]
Expand Down
24 changes: 24 additions & 0 deletions tests/waterdata_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,30 @@ def test_get_samples_summary_rejects_list():
get_samples_summary(monitoringLocationIdentifier=["USGS-04183500"])


def test_get_samples_raises_typed_error_on_429(httpx_mock):
"""Non-200 from the Samples endpoint now raises the module's typed error
(RateLimited on 429) — consistent with the OGC/stats path — instead of a
bare httpx.HTTPStatusError."""
from dataretrieval.waterdata.chunking import RateLimited

httpx_mock.add_response(status_code=429, headers={"Retry-After": "30"})
with pytest.raises(RateLimited):
get_samples(
service="results",
profile="fullphyschem",
monitoringLocationIdentifier="USGS-05406500",
)


def test_get_samples_summary_raises_typed_error_on_5xx(httpx_mock):
"""A 5xx from the Samples summary endpoint raises ServiceUnavailable."""
from dataretrieval.waterdata.chunking import ServiceUnavailable

httpx_mock.add_response(status_code=503)
with pytest.raises(ServiceUnavailable):
get_samples_summary(monitoringLocationIdentifier="USGS-04183500")


def test_check_profiles():
"""Tests that correct errors are raised for invalid profiles."""
with pytest.raises(ValueError):
Expand Down
Loading