diff --git a/modelskill/comparison/_comparison.py b/modelskill/comparison/_comparison.py index 1e96a3cf..7510f2ac 100644 --- a/modelskill/comparison/_comparison.py +++ b/modelskill/comparison/_comparison.py @@ -640,6 +640,9 @@ def mod_names(self) -> List[str]: """List of model result names""" return list(self.raw_mod_data.keys()) + def __contains__(self, key: str) -> bool: + return key in self.data.data_vars + @property def aux_names(self) -> List[str]: """List of auxiliary data names""" diff --git a/modelskill/matching.py b/modelskill/matching.py index 463e8f5b..affb03fd 100644 --- a/modelskill/matching.py +++ b/modelskill/matching.py @@ -4,11 +4,11 @@ import warnings from typing import ( - Dict, Iterable, Collection, List, Literal, + Mapping, Optional, Union, Sequence, @@ -26,13 +26,13 @@ from . import model_result, Quantity from .timeseries import TimeSeries -from .types import GeometryType, Period +from .types import Period +from .model._base import Alignable from .model.grid import GridModelResult from .model.dfsu import DfsuModelResult from .model.track import TrackModelResult -from .model.point import PointModelResult from .model.dummy import DummyModelResult -from .obs import Observation, PointObservation, TrackObservation +from .obs import Observation, observation from .comparison import Comparer, ComparerCollection from . import __version__ @@ -164,7 +164,7 @@ def from_matched( @overload def match( - obs: PointObservation | TrackObservation, + obs: Observation, mod: Union[MRInputType, Sequence[MRInputType]], *, obs_item: Optional[IdxOrNameTypes] = None, @@ -177,7 +177,7 @@ def match( @overload def match( - obs: Iterable[PointObservation | TrackObservation], + obs: Iterable[Observation], mod: Union[MRInputType, Sequence[MRInputType]], *, obs_item: Optional[IdxOrNameTypes] = None, @@ -335,6 +335,9 @@ def _single_obs_compare( ) matched_data.attrs["weight"] = obs.weight + # TODO where does this line belong? + matched_data.attrs["modelskill_version"] = __version__ + return Comparer(matched_data=matched_data, raw_mod_data=raw_mod_data) @@ -348,10 +351,9 @@ def _get_global_start_end(idxs: Iterable[pd.DatetimeIndex]) -> Period: def match_space_time( - observation: PointObservation | TrackObservation, - raw_mod_data: Dict[str, PointModelResult | TrackModelResult], + observation: Observation, + raw_mod_data: Mapping[str, Alignable], max_model_gap: float | None = None, - spatial_tolerance: float = 1e-3, ) -> xr.Dataset: """Match observation with one or more model results in time domain and return as xr.Dataset in the format used by modelskill.Comparer @@ -365,114 +367,63 @@ def match_space_time( ---------- observation : Observation Observation to be matched - raw_mod_data : Dict[str, PointModelResult | TrackModelResult] - Dictionary of model results ready for interpolation + raw_mod_data : Mapping[str, Alignable] + Mapping of model results ready for interpolation max_model_gap : Optional[TimeDeltaTypes], optional In case of non-equidistant model results (e.g. event data), max_model_gap can be given e.g. as seconds, by default None - spatial_tolerance : float, optional - Tolerance for spatial matching, by default 1e-3 Returns ------- xr.Dataset Matched data in the format used by modelskill.Comparer """ - obs_name = "Observation" - mod_names = list(raw_mod_data.keys()) idxs = [m.time for m in raw_mod_data.values()] period = _get_global_start_end(idxs) - assert isinstance(observation, (PointObservation, TrackObservation)) - gtype = "point" if isinstance(observation, PointObservation) else "track" observation = observation.trim(period.start, period.end) data = observation.data data.attrs["name"] = observation.name - data = data.rename({observation.name: obs_name}) + data = data.rename({observation.name: "Observation"}) - for _, mr in raw_mod_data.items(): - if isinstance(mr, PointModelResult): - assert len(observation.time) > 0 - mri: TimeSeries = mr.interp_time( - new_time=observation.time, max_gap=max_model_gap - ) - else: - mri = mr + for mr in raw_mod_data.values(): + # TODO is `align` the correct name for this operation? + aligned = mr.align(observation, max_gap=max_model_gap) - if isinstance(observation, TrackObservation): - assert isinstance(mri, TrackModelResult) - mri.data = _select_overlapping_trackdata_with_tolerance( - observation=observation, mri=mri, spatial_tolerance=spatial_tolerance + if overlapping_names := set(aligned.data_vars) & set(data.data_vars): + warnings.warn( + "Model result has overlapping variable names with observation. Renamed with suffix `_model`." ) + aligned = aligned.rename({v: f"{v}_mod" for v in overlapping_names}) - # check that model and observation have non-overlapping variables - if overlapping_names := set(mri.data.data_vars).intersection( - set(data.data_vars) - ): - raise ValueError( - f"Model: '{mr.name}' and observation have overlapping variables: {overlapping_names}" - ) - - # TODO: is name needed? - for v in list(mri.data.data_vars): - data[v] = mri.data[v] + data.update(aligned) # drop NaNs in model and observation columns (but allow NaNs in aux columns) - cols = list( - data.filter_by_attrs(kind=lambda k: k in ["model", "observation"]).data_vars - ) - data = data.dropna(dim="time", subset=cols) - - for n in mod_names: - data[n].attrs["kind"] = "model" + def mo_kind(k: str) -> bool: + return k in ["model", "observation"] - data.attrs["gtype"] = gtype - data.attrs["modelskill_version"] = __version__ + # TODO mo_cols vs non_aux_cols? + mo_cols = data.filter_by_attrs(kind=mo_kind).data_vars + data = data.dropna(dim="time", subset=mo_cols) return data -# TODO move to TrackModelResult -def _select_overlapping_trackdata_with_tolerance( - observation: TrackObservation, mri: TrackModelResult, spatial_tolerance: float -) -> xr.Dataset: - mod_df = mri.data.to_dataframe() - obs_df = observation.data.to_dataframe() - - # 1. inner join on time - df = mod_df.join(obs_df, how="inner", lsuffix="_mod", rsuffix="_obs") - - # 2. remove model points outside observation track - n_points = len(df) - keep_x = np.abs((df.x_mod - df.x_obs)) < spatial_tolerance - keep_y = np.abs((df.y_mod - df.y_obs)) < spatial_tolerance - df = df[keep_x & keep_y] - if n_points_removed := n_points - len(df): - warnings.warn( - f"Removed {n_points_removed} model points outside observation track (spatial_tolerance={spatial_tolerance})" - ) - return mri.data.sel(time=df.index) - - def _parse_single_obs( obs: ObsInputType, - item: Optional[int | str] = None, - gtype: Optional[GeometryTypes] = None, -) -> PointObservation | TrackObservation: - if isinstance(obs, (PointObservation, TrackObservation)): - if item is not None: + obs_item: Optional[int | str], + gtype: Optional[GeometryTypes], +) -> Observation: + if isinstance(obs, Observation): + if obs_item is not None: raise ValueError( "obs_item argument not allowed if obs is an modelskill.Observation type" ) return obs else: - if (gtype is not None) and ( - GeometryType.from_string(gtype) == GeometryType.TRACK - ): - return TrackObservation(obs, item=item) - else: - return PointObservation(obs, item=item) + # observation factory can only handle track and point + return observation(obs, item=obs_item, gtype=gtype) # type: ignore def _parse_models( diff --git a/modelskill/model/_base.py b/modelskill/model/_base.py index b153d7d2..4f045c4d 100644 --- a/modelskill/model/_base.py +++ b/modelskill/model/_base.py @@ -1,12 +1,13 @@ from __future__ import annotations from collections import Counter -from typing import List, Optional, Protocol, Sequence, TYPE_CHECKING +from typing import Any, List, Optional, Protocol, Sequence, TYPE_CHECKING from dataclasses import dataclass import warnings import pandas as pd if TYPE_CHECKING: + import xarray as xr from .point import PointModelResult from .track import TrackModelResult @@ -76,15 +77,26 @@ def extract( self, observation: PointObservation | TrackObservation, spatial_method: Optional[str] = None, - ) -> PointModelResult | TrackModelResult: - ... + ) -> PointModelResult | TrackModelResult: ... def _extract_point( self, observation: PointObservation, spatial_method: Optional[str] = None - ) -> PointModelResult: - ... + ) -> PointModelResult: ... def _extract_track( self, observation: TrackObservation, spatial_method: Optional[str] = None - ) -> TrackModelResult: - ... + ) -> TrackModelResult: ... + + +class Alignable(Protocol): + + @property + def time(self) -> pd.DatetimeIndex: ... + + def align( + self, + observation: Observation, + **kwargs: Any, + ) -> xr.Dataset: ... + + # the attributues of the returned dataset have additional requirements, but we can't express that here diff --git a/modelskill/model/point.py b/modelskill/model/point.py index 3860d49c..0e14e5ee 100644 --- a/modelskill/model/point.py +++ b/modelskill/model/point.py @@ -5,13 +5,14 @@ import xarray as xr import pandas as pd -from ..obs import PointObservation +from ..obs import Observation, PointObservation from ..types import PointType from ..quantity import Quantity from ..timeseries import TimeSeries, _parse_point_input +from ._base import Alignable -class PointModelResult(TimeSeries): +class PointModelResult(TimeSeries, Alignable): """Construct a PointModelResult from a 0d data source: dfs0 file, mikeio.Dataset/DataArray, pandas.DataFrame/Series or xarray.Dataset/DataArray @@ -77,53 +78,47 @@ def extract( raise NotImplementedError( "spatial interpolation not possible when matching point model results with point observations" ) - # TODO check x,y,z return self - def interp_time( - self, - new_time: pd.DatetimeIndex, - dropna: bool = True, - max_gap: float | None = None, - **kwargs: Any, - ) -> PointModelResult: - """Interpolate time series to new time index + def interp_time(self, observation: Observation, **kwargs: Any) -> PointModelResult: + """ + Interpolate model result to the time of the observation wrapper around xarray.Dataset.interp() Parameters ---------- - new_time : pd.DatetimeIndex - new time index - dropna : bool, optional - drop nan values, by default True + observation : Observation + The observation to interpolate to **kwargs - keyword arguments passed to xarray.Dataset.interp() + Additional keyword arguments passed to xarray.interp + Returns ------- - TimeSeries - interpolated time series + PointModelResult + Interpolated model result """ - if not isinstance(new_time, pd.DatetimeIndex): - try: - new_time = pd.DatetimeIndex(new_time) - except Exception: - raise ValueError( - "new_time must be a pandas DatetimeIndex (or convertible to one)" - ) - - # TODO: is it necessary to dropna before interpolation? + ds = self.align(observation, **kwargs) + return PointModelResult(ds) + + def align( + self, + observation: Observation, + *, + max_gap: float | None = None, + **kwargs: Any, + ) -> xr.Dataset: + new_time = observation.time + dati = self.data.dropna("time").interp( time=new_time, assume_sorted=True, **kwargs ) - if dropna: - dati = dati.dropna(dim="time") pmr = PointModelResult(dati) if max_gap is not None: pmr = pmr._remove_model_gaps(mod_index=self.time, max_gap=max_gap) - return pmr + return pmr.data def _remove_model_gaps( self, diff --git a/modelskill/model/track.py b/modelskill/model/track.py index a500ae1d..021218e8 100644 --- a/modelskill/model/track.py +++ b/modelskill/model/track.py @@ -1,15 +1,18 @@ from __future__ import annotations -from typing import Optional, Sequence +from typing import Any, Optional, Sequence +import warnings +import numpy as np import xarray as xr -from ..obs import TrackObservation +from ..obs import Observation, TrackObservation from ..types import TrackType from ..quantity import Quantity from ..timeseries import TimeSeries, _parse_track_input +from ._base import Alignable -class TrackModelResult(TimeSeries): +class TrackModelResult(TimeSeries, Alignable): """Construct a TrackModelResult from a dfs0 file, mikeio.Dataset, pandas.DataFrame or a xarray.Datasets @@ -76,5 +79,25 @@ def extract( raise NotImplementedError( "spatial interpolation not possible when matching track model results with track observations" ) - # TODO check x,y,z return self + + def align(self, observation: Observation, **kwargs: Any) -> xr.Dataset: + spatial_tolerance = 1e-3 + + mri = self + mod_df = mri.data.to_dataframe() + obs_df = observation.data.to_dataframe() + + # 1. inner join on time + df = mod_df.join(obs_df, how="inner", lsuffix="_mod", rsuffix="_obs") + + # 2. remove model points outside observation track + n_points = len(df) + keep_x = np.abs((df.x_mod - df.x_obs)) < spatial_tolerance + keep_y = np.abs((df.y_mod - df.y_obs)) < spatial_tolerance + df = df[keep_x & keep_y] + if n_points_removed := n_points - len(df): + warnings.warn( + f"Removed {n_points_removed} model points outside observation track (spatial_tolerance={spatial_tolerance})" + ) + return mri.data.sel(time=df.index) diff --git a/tests/model/test_point.py b/tests/model/test_point.py index 3a00a23e..1dfbd73e 100644 --- a/tests/model/test_point.py +++ b/tests/model/test_point.py @@ -245,3 +245,35 @@ def test_point_model_result_from_nc_file(): assert mr.x == pytest.approx(366844) assert mr.y == pytest.approx(6154291) assert mr.name == "smhi_2095_klagshamn" + + +def test_interp_time(): + + df = pd.DataFrame( + { + "WL": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + "aux1": [1.1, 2.1, 3.1, 4.1, 5.1, 6.1], + "time": pd.date_range("2019-01-01", periods=6, freq="D"), + } + ).set_index("time") + + mr = ms.PointModelResult(df, item="WL", aux_items="aux1") + + obs_df = pd.DataFrame( + { + "WL": [1.5, 2.5], + "time": [ + pd.Timestamp("2019-01-01 12:15"), + pd.Timestamp("2019-01-03 18:30"), + ], + } + ).set_index("time") + + obs = ms.PointObservation(obs_df, item="WL") + + interp = mr.interp_time(obs) + assert interp.time[0] == pd.Timestamp("2019-01-01 12:15") + assert interp.time[-1] == pd.Timestamp("2019-01-03 18:30") + + assert interp.data["WL"].values[0] == pytest.approx(1.5104166666666665) + assert interp.data["aux1"].values[0] == pytest.approx(1.6104166666666666) diff --git a/tests/observation/test_point_obs.py b/tests/observation/test_point_obs.py index f6691d4e..30e19ab5 100644 --- a/tests/observation/test_point_obs.py +++ b/tests/observation/test_point_obs.py @@ -223,6 +223,16 @@ def test_mikeio_iteminfo_pretty_units(): assert obs.quantity.unit == "m^3/s" +def test_point_obs_repr(df_aux): + # Some basic test to see that repr does not fail + o = ms.PointObservation(df_aux, item="WL", aux_items=["aux1"]) + assert "aux1" in repr(o) + + # TODO ignore this for now + # o.z = -1 + # assert "-1" in repr(o) + + def test_point_observation_without_coords_are_nan(): # No coords in file, no coords supplied 😳 obs = ms.PointObservation( diff --git a/tests/observation/test_track_obs.py b/tests/observation/test_track_obs.py index 681d3a5e..191b5845 100644 --- a/tests/observation/test_track_obs.py +++ b/tests/observation/test_track_obs.py @@ -345,3 +345,13 @@ def test_track_aux_items_fail(df_aux): with pytest.raises(ValueError): ms.TrackObservation(df_aux, item="WL", x_item="x", y_item="y", aux_items=["x"]) + + +def test_track_basic_repr(df_aux): + # Some basic test to see that repr does not fail + o = ms.TrackObservation( + df_aux, item="WL", x_item="x", y_item="y", aux_items=["aux1"] + ) + assert "TrackObservation" in repr(o) + assert "WL" in repr(o) + assert "aux1" in repr(o) diff --git a/tests/test_match.py b/tests/test_match.py index 5d5c5522..0fb7a20d 100644 --- a/tests/test_match.py +++ b/tests/test_match.py @@ -426,11 +426,13 @@ def test_obs_and_mod_can_not_have_same_aux_item_names(): obs = ms.PointObservation(obs_df, item="wl", aux_items=["wind_speed"]) mod = ms.PointModelResult(mod_df, item="wl", aux_items=["wind_speed"]) - with pytest.raises(ValueError, match="wind_speed"): - ms.match(obs=obs, mod=mod) + with pytest.warns(match="_model"): + cmp = ms.match(obs=obs, mod=mod) + assert "wind_speed" in cmp + assert "wind_speed_mod" in cmp # renamed -def test_mod_aux_items_must_be_unique(): +def test_mod_aux_items_overlapping_names(): obs_df = pd.DataFrame( {"wl": [1.0, 2.0, 3.0], "wind_speed": [1.0, 2.0, 3.0]}, index=pd.date_range("2017-01-01", periods=3), @@ -454,11 +456,10 @@ def test_mod_aux_items_must_be_unique(): mod2_df, item="wl", aux_items=["wind_speed"], name="remote" ) - with pytest.raises(ValueError) as e: - ms.match(obs=obs, mod=[mod, mod2]) + # we don't care which model the aux data comes from + cmp = ms.match(obs=obs, mod=[mod, mod2]) - assert "wind_speed" in str(e.value) - assert "remote" in str(e.value) + assert "wind_speed" in cmp def test_multiple_obs_not_allowed_with_non_spatial_modelresults(): diff --git a/tests/test_trackcompare.py b/tests/test_trackcompare.py index 910888cb..c170ab17 100644 --- a/tests/test_trackcompare.py +++ b/tests/test_trackcompare.py @@ -386,6 +386,7 @@ def test_df_input(obs_tiny_df, mod_tiny3): assert isinstance(obs_tiny_df, pd.DataFrame) assert len(obs_tiny_df["2017-10-27 13:00:02":"2017-10-27 13:00:02"]) == 2 + with pytest.warns(UserWarning, match="Removed 2 duplicate timestamps"): cmp = ms.match(obs_tiny_df, mod_tiny3, gtype="track")