Skip to content

Commit

Permalink
Merge pull request #413 from DHI/obs
Browse files Browse the repository at this point in the history
Refactor `matching` to be Observation type agnostic
  • Loading branch information
ecomodeller committed Mar 20, 2024
2 parents 42823e3 + 7940ec6 commit eab3744
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 131 deletions.
3 changes: 3 additions & 0 deletions modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,9 @@ def mod_names(self) -> List[str]:
"""List of model result names"""
return list(self.raw_mod_data.keys())

def __contains__(self, key: str) -> bool:
return key in self.data.data_vars

@property
def aux_names(self) -> List[str]:
"""List of auxiliary data names"""
Expand Down
117 changes: 34 additions & 83 deletions modelskill/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import warnings

from typing import (
Dict,
Iterable,
Collection,
List,
Literal,
Mapping,
Optional,
Union,
Sequence,
Expand All @@ -26,13 +26,13 @@

from . import model_result, Quantity
from .timeseries import TimeSeries
from .types import GeometryType, Period
from .types import Period
from .model._base import Alignable
from .model.grid import GridModelResult
from .model.dfsu import DfsuModelResult
from .model.track import TrackModelResult
from .model.point import PointModelResult
from .model.dummy import DummyModelResult
from .obs import Observation, PointObservation, TrackObservation
from .obs import Observation, observation
from .comparison import Comparer, ComparerCollection
from . import __version__

Expand Down Expand Up @@ -164,7 +164,7 @@ def from_matched(

@overload
def match(
obs: PointObservation | TrackObservation,
obs: Observation,
mod: Union[MRInputType, Sequence[MRInputType]],
*,
obs_item: Optional[IdxOrNameTypes] = None,
Expand All @@ -177,7 +177,7 @@ def match(

@overload
def match(
obs: Iterable[PointObservation | TrackObservation],
obs: Iterable[Observation],
mod: Union[MRInputType, Sequence[MRInputType]],
*,
obs_item: Optional[IdxOrNameTypes] = None,
Expand Down Expand Up @@ -335,6 +335,9 @@ def _single_obs_compare(
)
matched_data.attrs["weight"] = obs.weight

# TODO where does this line belong?
matched_data.attrs["modelskill_version"] = __version__

return Comparer(matched_data=matched_data, raw_mod_data=raw_mod_data)


Expand All @@ -348,10 +351,9 @@ def _get_global_start_end(idxs: Iterable[pd.DatetimeIndex]) -> Period:


def match_space_time(
observation: PointObservation | TrackObservation,
raw_mod_data: Dict[str, PointModelResult | TrackModelResult],
observation: Observation,
raw_mod_data: Mapping[str, Alignable],
max_model_gap: float | None = None,
spatial_tolerance: float = 1e-3,
) -> xr.Dataset:
"""Match observation with one or more model results in time domain
and return as xr.Dataset in the format used by modelskill.Comparer
Expand All @@ -365,114 +367,63 @@ def match_space_time(
----------
observation : Observation
Observation to be matched
raw_mod_data : Dict[str, PointModelResult | TrackModelResult]
Dictionary of model results ready for interpolation
raw_mod_data : Mapping[str, Alignable]
Mapping of model results ready for interpolation
max_model_gap : Optional[TimeDeltaTypes], optional
In case of non-equidistant model results (e.g. event data),
max_model_gap can be given e.g. as seconds, by default None
spatial_tolerance : float, optional
Tolerance for spatial matching, by default 1e-3
Returns
-------
xr.Dataset
Matched data in the format used by modelskill.Comparer
"""
obs_name = "Observation"
mod_names = list(raw_mod_data.keys())
idxs = [m.time for m in raw_mod_data.values()]
period = _get_global_start_end(idxs)

assert isinstance(observation, (PointObservation, TrackObservation))
gtype = "point" if isinstance(observation, PointObservation) else "track"
observation = observation.trim(period.start, period.end)

data = observation.data
data.attrs["name"] = observation.name
data = data.rename({observation.name: obs_name})
data = data.rename({observation.name: "Observation"})

for _, mr in raw_mod_data.items():
if isinstance(mr, PointModelResult):
assert len(observation.time) > 0
mri: TimeSeries = mr.interp_time(
new_time=observation.time, max_gap=max_model_gap
)
else:
mri = mr
for mr in raw_mod_data.values():
# TODO is `align` the correct name for this operation?
aligned = mr.align(observation, max_gap=max_model_gap)

if isinstance(observation, TrackObservation):
assert isinstance(mri, TrackModelResult)
mri.data = _select_overlapping_trackdata_with_tolerance(
observation=observation, mri=mri, spatial_tolerance=spatial_tolerance
if overlapping_names := set(aligned.data_vars) & set(data.data_vars):
warnings.warn(
"Model result has overlapping variable names with observation. Renamed with suffix `_model`."
)
aligned = aligned.rename({v: f"{v}_mod" for v in overlapping_names})

# check that model and observation have non-overlapping variables
if overlapping_names := set(mri.data.data_vars).intersection(
set(data.data_vars)
):
raise ValueError(
f"Model: '{mr.name}' and observation have overlapping variables: {overlapping_names}"
)

# TODO: is name needed?
for v in list(mri.data.data_vars):
data[v] = mri.data[v]
data.update(aligned)

# drop NaNs in model and observation columns (but allow NaNs in aux columns)
cols = list(
data.filter_by_attrs(kind=lambda k: k in ["model", "observation"]).data_vars
)
data = data.dropna(dim="time", subset=cols)

for n in mod_names:
data[n].attrs["kind"] = "model"
def mo_kind(k: str) -> bool:
return k in ["model", "observation"]

data.attrs["gtype"] = gtype
data.attrs["modelskill_version"] = __version__
# TODO mo_cols vs non_aux_cols?
mo_cols = data.filter_by_attrs(kind=mo_kind).data_vars
data = data.dropna(dim="time", subset=mo_cols)

return data


# TODO move to TrackModelResult
def _select_overlapping_trackdata_with_tolerance(
observation: TrackObservation, mri: TrackModelResult, spatial_tolerance: float
) -> xr.Dataset:
mod_df = mri.data.to_dataframe()
obs_df = observation.data.to_dataframe()

# 1. inner join on time
df = mod_df.join(obs_df, how="inner", lsuffix="_mod", rsuffix="_obs")

# 2. remove model points outside observation track
n_points = len(df)
keep_x = np.abs((df.x_mod - df.x_obs)) < spatial_tolerance
keep_y = np.abs((df.y_mod - df.y_obs)) < spatial_tolerance
df = df[keep_x & keep_y]
if n_points_removed := n_points - len(df):
warnings.warn(
f"Removed {n_points_removed} model points outside observation track (spatial_tolerance={spatial_tolerance})"
)
return mri.data.sel(time=df.index)


def _parse_single_obs(
obs: ObsInputType,
item: Optional[int | str] = None,
gtype: Optional[GeometryTypes] = None,
) -> PointObservation | TrackObservation:
if isinstance(obs, (PointObservation, TrackObservation)):
if item is not None:
obs_item: Optional[int | str],
gtype: Optional[GeometryTypes],
) -> Observation:
if isinstance(obs, Observation):
if obs_item is not None:
raise ValueError(
"obs_item argument not allowed if obs is an modelskill.Observation type"
)
return obs
else:
if (gtype is not None) and (
GeometryType.from_string(gtype) == GeometryType.TRACK
):
return TrackObservation(obs, item=item)
else:
return PointObservation(obs, item=item)
# observation factory can only handle track and point
return observation(obs, item=obs_item, gtype=gtype) # type: ignore


def _parse_models(
Expand Down
26 changes: 19 additions & 7 deletions modelskill/model/_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations
from collections import Counter
from typing import List, Optional, Protocol, Sequence, TYPE_CHECKING
from typing import Any, List, Optional, Protocol, Sequence, TYPE_CHECKING
from dataclasses import dataclass
import warnings

import pandas as pd

if TYPE_CHECKING:
import xarray as xr
from .point import PointModelResult
from .track import TrackModelResult

Expand Down Expand Up @@ -76,15 +77,26 @@ def extract(
self,
observation: PointObservation | TrackObservation,
spatial_method: Optional[str] = None,
) -> PointModelResult | TrackModelResult:
...
) -> PointModelResult | TrackModelResult: ...

def _extract_point(
self, observation: PointObservation, spatial_method: Optional[str] = None
) -> PointModelResult:
...
) -> PointModelResult: ...

def _extract_track(
self, observation: TrackObservation, spatial_method: Optional[str] = None
) -> TrackModelResult:
...
) -> TrackModelResult: ...


class Alignable(Protocol):

@property
def time(self) -> pd.DatetimeIndex: ...

def align(
self,
observation: Observation,
**kwargs: Any,
) -> xr.Dataset: ...

# the attributues of the returned dataset have additional requirements, but we can't express that here
55 changes: 25 additions & 30 deletions modelskill/model/point.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import xarray as xr
import pandas as pd

from ..obs import PointObservation
from ..obs import Observation, PointObservation
from ..types import PointType
from ..quantity import Quantity
from ..timeseries import TimeSeries, _parse_point_input
from ._base import Alignable


class PointModelResult(TimeSeries):
class PointModelResult(TimeSeries, Alignable):
"""Construct a PointModelResult from a 0d data source:
dfs0 file, mikeio.Dataset/DataArray, pandas.DataFrame/Series
or xarray.Dataset/DataArray
Expand Down Expand Up @@ -77,53 +78,47 @@ def extract(
raise NotImplementedError(
"spatial interpolation not possible when matching point model results with point observations"
)
# TODO check x,y,z
return self

def interp_time(
self,
new_time: pd.DatetimeIndex,
dropna: bool = True,
max_gap: float | None = None,
**kwargs: Any,
) -> PointModelResult:
"""Interpolate time series to new time index
def interp_time(self, observation: Observation, **kwargs: Any) -> PointModelResult:
"""
Interpolate model result to the time of the observation
wrapper around xarray.Dataset.interp()
Parameters
----------
new_time : pd.DatetimeIndex
new time index
dropna : bool, optional
drop nan values, by default True
observation : Observation
The observation to interpolate to
**kwargs
keyword arguments passed to xarray.Dataset.interp()
Additional keyword arguments passed to xarray.interp
Returns
-------
TimeSeries
interpolated time series
PointModelResult
Interpolated model result
"""
if not isinstance(new_time, pd.DatetimeIndex):
try:
new_time = pd.DatetimeIndex(new_time)
except Exception:
raise ValueError(
"new_time must be a pandas DatetimeIndex (or convertible to one)"
)

# TODO: is it necessary to dropna before interpolation?
ds = self.align(observation, **kwargs)
return PointModelResult(ds)

def align(
self,
observation: Observation,
*,
max_gap: float | None = None,
**kwargs: Any,
) -> xr.Dataset:
new_time = observation.time

dati = self.data.dropna("time").interp(
time=new_time, assume_sorted=True, **kwargs
)
if dropna:
dati = dati.dropna(dim="time")

pmr = PointModelResult(dati)
if max_gap is not None:
pmr = pmr._remove_model_gaps(mod_index=self.time, max_gap=max_gap)
return pmr
return pmr.data

def _remove_model_gaps(
self,
Expand Down
Loading

0 comments on commit eab3744

Please sign in to comment.