From bff87472624070a68037bb229f8dcab66f216c14 Mon Sep 17 00:00:00 2001 From: Antithetical Date: Mon, 17 Nov 2025 11:44:35 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B(huggingface=5Fdataset=5Fto=5Fplaid?= =?UTF-8?q?):=20add=20option=20for=20silent=20replacement=20with=20warn=20?= =?UTF-8?q?flag?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 2 ++ src/plaid/bridges/huggingface_bridge.py | 2 +- src/plaid/containers/dataset.py | 7 +++++-- tests/bridges/test_huggingface_bridge.py | 14 ++++++++++++++ tests/containers/test_dataset.py | 21 +++++++++++++++++++++ 5 files changed, 43 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fd4fcd4..187df10a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixes +- (dataset/huggingface_bridge) add optional `warn` parameter to `Dataset.set_infos()` to allow silent replacement of infos; `huggingface_dataset_to_plaid` now uses `warn=False` to prevent unnecessary warnings + ### Removed ## [0.1.10] - 2025-10-29 diff --git a/src/plaid/bridges/huggingface_bridge.py b/src/plaid/bridges/huggingface_bridge.py index dd1e71e6..424d80d1 100644 --- a/src/plaid/bridges/huggingface_bridge.py +++ b/src/plaid/bridges/huggingface_bridge.py @@ -1535,7 +1535,7 @@ def parallel_convert(shard_path, n_workers): infos = huggingface_description_to_infos(ds.description) - dataset.set_infos(infos) + dataset.set_infos(infos, warn=False) problem_definition = huggingface_description_to_problem_definition(ds.description) diff --git a/src/plaid/containers/dataset.py b/src/plaid/containers/dataset.py index 45f80f2a..ebb306a5 100644 --- a/src/plaid/containers/dataset.py +++ b/src/plaid/containers/dataset.py @@ -931,11 +931,12 @@ def add_infos(self, cat_key: str, infos: dict[str, str]) -> None: for key, value in infos.items(): self._infos[cat_key][key] = value - def set_infos(self, infos: dict[str, dict[str, str]]) -> None: + def set_infos(self, infos: dict[str, dict[str, str]], warn: bool = True) -> None: """Set information to the :class:`Dataset `, overwriting the existing one. Args: infos (dict[str,dict[str,str]]): Information to associate with this data set (Dataset). + warn (bool, optional): If True, warns when replacing existing infos. Defaults to True. Raises: KeyError: Invalid category key format in provided infos. @@ -963,7 +964,9 @@ def set_infos(self, infos: dict[str, dict[str, str]]) -> None: f"{info_key=} not among authorized keys. Maybe you want to try among these keys {AUTHORIZED_INFO_KEYS[cat_key]}" ) - if len(self._infos) > 0: + # Check if there are any non-plaid infos being replaced + has_user_infos = any(key != "plaid" for key in self._infos.keys()) + if has_user_infos and warn: logger.warning("infos not empty, replacing it anyway") self._infos = copy.deepcopy(infos) diff --git a/tests/bridges/test_huggingface_bridge.py b/tests/bridges/test_huggingface_bridge.py index 452aca6d..82032758 100644 --- a/tests/bridges/test_huggingface_bridge.py +++ b/tests/bridges/test_huggingface_bridge.py @@ -272,6 +272,20 @@ def test_huggingface_dataset_to_plaid(self, hf_dataset): ds, _ = huggingface_bridge.huggingface_dataset_to_plaid(hf_dataset) self.assert_plaid_dataset(ds) + def test_huggingface_dataset_to_plaid_no_warning(self, hf_dataset, caplog): + """Test that huggingface_dataset_to_plaid does not trigger infos replacement warning.""" + import logging + + with caplog.at_level(logging.WARNING): + ds, _ = huggingface_bridge.huggingface_dataset_to_plaid( + hf_dataset, verbose=False + ) + + # Should not warn about replacing infos + assert "infos not empty, replacing it anyway" not in caplog.text + # Dataset should still be valid + self.assert_plaid_dataset(ds) + def test_huggingface_dataset_to_plaid_with_ids_binary(self, hf_dataset): huggingface_bridge.huggingface_dataset_to_plaid(hf_dataset, ids=[0, 1]) diff --git a/tests/containers/test_dataset.py b/tests/containers/test_dataset.py index ac7e2704..9ee05273 100644 --- a/tests/containers/test_dataset.py +++ b/tests/containers/test_dataset.py @@ -858,6 +858,27 @@ def test_set_infos(self, dataset, infos): {"legal": {"illegal_info_key": "PLAID2", "license": "BSD-3"}} ) + def test_set_infos_warn_parameter(self, dataset, infos, caplog): + """Test the warn parameter for silent replacement of infos.""" + import logging + + # First set should not warn (no user infos to replace) + with caplog.at_level(logging.WARNING): + dataset.set_infos(infos) + assert "infos not empty, replacing it anyway" not in caplog.text + + # Second set with warn=True (default) should warn + caplog.clear() + with caplog.at_level(logging.WARNING): + dataset.set_infos({"legal": {"owner": "Owner2"}}) + assert "infos not empty, replacing it anyway" in caplog.text + + # Third set with warn=False should not warn + caplog.clear() + with caplog.at_level(logging.WARNING): + dataset.set_infos({"legal": {"owner": "Owner3"}}, warn=False) + assert "infos not empty, replacing it anyway" not in caplog.text + def test_get_infos(self, dataset): assert dataset.get_infos()["plaid"]["version"] == str( Version(plaid.__version__)