diff --git a/src/osekit/core_api/base_dataset.py b/src/osekit/core_api/base_dataset.py index 9ad8612e..71497418 100644 --- a/src/osekit/core_api/base_dataset.py +++ b/src/osekit/core_api/base_dataset.py @@ -222,6 +222,7 @@ def write( """ last = len(self.data) if last is None else last + self._check_duplicate_data_names(first_idx=first, last_idx=last) for data in tqdm( self.data[first:last], disable=os.getenv("DISABLE_TQDM", "False").lower() in ("true", "1", "t"), @@ -237,6 +238,7 @@ def to_dict(self) -> dict: The serialized dictionary representing the ``BaseDataset``. """ + self._check_duplicate_data_names() return { "data": {str(d): d.to_dict() for d in self.data}, "name": self._name, @@ -631,3 +633,38 @@ def _parse_file( return False else: return True + + def _check_duplicate_data_names( + self, + first_idx: int = 0, + last_idx: int | None = None, + ) -> None: + """Raise an error if the dataset contains duplicate data names. + + This method is called when the dataset is converted to a dictionary + (since the data names are dictionary keys) or when the data of the + dataset is exported. + + Parameters + ---------- + first_idx: int + Index of the first data to consider (in case of a partial export) + last_idx: int|None + Index of the last data to consider (in case of a partial export) + If None, data will be considered up to the last one. + + """ + last_idx = last_idx if last_idx is not None else len(self.data) + data_names = [data.name for data in self.data[first_idx:last_idx]] + unique_names = set(data_names) + duplicated_data_names = { + str(name) for name in unique_names if data_names.count(name) > 1 + } + if duplicated_data_names: + msg = ( + f"Duplicate data names found in the {self} {self.__class__.__name__}.\n" + f"Consider renaming the following data which names appear " + f"more than once to avoid errors or missing exports:\n" + f"{'\n'.join(duplicated_data_names)}" + ) + raise ValueError(msg) diff --git a/src/osekit/core_api/spectro_dataset.py b/src/osekit/core_api/spectro_dataset.py index 796da576..8595874b 100644 --- a/src/osekit/core_api/spectro_dataset.py +++ b/src/osekit/core_api/spectro_dataset.py @@ -148,6 +148,7 @@ def save_spectrogram( """ last = len(self.data) if last is None else last + self._check_duplicate_data_names(first_idx=first, last_idx=last) multiprocess( self._save_spectrogram, self.data[first:last], @@ -361,6 +362,7 @@ def save_all( """ last = len(self.data) if last is None else last + self._check_duplicate_data_names(first_idx=first, last_idx=last) self.data[first:last] = multiprocess( func=self._save_all_, enumerable=self.data[first:last], diff --git a/tests/test_core_api_base.py b/tests/test_core_api_base.py index 139e9cc9..50e8b8d2 100644 --- a/tests/test_core_api_base.py +++ b/tests/test_core_api_base.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextlib import AbstractContextManager, nullcontext from pathlib import Path from typing import Literal @@ -2567,3 +2568,70 @@ def test_dataset_remove_empty_data_threshold_errors() -> None: with pytest.raises(ValueError, match=r"Threshold should be between 0 and 1."): ds.remove_empty_data(threshold=1.5) + + +@pytest.mark.parametrize( + ( + "data_names", + "expected_error_names_dict", + "expected_error_names_export", + "first", + "last", + ), + [ + pytest.param( + ["ken", "kesey", "sometimes"], + nullcontext(), + nullcontext(), + None, + None, + id="no_duplicates_shouldnt_raise", + ), + pytest.param( + ["ken", "ken", "kesey"], + pytest.raises(ValueError, match="ken"), + pytest.raises(ValueError, match="ken"), + None, + None, + id="duplicates_should_raise", + ), + pytest.param( + ["ken", "ken", "kesey", "sometimes"], + pytest.raises(ValueError, match="ken"), + nullcontext(), + 1, + None, + id="duplicates_not_included_in_write_shouldnt_raise_begin", + ), + pytest.param( + ["ken", "kesey", "sometimes", "ken"], + pytest.raises(ValueError, match="ken"), + nullcontext(), + None, + 3, + id="duplicates_not_included_in_write_shouldnt_raise_end", + ), + ], +) +def test_duplicate_data_error( + data_names: list[str], + expected_error_names_dict: AbstractContextManager, + expected_error_names_export: AbstractContextManager, + first: int | None, + last: int | None, +) -> None: + ds = DummyDataset( + [ + DummyData.from_files( + [DummyFile("foo", begin=Timestamp("1997-01-28 00:00:00"))], + name=name, + ) + for name in data_names + ], + ) + + with expected_error_names_dict: + ds.to_dict() + + with expected_error_names_export: + ds.write(Path(r"bar"), first=first, last=last) diff --git a/tests/test_spectro.py b/tests/test_spectro.py index 7e575fc7..51066aa8 100644 --- a/tests/test_spectro.py +++ b/tests/test_spectro.py @@ -1486,12 +1486,17 @@ def patch_collect() -> None: assert collect_calls[0] == 2 # noqa: PLR2004 - sds = SpectroDataset([sd] * 5) + sds = SpectroDataset(sd.split(5)) sds.save_spectrogram(tmp_path / "output") assert collect_calls[0] == 7 # noqa: PLR2004 - ltass = LTASDataset([ltas] * 5) + ltass = LTASDataset( + [ + LTASData.from_spectro_data(sd, nb_time_bins=ltas.nb_time_bins) + for sd in ltas.split(5) + ], + ) ltass.save_spectrogram(tmp_path / "output") assert collect_calls[0] == 12 # noqa: PLR2004 @@ -1872,3 +1877,29 @@ def mock_to_db(self: SpectroData | None, sx: np.ndarray) -> np.ndarray: assert len(get_value_calls) == 1 assert get_value_calls[0] == sd assert np.array_equal(sx_db, mock_to_db(None, sx=sd.get_value())) + + +def test_duplicate_data_check(monkeypatch: pytest.monkeypatch) -> None: + check_calls = [0] + + def mock_check_duplicate_data_names( + *args, # noqa: ANN002 + **kwargs, # noqa: ANN003 + ) -> None: + check_calls[0] += 1 + + monkeypatch.setattr( + SpectroDataset, + "_check_duplicate_data_names", + mock_check_duplicate_data_names, + ) + + sds = SpectroDataset([]) + + sds.save_spectrogram(folder=Path("bantam")) + + assert check_calls[0] == 1 + + sds.save_all(matrix_folder=Path("bantam"), spectrogram_folder=Path("lyons")) + + assert check_calls[0] == 2 # noqa: PLR2004