diff --git a/src/unified_graphics/diag.py b/src/unified_graphics/diag.py index 4133e14b..a9370de1 100644 --- a/src/unified_graphics/diag.py +++ b/src/unified_graphics/diag.py @@ -2,17 +2,11 @@ from collections import namedtuple from datetime import datetime, timedelta from enum import Enum -from typing import Union -from urllib.parse import urlparse import numpy as np import pandas as pd import sqlalchemy as sa -import xarray as xr -import zarr # type: ignore -from s3fs import S3FileSystem, S3Map # type: ignore from werkzeug.datastructures import MultiDict -from xarray.core.dataset import Dataset from .models import Analysis, WeatherModel @@ -88,217 +82,6 @@ def get_model_metadata(session) -> ModelMetadata: ) -def get_store(url: str) -> Union[str, S3Map]: - result = urlparse(url) - if result.scheme in ["", "file"]: - return result.path - - if result.scheme != "s3": - raise ValueError(f"Unsupported protocol '{result.scheme}' for URI: '{url}'") - - region = os.environ.get("AWS_REGION", "us-east-1") - s3 = S3FileSystem( - key=os.environ.get("AWS_ACCESS_KEY_ID"), - secret=os.environ.get("AWS_SECRET_ACCESS_KEY"), - token=os.environ.get("AWS_SESSION_TOKEN"), - client_kwargs={"region_name": region}, - ) - - return S3Map(root=f"{result.netloc}{result.path}", s3=s3, check=False) - - -def open_diagnostic( - diag_zarr: str, - model: str, - system: str, - domain: str, - background: str, - frequency: str, - variable: Variable, - initialization_time: str, - loop: MinimLoop, -) -> xr.Dataset: - store = get_store(diag_zarr) - group = ( - f"/{model}/{system}/{domain}/{background}/{frequency}" - f"/{variable.value}/{initialization_time}/{loop.value}" - ) - return xr.open_zarr(store, group=group, consolidated=False) - - -def parse_filter_value(value): - if value == "true": - return 1 - - if value == "false": - return 0 - - try: - return float(value) - except ValueError: - return value - - -# TODO: Refactor to a class -# I think this might belong in a different module. It could be a class or set of classes -# that represent different filters that can be added together into a filtering pipeline -def get_bounds(filters: MultiDict): - for coord, value in filters.items(): - extent = np.array( - [ - [parse_filter_value(digit) for digit in pair.split(",")] - for pair in value.split("::") - ] - ) - yield coord, extent.min(axis=0), extent.max(axis=0) - - -def apply_filters(dataset: xr.Dataset, filters: MultiDict) -> Dataset: - for coord, lower, upper in get_bounds(filters): - data_array = dataset[coord] - dataset = dataset.where((data_array >= lower) & (data_array <= upper)).dropna( - dim="nobs" - ) - - # If the is_used filter is not passed, our default behavior is to include only used - # observations. - if "is_used" not in filters: - dataset = dataset.where(dataset["is_used"]).dropna(dim="nobs") - - return dataset - - -def scalar( - diag_zarr: str, - model: str, - system: str, - domain: str, - background: str, - frequency: str, - variable: Variable, - initialization_time: str, - loop: MinimLoop, - filters: MultiDict, -) -> pd.DataFrame: - data = open_diagnostic( - diag_zarr, - model, - system, - domain, - background, - frequency, - variable, - initialization_time, - loop, - ) - data = apply_filters(data, filters) - - return data.to_dataframe() - - -def temperature( - diag_zarr: str, - model: str, - system: str, - domain: str, - background: str, - frequency: str, - initialization_time: str, - loop: MinimLoop, - filters: MultiDict, -) -> pd.DataFrame: - return scalar( - diag_zarr, - model, - system, - domain, - background, - frequency, - Variable.TEMPERATURE, - initialization_time, - loop, - filters, - ) - - -def moisture( - diag_zarr: str, - model: str, - system: str, - domain: str, - background: str, - frequency: str, - initialization_time: str, - loop: MinimLoop, - filters: MultiDict, -) -> pd.DataFrame: - return scalar( - diag_zarr, - model, - system, - domain, - background, - frequency, - Variable.MOISTURE, - initialization_time, - loop, - filters, - ) - - -def pressure( - diag_zarr: str, - model: str, - system: str, - domain: str, - background: str, - frequency: str, - initialization_time: str, - loop: MinimLoop, - filters: MultiDict, -) -> pd.DataFrame: - return scalar( - diag_zarr, - model, - system, - domain, - background, - frequency, - Variable.PRESSURE, - initialization_time, - loop, - filters, - ) - - -def wind( - diag_zarr: str, - model: str, - system: str, - domain: str, - background: str, - frequency: str, - initialization_time: str, - loop: MinimLoop, - filters: MultiDict, -) -> pd.DataFrame | pd.Series: - data = open_diagnostic( - diag_zarr, - model, - system, - domain, - background, - frequency, - Variable.WIND, - initialization_time, - loop, - ) - - data = apply_filters(data, filters) - - return data.to_dataframe() - - def magnitude(dataset: pd.DataFrame) -> pd.DataFrame: return dataset.groupby(level=0).aggregate( { @@ -311,19 +94,71 @@ def magnitude(dataset: pd.DataFrame) -> pd.DataFrame: ) -def get_model_run_list( - diag_zarr: str, +def diag_observations( model: str, system: str, domain: str, background: str, frequency: str, - variable: Variable, -): - store = get_store(diag_zarr) - path = "/".join([model, system, domain, background, frequency, variable.value]) - with zarr.open_group(store, mode="r", path=path) as group: - return group.group_keys() + variable: str, + init_time: datetime, + loop: str, + uri: str, + filters: dict = {}, +) -> pd.DataFrame | pd.Series: + model_config = "_".join((model, background, system, domain, frequency)) + + is_used = filters.pop("is_used", True) + parquet_filters = [ + ("loop", "=", loop), + ("initialization_time", "=", init_time), + ] + + if isinstance(is_used, bool): + parquet_filters.append(("is_used", "=", is_used)) + + df = pd.read_parquet( + "/".join((uri, model_config, variable)), + columns=[ + "obs_minus_forecast_adjusted", + "obs_minus_forecast_unadjusted", + "observation", + "latitude", + "longitude", + "is_used", + ], + filters=parquet_filters, + ) + + # To apply the filters, we need the vector components in the columns, not + # the rows. + # FIXME: We should consider changing how we store the vector data so we + # don't have to unstack it every time. + if "component" in df.index.names: + # FIXME: Specifically unstack the component level of the index because + # I'm seeing some data where the index is (component, nobs) instead of + # (nobs, component) + df = df.unstack("component") # type: ignore + + # Iterate over each filter and apply it + for col_name, filter_value in filters.items(): + arr = np.array(filter_value) + + # Boolean mask for the rows in the data that are within the range + # specified by the filter + mask = (df[col_name] >= arr.min(axis=0)) & (df[col_name] <= arr.max(axis=0)) + + # In the event of a vector variable, we will have a DataFrame mask + # instead of a Series, which we need to flatten to a series which + # evaluates to True only when every column in the frame is True. If any + # column is False, this row should be excluded from the data + if len(mask.shape) > 1: + mask = mask.all(axis="columns") + + # Apply the mask + df = df[mask] + + return df def history( diff --git a/src/unified_graphics/routes.py b/src/unified_graphics/routes.py index c4033820..b6b3d3b5 100644 --- a/src/unified_graphics/routes.py +++ b/src/unified_graphics/routes.py @@ -10,6 +10,7 @@ stream_template, url_for, ) +from werkzeug.datastructures import MultiDict from zarr.errors import FSPathExistNotDir, GroupNotFoundError # type: ignore from unified_graphics import diag @@ -18,6 +19,64 @@ bp = Blueprint("api", __name__) +def parse_filters(query: MultiDict) -> dict: + """Return a dictionary defining filters for diagnostic data + + You can pass a MultiDict (which is typically provided by Werkzeug/Flask to + represent the query string in a request) to this function to convert it + into a regular dictionary that defines a set of filters to limit + observations in diagnostic data. It will convert strings of "true" or + "false" to their boolean equivalents, and it treats two values separated by + a "::" as a range that is converted to a list. Order is preserved in the + ranges, so there is no guarantee that the first item in the list is the + lower bound. + + >>> query = MultiDict([ + ... ("obs_minus_forecast_adjusted", "0::0.3"), + ... ("obs_minus_forecast_adjusted", "-1::1"), + ... ("is_used", "true"), + ... ]) + >>> parse_filters(query) + ... {"obs_minus_forecast_adjusted": [[0.0, 0.3], [-1.0, 1.0]], "is_used": True} + """ + + def parse_value(value): + """Parse the string value from the query string into a usable filter value""" + + # If this value defines a range, split it into multiple values and + # parse each one individually. + if "::" in value: + return [parse_value(tok) for tok in value.split("::")] + + # If this is a true/false value, convert it to a boolean. We use "true" + # and "false" instead of the Pythonic "True" or "False" because these + # values are probably coming from a browser, the lower case forms are + # correct for JavaScript/JSON. + if value in ["true", "false"]: + return value == "true" + + # If the value is neither a range nor a boolean, we will assume it's a + # number. If the conversion to a float fails, this is an invalid value. + try: + return float(value) + except ValueError: + return value + + filters = {} + for col, value_list in query.lists(): + # Query strings can have the same key repeated multiple times - e.g. + # ?longitude=12&longitude=13 - so the MultiDict gives us a list of + # values for each key in the query string, which we iterate through and + # parse + val = [parse_value(val) for val in value_list] + + # If there was only one value for this key, we extract it from the list + # as a single value, otherwise we keep the list of parsed values as-is + filters[col] = val if len(val) > 1 else val[0] + + return filters + + @bp.errorhandler(GroupNotFoundError) def handle_diag_group_not_found(e): current_app.logger.exception("Unable to read diagnostic group") @@ -167,22 +226,19 @@ def history(model, system, domain, background, frequency, variable, loop): def diagnostics( model, system, domain, background, frequency, variable, initialization_time, loop ): - try: - v = diag.Variable(variable) - except ValueError: - return jsonify(msg=f"Variable not found: '{variable}'"), 404 - - variable_diagnostics = getattr(diag, v.name.lower()) - data = variable_diagnostics( - current_app.config["DIAG_ZARR"], + filters = parse_filters(request.args) + + data = diag.diag_observations( model, system, domain, background, frequency, - initialization_time, - diag.MinimLoop(loop), - request.args, + variable, + datetime.fromisoformat(initialization_time), + loop, + current_app.config["DIAG_PARQUET"], + filters, )[ [ "obs_minus_forecast_adjusted", @@ -193,8 +249,7 @@ def diagnostics( ] ] - if "component" in data.index.names: - data = data.unstack() + if "component" in data.columns.names: data.columns = ["_".join(col) for col in data.columns] return data.to_json(orient="records"), {"Content-Type": "application/json"} @@ -207,22 +262,19 @@ def diagnostics( def magnitude( model, system, domain, background, frequency, variable, initialization_time, loop ): - try: - v = diag.Variable(variable) - except ValueError: - return jsonify(msg=f"Variable not found: '{variable}'"), 404 - - variable_diagnostics = getattr(diag, v.name.lower()) - data = variable_diagnostics( - current_app.config["DIAG_ZARR"], + filters = parse_filters(request.args) + + data = diag.diag_observations( model, system, domain, background, frequency, - initialization_time, - diag.MinimLoop(loop), - request.args, + variable, + datetime.fromisoformat(initialization_time), + loop, + current_app.config["DIAG_PARQUET"], + filters, )[ [ "obs_minus_forecast_adjusted", @@ -232,6 +284,6 @@ def magnitude( "latitude", ] ] - data = diag.magnitude(data) + data = diag.magnitude(data.stack()) return data.to_json(orient="records"), {"Content-Type": "application/json"} diff --git a/tests/test_diag.py b/tests/test_diag.py index 86c041cd..e1acd84f 100644 --- a/tests/test_diag.py +++ b/tests/test_diag.py @@ -1,22 +1,16 @@ -import uuid from datetime import datetime from functools import partial import numpy as np import pandas as pd import pytest -import xarray as xr from botocore.session import Session from moto.server import ThreadedMotoServer -from s3fs import S3FileSystem, S3Map from werkzeug.datastructures import MultiDict from unified_graphics import diag from unified_graphics.models import Analysis, WeatherModel -# Global resources for s3 -test_bucket_name = "osti-modeling-dev-rtma-vis" - @pytest.fixture def aws_credentials(monkeypatch): @@ -48,11 +42,6 @@ def s3_client(aws_credentials, moto_server): return session.create_client("s3", endpoint_url=moto_server) -@pytest.fixture -def test_key_prefix(): - return f"/test/{uuid.uuid4()}/" - - def test_get_model_metadata(session): model_run_list = [ ("RTMA", "WCOSS", "CONUS", "REALTIME", "HRRR", "2023-03-17T14:00"), @@ -85,181 +74,6 @@ def test_get_model_metadata(session): ) -@pytest.mark.parametrize( - "uri,expected", - [ - ("file:///tmp/diag.zarr", "/tmp/diag.zarr"), - ("/tmp/diag.zarr", "/tmp/diag.zarr"), - ], -) -def test_get_store_file(uri, expected): - result = diag.get_store(uri) - - assert result == expected - - -def test_get_store_s3(moto_server, s3_client, monkeypatch): - client = {"region_name": "us-east-1"} - uri = "s3://bucket/prefix/diag.zarr" - s3_client.create_bucket(Bucket="bucket") - s3_client.put_object(Bucket="bucket", Body=b"Test object", Key="prefix/diag.zarr") - - monkeypatch.setattr( - diag, - "S3FileSystem", - partial(diag.S3FileSystem, endpoint_url=moto_server), - ) - - result = diag.get_store(uri) - - assert result == S3Map( - root=uri, - s3=S3FileSystem( - client_kwargs=client, - endpoint_url=moto_server, - ), - check=False, - ) - - -def test_open_diagnostic(tmp_path, test_dataset): - diag_zarr_file = str(tmp_path / "test_diag.zarr") - expected = test_dataset() - group = "/".join( - ( - expected.model, - expected.system, - expected.domain, - expected.background, - expected.frequency, - expected.name, - expected.initialization_time, - expected.loop, - ) - ) - - expected.to_zarr(diag_zarr_file, group=group, consolidated=False) - - result = diag.open_diagnostic( - diag_zarr_file, - expected.model, - expected.system, - expected.domain, - expected.background, - expected.frequency, - diag.Variable(expected.name), - expected.initialization_time, - diag.MinimLoop(expected.loop), - ) - - xr.testing.assert_equal(result, expected) - - -@pytest.mark.parametrize( - "uri,expected", - [ - ( - "foo://an/unknown/uri.zarr", - "Unsupported protocol 'foo' for URI: 'foo://an/unknown/uri.zarr'", - ), - ( - "ftp://an/unsupported/uri.zarr", - "Unsupported protocol 'ftp' for URI: 'ftp://an/unsupported/uri.zarr'", - ), - ], -) -def test_open_diagnostic_unknown_uri(uri, expected): - model = "RTMA" - system = "WCOSS" - domain = "CONUS" - background = "HRRR" - frequency = "REALTIME" - init_time = "2022-05-16T04:00" - - with pytest.raises(ValueError, match=expected): - diag.open_diagnostic( - uri, - model, - system, - domain, - background, - frequency, - init_time, - diag.Variable.WIND, - diag.MinimLoop.GUESS, - ) - - -@pytest.mark.usefixtures("aws_credentials") -def test_open_diagnostic_s3(moto_server, test_dataset, monkeypatch): - store = "s3://test_open_diagnostic_s3/test_diag.zarr" - expected = test_dataset() - group = "/".join( - ( - expected.model, - expected.system, - expected.domain, - expected.background, - expected.frequency, - expected.name, - expected.initialization_time, - expected.loop, - ) - ) - - monkeypatch.setattr( - diag, - "S3FileSystem", - partial(diag.S3FileSystem, endpoint_url=moto_server), - ) - - expected.to_zarr( - store, - group=group, - consolidated=False, - storage_options={"endpoint_url": moto_server}, - ) - - result = diag.open_diagnostic( - store, - expected.model, - expected.system, - expected.domain, - expected.background, - expected.frequency, - diag.Variable(expected.name), - expected.initialization_time, - diag.MinimLoop(expected.loop), - ) - - xr.testing.assert_equal(result, expected) - - -@pytest.mark.parametrize( - "mapping,expected", - [ - ([("a", "1")], [("a", np.array([1.0]), np.array([1.0]))]), - ([("a", "1::2")], [("a", np.array([1.0]), np.array([2.0]))]), - ([("a", "2,4::3,1")], [("a", np.array([2.0, 1.0]), np.array([3.0, 4.0]))]), - ], - scope="class", -) -class TestGetBounds: - @pytest.fixture(scope="class") - def result(self, mapping): - filters = MultiDict(mapping) - return list(diag.get_bounds(filters)) - - def test_coord(self, result, expected): - assert result[0][0] == expected[0][0] - - def test_lower_bounds(self, result, expected): - assert (result[0][1] == expected[0][1]).all() - - def test_upper_bounds(self, result, expected): - assert (result[0][2] == expected[0][2]).all() - - def test_history(tmp_path, test_dataset, diag_parquet): run_list = [ { diff --git a/tests/test_unified_graphics.py b/tests/test_unified_graphics.py index 5a551d88..ddb693d1 100644 --- a/tests/test_unified_graphics.py +++ b/tests/test_unified_graphics.py @@ -5,6 +5,7 @@ import xarray as xr from unified_graphics import create_app +from unified_graphics.etl.diag import prep_dataframe def get_group(ds: xr.Dataset) -> str: @@ -23,7 +24,18 @@ def get_group(ds: xr.Dataset) -> str: def save(store: Path, data: xr.Dataset): - data.to_zarr(store, group=get_group(data), consolidated=False) + # FIXME: We should really use etl.diag.save here instead of trying to copy + # the saving logic. + parquet_file = ( + store + / "_".join( + (data.model, data.background, data.system, data.domain, data.frequency) + ) + / data.name + ) + prep_dataframe(data).to_parquet( + parquet_file, engine="pyarrow", index=True, partition_cols=["loop"] + ) @pytest.fixture(scope="module") @@ -43,7 +55,7 @@ def diag_zarr_path(tmp_path): @pytest.fixture -def t(model, diag_zarr_path, test_dataset): +def t(model, tmp_path, test_dataset): ds = test_dataset( **model, initialization_time="2022-05-16T04:00", @@ -53,16 +65,16 @@ def t(model, diag_zarr_path, test_dataset): forecast_unadjusted=[0, 1, -1], longitude=[90, 91, 89], latitude=[22, 23, 24], - is_used=[1, 1, 0], + is_used=[True, True, False], ) - save(diag_zarr_path, ds) + save(tmp_path, ds) return ds @pytest.fixture -def uv(model, diag_zarr_path, test_dataset): +def uv(model, tmp_path, test_dataset): ds = test_dataset( **model, variable="uv", @@ -72,11 +84,11 @@ def uv(model, diag_zarr_path, test_dataset): forecast_unadjusted=[[0, 0], [1, 1]], longitude=[90, 91], latitude=[22, 23], - is_used=[1, 1], + is_used=[True, True], component=["u", "v"], ) - save(diag_zarr_path, ds) + save(tmp_path, ds) return ds @@ -513,32 +525,4 @@ def test_diag_not_found(variable, client): ) assert response.status_code == 404 - assert response.json == {"msg": "Diagnostic file group not found"} - - -@pytest.mark.parametrize( - "variable", - ["t", "q", "ps", "uv"], -) -def test_diag_read_error(variable, app, client): - Path(app.config["DIAG_ZARR"].replace("file://", "")).touch() - - response = client.get( - f"/diag/RTMA/WCOSS/CONUS/HRRR/REALTIME/{variable}/2022-05-05T14:00/ges/" - ) - - assert response.status_code == 500 - assert response.json == {"msg": "Unable to read diagnostic file group"} - - -@pytest.mark.parametrize( - "url", - [ - "not_a_variable/2022-05-05T14:00/ges/", - ], -) -def test_unknown_variable(url, client): - response = client.get(f"/diag/RTMA/WCOSS/CONUS/HRRR/REALTIME/{url}") - - assert response.status_code == 404 - assert response.json == {"msg": "Variable not found: 'not_a_variable'"} + assert response.json == {"msg": "Diagnostic file not found"}