Skip to content

Commit

Permalink
Merge pull request #424 from DHI/fix-copy
Browse files Browse the repository at this point in the history
Use deepcopy
  • Loading branch information
daniel-caichac-DHI committed Mar 1, 2024
2 parents d030921 + 181f535 commit 4787be8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
16 changes: 4 additions & 12 deletions modelskill/comparison/_collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from copy import deepcopy
import os
from pathlib import Path
import tempfile
Expand Down Expand Up @@ -253,12 +254,10 @@ def rename(self, mapping: Dict[str, str]) -> "ComparerCollection":
return ComparerCollection(cmps)

@overload
def __getitem__(self, x: slice | Iterable[Hashable]) -> ComparerCollection:
...
def __getitem__(self, x: slice | Iterable[Hashable]) -> ComparerCollection: ...

@overload
def __getitem__(self, x: int | Hashable) -> Comparer:
...
def __getitem__(self, x: int | Hashable) -> Comparer: ...

def __getitem__(
self, x: int | Hashable | slice | Iterable[Hashable]
Expand Down Expand Up @@ -298,15 +297,8 @@ def __len__(self) -> int:
def __iter__(self) -> Iterator[Comparer]:
return iter(self._comparers.values())

def __copy__(self) -> "ComparerCollection":
cls = self.__class__
cp = cls.__new__(cls)
# TODO should this use deepcopy?
cp.__init__(list(self._comparers)) # type: ignore
return cp

def copy(self) -> "ComparerCollection":
return self.__copy__()
return deepcopy(self)

def __add__(
self, other: Union["Comparer", "ComparerCollection"]
Expand Down
10 changes: 10 additions & 0 deletions tests/test_comparercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,13 @@ def test_peak_ratio_2(cc_pr):
sk = cc_pr.skill(metrics=["peak_ratio"])
assert "peak_ratio" in sk.data.columns
assert sk.to_dataframe()["peak_ratio"].values == pytest.approx(1.0799999095653732)


def test_copy(cc):
cc2 = cc.copy()
assert cc2.n_models == 3
assert cc2.n_points == 10
assert cc2.start_time == pd.Timestamp("2019-01-01")
assert cc2.end_time == pd.Timestamp("2019-01-07")
assert cc2.obs_names == ["fake point obs", "fake track obs"]
assert cc2.mod_names == ["m1", "m2", "m3"]

0 comments on commit 4787be8

Please sign in to comment.