From 2ceb9bfdf273526717893afeb3c13a7eac0cb250 Mon Sep 17 00:00:00 2001 From: Goeff Thomas Date: Tue, 15 Apr 2025 23:55:04 +0000 Subject: [PATCH 1/2] Add validation of loader kwargs to `dataset_load` As discussed on #238: https://github.com/Kaggle/kagglehub/pull/238#discussion_r2027399457 http://b/388077145 --- CHANGELOG.md | 1 + src/kagglehub/datasets.py | 44 ++++++++++++++++++++++++++++++--- tests/test_dataset_load.py | 50 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f6bd332..fc00c6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ * Fix `model_signing` breaking changes from `1.0.0` release * Add `KaggleDatasetAdapter.POLARS` support to `dataset_load` +* Add validation of kwargs to `dataset_load` ## v0.3.11 (April 1, 2025) diff --git a/src/kagglehub/datasets.py b/src/kagglehub/datasets.py index cddbd39..ae8ae34 100755 --- a/src/kagglehub/datasets.py +++ b/src/kagglehub/datasets.py @@ -15,12 +15,19 @@ # Patterns that are always ignored for dataset uploading. DEFAULT_IGNORE_PATTERNS = [".git/", "*/.git/", ".cache/", ".huggingface/"] # Mapping of adapters to the optional dependencies required to run them -LOAD_DATASET_ADAPTER_OPTIONAL_DEPENDENCIES_MAP = { +DATASET_LOAD_ADAPTER_OPTIONAL_DEPENDENCIES_MAP = { KaggleDatasetAdapter.HUGGING_FACE: "hf-datasets", KaggleDatasetAdapter.PANDAS: "pandas-datasets", KaggleDatasetAdapter.POLARS: "polars-datasets", } +# Mapping of adapters to the valid kwargs to use for that adapter +DATASET_LOAD_VALID_KWARGS_MAP = { + KaggleDatasetAdapter.HUGGING_FACE: {"hf_kwargs", "pandas_kwargs", "sql_query"}, + KaggleDatasetAdapter.PANDAS: {"pandas_kwargs", "sql_query"}, + KaggleDatasetAdapter.POLARS: {"sql_query", "polars_frame_type", "polars_kwargs"}, +} + def dataset_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: """Download dataset files @@ -80,7 +87,7 @@ def dataset_load( pandas_kwargs: Any = None, # noqa: ANN401 sql_query: Optional[str] = None, hf_kwargs: Any = None, # noqa: ANN401 - polars_frame_type: PolarsFrameType = PolarsFrameType.LAZY_FRAME, + polars_frame_type: Optional[PolarsFrameType] = None, polars_kwargs: Any = None, # noqa: ANN401 ) -> Any: # noqa: ANN401 """Load a Kaggle Dataset into a python object based on the selected adapter @@ -96,7 +103,8 @@ def dataset_load( for details: https://pandas.pydata.org/docs/reference/api/pandas.read_sql_query.html hf_kwargs: (dict) Optional set of kwargs to pass to Dataset.from_pandas() while constructing the Dataset - polars_frame_type: (PolarsFrameType) Optional value to control what type of frame to return from polars + polars_frame_type: (PolarsFrameType) Optional control for which Frame to return: LazyFrame or DataFrame. The + default is PolarsFrameType.LAZY_FRAME. polars_kwargs: (dict) Optional set of kwargs to pass to the polars `read_*` method while constructing the DataFrame(s) Returns: @@ -107,6 +115,16 @@ def dataset_load( A LazyFrame or DataFrame (or dict[int | str, LazyFrame] / dict[int | str, DataFrame] for Excel-like files with multiple sheets) """ + validate_dataset_load_args( + adapter, + pandas_kwargs=pandas_kwargs, + sql_query=sql_query, + hf_kwargs=hf_kwargs, + polars_frame_type=polars_frame_type, + polars_kwargs=polars_kwargs, + ) + # Define the default behavior internally so we can assess kwarg validity in validate_dataset_load_args above + polars_frame_type = polars_frame_type if polars_frame_type is not None else PolarsFrameType.LAZY_FRAME try: if adapter is KaggleDatasetAdapter.HUGGING_FACE: import kagglehub.hf_datasets @@ -134,7 +152,7 @@ def dataset_load( not_implemented_error_message = f"{adapter} is not yet implemented" raise NotImplementedError(not_implemented_error_message) except ImportError: - adapter_optional_dependency = LOAD_DATASET_ADAPTER_OPTIONAL_DEPENDENCIES_MAP[adapter] + adapter_optional_dependency = DATASET_LOAD_ADAPTER_OPTIONAL_DEPENDENCIES_MAP[adapter] import_warning_message = ( f"The 'dataset_load' function requires the '{adapter_optional_dependency}' extras. " f"Install them with 'pip install kagglehub[{adapter_optional_dependency}]'" @@ -157,3 +175,21 @@ def load_dataset( "load_dataset is deprecated and will be removed in a future version.", DeprecationWarning, stacklevel=2 ) return dataset_load(adapter, handle, path, pandas_kwargs=pandas_kwargs, sql_query=sql_query, hf_kwargs=hf_kwargs) + + +def validate_dataset_load_args( + adapter: KaggleDatasetAdapter, + **kwargs: Any, # noqa: ANN401 +) -> None: + valid_kwargs = DATASET_LOAD_VALID_KWARGS_MAP[adapter] + invalid_kwargs_list: list[str] = [] + for key, value in kwargs.items(): + if key not in valid_kwargs and value is not None: + invalid_kwargs_list.append(key) + + if len(invalid_kwargs_list) == 0: + return + + invalid_kwargs = ", ".join(invalid_kwargs_list) + invalid_kwargs_msg = f"{invalid_kwargs} are invalid for {adapter}" + raise ValueError(invalid_kwargs_msg) from None diff --git a/tests/test_dataset_load.py b/tests/test_dataset_load.py index 9f2e496..9a35e66 100644 --- a/tests/test_dataset_load.py +++ b/tests/test_dataset_load.py @@ -42,6 +42,19 @@ def _load_hf_dataset_with_invalid_file_type_and_assert_raises(self) -> None: ) self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception)) + def _load_hf_dataset_with_other_loader_kwargs_and_assert_raises(self) -> None: + with self.assertRaises(ValueError) as cm: + dataset_load( + KaggleDatasetAdapter.HUGGING_FACE, + DATASET_HANDLE, + TEXT_FILE, + polars_frame_type=PolarsFrameType.LAZY_FRAME, + polars_kwargs={}, + ) + self.assertIn( + "polars_frame_type, polars_kwargs are invalid for KaggleDatasetAdapter.HUGGING_FACE", str(cm.exception) + ) + def _load_hf_dataset_with_multiple_tables_and_assert_raises(self) -> None: with self.assertRaises(ValueError) as cm: dataset_load( @@ -85,6 +98,10 @@ def _load_hf_dataset_with_splits_and_assert_loaded(self) -> None: self.assertEqual(TEST_SPLIT_SIZE if split_name == "test" else TRAIN_SPLIT_SIZE, dataset.num_rows) self.assertEqual(SHAPES_COLUMNS, dataset.column_names) + def test_hf_dataset_with_other_loader_kwargs_raises(self) -> None: + with create_test_cache(): + self._load_hf_dataset_with_other_loader_kwargs_and_assert_raises() + def test_hf_dataset_with_invalid_file_type_raises(self) -> None: with create_test_cache(): self._load_hf_dataset_with_invalid_file_type_and_assert_raises() @@ -139,6 +156,20 @@ def _load_pandas_dataset_with_invalid_file_type_and_assert_raises(self) -> None: ) self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception)) + def _load_pandas_dataset_with_other_loader_kwargs_and_assert_raises(self) -> None: + with self.assertRaises(ValueError) as cm: + dataset_load( + KaggleDatasetAdapter.PANDAS, + DATASET_HANDLE, + TEXT_FILE, + hf_kwargs={}, + polars_frame_type=PolarsFrameType.LAZY_FRAME, + polars_kwargs={}, + ) + self.assertIn( + "hf_kwargs, polars_frame_type, polars_kwargs are invalid for KaggleDatasetAdapter.PANDAS", str(cm.exception) + ) + def _load_pandas_simple_dataset_and_assert_loaded( self, file_extension: str, @@ -187,6 +218,10 @@ def test_pandas_dataset_with_invalid_file_type_raises(self) -> None: with create_test_cache(): self._load_pandas_dataset_with_invalid_file_type_and_assert_raises() + def test_pandas_dataset_with_other_loader_kwargs_raises(self) -> None: + with create_test_cache(): + self._load_pandas_dataset_with_other_loader_kwargs_and_assert_raises() + def test_pandas_dataset_with_multiple_tables_succeeds(self) -> None: with create_test_cache(): self._load_pandas_dataset_with_multiple_tables_and_assert_loaded() @@ -249,6 +284,17 @@ def _load_polars_dataset_with_invalid_file_type_and_assert_raises(self) -> None: ) self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception)) + def _load_polars_dataset_with_other_loader_kwargs_and_assert_raises(self) -> None: + with self.assertRaises(ValueError) as cm: + dataset_load( + KaggleDatasetAdapter.POLARS, + DATASET_HANDLE, + TEXT_FILE, + pandas_kwargs={}, + hf_kwargs={}, + ) + self.assertIn("pandas_kwargs, hf_kwargs are invalid for KaggleDatasetAdapter.POLARS", str(cm.exception)) + def _load_polars_simple_dataset_and_assert_loaded( self, file_extension: str, @@ -332,6 +378,10 @@ def test_polars_dataset_with_invalid_file_type_raises(self) -> None: with create_test_cache(): self._load_polars_dataset_with_invalid_file_type_and_assert_raises() + def test_polars_dataset_with_other_loader_kwargs_raises(self) -> None: + with create_test_cache(): + self._load_polars_dataset_with_other_loader_kwargs_and_assert_raises() + def test_polars_dataset_with_multiple_tables_succeeds(self) -> None: with create_test_cache(): self._load_polars_dataset_with_multiple_tables_and_assert_loaded(PolarsFrameType.LAZY_FRAME) From 78c5d313e02b4ecdf6d27818540d5821f5bb4db7 Mon Sep 17 00:00:00 2001 From: Goeff Thomas Date: Mon, 21 Apr 2025 17:14:28 +0000 Subject: [PATCH 2/2] PR feedback --- src/kagglehub/datasets.py | 8 ++-- tests/test_dataset_load.py | 87 +++++++++++++++++++++----------------- 2 files changed, 52 insertions(+), 43 deletions(-) diff --git a/src/kagglehub/datasets.py b/src/kagglehub/datasets.py index ae8ae34..51787f1 100755 --- a/src/kagglehub/datasets.py +++ b/src/kagglehub/datasets.py @@ -22,7 +22,7 @@ } # Mapping of adapters to the valid kwargs to use for that adapter -DATASET_LOAD_VALID_KWARGS_MAP = { +_DATASET_LOAD_VALID_KWARGS_MAP = { KaggleDatasetAdapter.HUGGING_FACE: {"hf_kwargs", "pandas_kwargs", "sql_query"}, KaggleDatasetAdapter.PANDAS: {"pandas_kwargs", "sql_query"}, KaggleDatasetAdapter.POLARS: {"sql_query", "polars_frame_type", "polars_kwargs"}, @@ -123,7 +123,6 @@ def dataset_load( polars_frame_type=polars_frame_type, polars_kwargs=polars_kwargs, ) - # Define the default behavior internally so we can assess kwarg validity in validate_dataset_load_args above polars_frame_type = polars_frame_type if polars_frame_type is not None else PolarsFrameType.LAZY_FRAME try: if adapter is KaggleDatasetAdapter.HUGGING_FACE: @@ -181,7 +180,7 @@ def validate_dataset_load_args( adapter: KaggleDatasetAdapter, **kwargs: Any, # noqa: ANN401 ) -> None: - valid_kwargs = DATASET_LOAD_VALID_KWARGS_MAP[adapter] + valid_kwargs = _DATASET_LOAD_VALID_KWARGS_MAP[adapter] invalid_kwargs_list: list[str] = [] for key, value in kwargs.items(): if key not in valid_kwargs and value is not None: @@ -191,5 +190,4 @@ def validate_dataset_load_args( return invalid_kwargs = ", ".join(invalid_kwargs_list) - invalid_kwargs_msg = f"{invalid_kwargs} are invalid for {adapter}" - raise ValueError(invalid_kwargs_msg) from None + logger.warning(f"{invalid_kwargs} are invalid for {adapter} and will be ignored") diff --git a/tests/test_dataset_load.py b/tests/test_dataset_load.py index 9a35e66..bf4c57d 100644 --- a/tests/test_dataset_load.py +++ b/tests/test_dataset_load.py @@ -1,3 +1,5 @@ +import io +import logging import os from typing import Any from unittest.mock import MagicMock, patch @@ -5,7 +7,7 @@ import polars as pl from requests import Response -from kagglehub.datasets import KaggleDatasetAdapter, PolarsFrameType, dataset_load +from kagglehub.datasets import KaggleDatasetAdapter, PolarsFrameType, dataset_load, logger from kagglehub.exceptions import KaggleApiHTTPError from tests.fixtures import BaseTestCase @@ -42,17 +44,20 @@ def _load_hf_dataset_with_invalid_file_type_and_assert_raises(self) -> None: ) self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception)) - def _load_hf_dataset_with_other_loader_kwargs_and_assert_raises(self) -> None: - with self.assertRaises(ValueError) as cm: - dataset_load( - KaggleDatasetAdapter.HUGGING_FACE, - DATASET_HANDLE, - TEXT_FILE, - polars_frame_type=PolarsFrameType.LAZY_FRAME, - polars_kwargs={}, - ) + def _load_hf_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None: + output_stream = io.StringIO() + handler = logging.StreamHandler(output_stream) + logger.addHandler(handler) + dataset_load( + KaggleDatasetAdapter.HUGGING_FACE, + DATASET_HANDLE, + AUTO_COMPRESSED_FILE_NAME, + polars_frame_type=PolarsFrameType.LAZY_FRAME, + polars_kwargs={}, + ) + captured_output = output_stream.getvalue() self.assertIn( - "polars_frame_type, polars_kwargs are invalid for KaggleDatasetAdapter.HUGGING_FACE", str(cm.exception) + "polars_frame_type, polars_kwargs are invalid for KaggleDatasetAdapter.HUGGING_FACE", captured_output ) def _load_hf_dataset_with_multiple_tables_and_assert_raises(self) -> None: @@ -98,9 +103,9 @@ def _load_hf_dataset_with_splits_and_assert_loaded(self) -> None: self.assertEqual(TEST_SPLIT_SIZE if split_name == "test" else TRAIN_SPLIT_SIZE, dataset.num_rows) self.assertEqual(SHAPES_COLUMNS, dataset.column_names) - def test_hf_dataset_with_other_loader_kwargs_raises(self) -> None: + def test_hf_dataset_with_other_loader_kwargs_prints_warning(self) -> None: with create_test_cache(): - self._load_hf_dataset_with_other_loader_kwargs_and_assert_raises() + self._load_hf_dataset_with_other_loader_kwargs_and_assert_warning() def test_hf_dataset_with_invalid_file_type_raises(self) -> None: with create_test_cache(): @@ -156,18 +161,21 @@ def _load_pandas_dataset_with_invalid_file_type_and_assert_raises(self) -> None: ) self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception)) - def _load_pandas_dataset_with_other_loader_kwargs_and_assert_raises(self) -> None: - with self.assertRaises(ValueError) as cm: - dataset_load( - KaggleDatasetAdapter.PANDAS, - DATASET_HANDLE, - TEXT_FILE, - hf_kwargs={}, - polars_frame_type=PolarsFrameType.LAZY_FRAME, - polars_kwargs={}, - ) + def _load_pandas_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None: + output_stream = io.StringIO() + handler = logging.StreamHandler(output_stream) + logger.addHandler(handler) + dataset_load( + KaggleDatasetAdapter.PANDAS, + DATASET_HANDLE, + AUTO_COMPRESSED_FILE_NAME, + hf_kwargs={}, + polars_frame_type=PolarsFrameType.LAZY_FRAME, + polars_kwargs={}, + ) + captured_output = output_stream.getvalue() self.assertIn( - "hf_kwargs, polars_frame_type, polars_kwargs are invalid for KaggleDatasetAdapter.PANDAS", str(cm.exception) + "hf_kwargs, polars_frame_type, polars_kwargs are invalid for KaggleDatasetAdapter.PANDAS", captured_output ) def _load_pandas_simple_dataset_and_assert_loaded( @@ -218,9 +226,9 @@ def test_pandas_dataset_with_invalid_file_type_raises(self) -> None: with create_test_cache(): self._load_pandas_dataset_with_invalid_file_type_and_assert_raises() - def test_pandas_dataset_with_other_loader_kwargs_raises(self) -> None: + def test_pandas_dataset_with_other_loader_kwargs_prints_warning(self) -> None: with create_test_cache(): - self._load_pandas_dataset_with_other_loader_kwargs_and_assert_raises() + self._load_pandas_dataset_with_other_loader_kwargs_and_assert_warning() def test_pandas_dataset_with_multiple_tables_succeeds(self) -> None: with create_test_cache(): @@ -284,16 +292,19 @@ def _load_polars_dataset_with_invalid_file_type_and_assert_raises(self) -> None: ) self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception)) - def _load_polars_dataset_with_other_loader_kwargs_and_assert_raises(self) -> None: - with self.assertRaises(ValueError) as cm: - dataset_load( - KaggleDatasetAdapter.POLARS, - DATASET_HANDLE, - TEXT_FILE, - pandas_kwargs={}, - hf_kwargs={}, - ) - self.assertIn("pandas_kwargs, hf_kwargs are invalid for KaggleDatasetAdapter.POLARS", str(cm.exception)) + def _load_polars_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None: + output_stream = io.StringIO() + handler = logging.StreamHandler(output_stream) + logger.addHandler(handler) + dataset_load( + KaggleDatasetAdapter.POLARS, + DATASET_HANDLE, + AUTO_COMPRESSED_FILE_NAME, + pandas_kwargs={}, + hf_kwargs={}, + ) + captured_output = output_stream.getvalue() + self.assertIn("pandas_kwargs, hf_kwargs are invalid for KaggleDatasetAdapter.POLARS", captured_output) def _load_polars_simple_dataset_and_assert_loaded( self, @@ -378,9 +389,9 @@ def test_polars_dataset_with_invalid_file_type_raises(self) -> None: with create_test_cache(): self._load_polars_dataset_with_invalid_file_type_and_assert_raises() - def test_polars_dataset_with_other_loader_kwargs_raises(self) -> None: + def test_polars_dataset_with_other_loader_kwargs_prints_warning(self) -> None: with create_test_cache(): - self._load_polars_dataset_with_other_loader_kwargs_and_assert_raises() + self._load_polars_dataset_with_other_loader_kwargs_and_assert_warning() def test_polars_dataset_with_multiple_tables_succeeds(self) -> None: with create_test_cache():