diff --git a/modelskill/comparison/_collection.py b/modelskill/comparison/_collection.py index f08e6592..a251fce6 100644 --- a/modelskill/comparison/_collection.py +++ b/modelskill/comparison/_collection.py @@ -1,4 +1,5 @@ from __future__ import annotations +from copy import deepcopy import os from pathlib import Path import tempfile @@ -253,12 +254,10 @@ def rename(self, mapping: Dict[str, str]) -> "ComparerCollection": return ComparerCollection(cmps) @overload - def __getitem__(self, x: slice | Iterable[Hashable]) -> ComparerCollection: - ... + def __getitem__(self, x: slice | Iterable[Hashable]) -> ComparerCollection: ... @overload - def __getitem__(self, x: int | Hashable) -> Comparer: - ... + def __getitem__(self, x: int | Hashable) -> Comparer: ... def __getitem__( self, x: int | Hashable | slice | Iterable[Hashable] @@ -298,15 +297,8 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[Comparer]: return iter(self._comparers.values()) - def __copy__(self) -> "ComparerCollection": - cls = self.__class__ - cp = cls.__new__(cls) - # TODO should this use deepcopy? - cp.__init__(list(self._comparers)) # type: ignore - return cp - def copy(self) -> "ComparerCollection": - return self.__copy__() + return deepcopy(self) def __add__( self, other: Union["Comparer", "ComparerCollection"] diff --git a/tests/test_comparercollection.py b/tests/test_comparercollection.py index 544d06a3..dec0b8b3 100644 --- a/tests/test_comparercollection.py +++ b/tests/test_comparercollection.py @@ -569,3 +569,13 @@ def test_peak_ratio_2(cc_pr): sk = cc_pr.skill(metrics=["peak_ratio"]) assert "peak_ratio" in sk.data.columns assert sk.to_dataframe()["peak_ratio"].values == pytest.approx(1.0799999095653732) + + +def test_copy(cc): + cc2 = cc.copy() + assert cc2.n_models == 3 + assert cc2.n_points == 10 + assert cc2.start_time == pd.Timestamp("2019-01-01") + assert cc2.end_time == pd.Timestamp("2019-01-07") + assert cc2.obs_names == ["fake point obs", "fake track obs"] + assert cc2.mod_names == ["m1", "m2", "m3"]