Skip to content

Commit

Permalink
Merge pull request #416 from DHI/inline-cc
Browse files Browse the repository at this point in the history
Remove __setitem__ from ComparerCollection
  • Loading branch information
ecomodeller committed Mar 20, 2024
2 parents d55b392 + 4243f43 commit 42823e3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 28 deletions.
31 changes: 9 additions & 22 deletions modelskill/comparison/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,15 @@ class ComparerCollection(Mapping, Scoreable):

def __init__(self, comparers: Iterable[Comparer]) -> None:
self._comparers: Dict[str, Comparer] = {}
self._insert_comparers(comparers)

for cmp in comparers:
if cmp.name in self._comparers:
# comparer with this name already exists!
# maybe the user is trying to add a new model
# or a new time period
self._comparers[cmp.name] += cmp
else:
self._comparers[cmp.name] = cmp

self.plot = ComparerCollection.plotter(self)
"""Plot using the ComparerCollectionPlotter
Expand All @@ -106,15 +114,6 @@ def __init__(self, comparers: Iterable[Comparer]) -> None:
>>> cc.plot.hist()
"""

def _insert_comparers(self, comparer: Union[Comparer, Iterable[Comparer]]) -> None:
if isinstance(comparer, Iterable):
for c in comparer:
self[c.name] = c
elif isinstance(comparer, Comparer):
self[comparer.name] = comparer
else:
pass

@property
def _name(self) -> str:
return "Observations"
Expand Down Expand Up @@ -279,18 +278,6 @@ def __getitem__(

raise TypeError(f"Invalid type for __getitem__: {type(x)}")

def __setitem__(self, x: str, value: Comparer) -> None:
assert isinstance(
value, Comparer
), f"comparer must be a Comparer, not {type(value)}"
if x in self._comparers:
# comparer with this name already exists!
# maybe the user is trying to add a new model
# or a new time period
self._comparers[x] = self._comparers[x] + value # type: ignore
else:
self._comparers[x] = value

def __len__(self) -> int:
return len(self._comparers)

Expand Down
27 changes: 21 additions & 6 deletions modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,24 @@ def _to_observation(self) -> PointObservation | TrackObservation:
else:
raise NotImplementedError(f"Unknown gtype: {self.gtype}")

def __iadd__(self, other: Comparer): # type: ignore
from ..matching import match_space_time

missing_models = set(self.mod_names) - set(other.mod_names)
if len(missing_models) == 0:
# same obs name and same model names
self.data = xr.concat([self.data, other.data], dim="time").drop_duplicates(
"time"
)
else:
self.raw_mod_data.update(other.raw_mod_data)
matched = match_space_time(
observation=self._to_observation(), raw_mod_data=self.raw_mod_data # type: ignore
)
self.data = matched

return self

def __add__(
self, other: Union["Comparer", "ComparerCollection"]
) -> "ComparerCollection" | "Comparer":
Expand All @@ -792,12 +810,9 @@ def __add__(
if len(missing_models) == 0:
# same obs name and same model names
cmp = self.copy()
cmp.data = xr.concat([cmp.data, other.data], dim="time")
# cc.data = cc.data[
# ~cc.data.time.to_index().duplicated(keep="last")
# ] # 'first'
_, index = np.unique(cmp.data["time"], return_index=True)
cmp.data = cmp.data.isel(time=index)
cmp.data = xr.concat(
[cmp.data, other.data], dim="time"
).drop_duplicates("time")

else:
raw_mod_data = self.raw_mod_data.copy()
Expand Down

0 comments on commit 42823e3

Please sign in to comment.