Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

* Fix `model_signing` breaking changes from `1.0.0` release
* Add `KaggleDatasetAdapter.POLARS` support to `dataset_load`
* Add validation of kwargs to `dataset_load`

## v0.3.11 (April 1, 2025)

Expand Down
42 changes: 38 additions & 4 deletions src/kagglehub/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@
# Patterns that are always ignored for dataset uploading.
DEFAULT_IGNORE_PATTERNS = [".git/", "*/.git/", ".cache/", ".huggingface/"]
# Mapping of adapters to the optional dependencies required to run them
LOAD_DATASET_ADAPTER_OPTIONAL_DEPENDENCIES_MAP = {
DATASET_LOAD_ADAPTER_OPTIONAL_DEPENDENCIES_MAP = {
KaggleDatasetAdapter.HUGGING_FACE: "hf-datasets",
KaggleDatasetAdapter.PANDAS: "pandas-datasets",
KaggleDatasetAdapter.POLARS: "polars-datasets",
}

# Mapping of adapters to the valid kwargs to use for that adapter
_DATASET_LOAD_VALID_KWARGS_MAP = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This can become tedious as the methods get more complex. Instead, we can use the inspect module to get the function signature. See https://docs.python.org/3/library/inspect.html#introspecting-callables-with-the-signature-object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, very cool. When we do the next data loader, I'll make a point to refactor this then.

KaggleDatasetAdapter.HUGGING_FACE: {"hf_kwargs", "pandas_kwargs", "sql_query"},
KaggleDatasetAdapter.PANDAS: {"pandas_kwargs", "sql_query"},
KaggleDatasetAdapter.POLARS: {"sql_query", "polars_frame_type", "polars_kwargs"},
}


def dataset_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
"""Download dataset files
Expand Down Expand Up @@ -80,7 +87,7 @@ def dataset_load(
pandas_kwargs: Any = None, # noqa: ANN401
sql_query: Optional[str] = None,
hf_kwargs: Any = None, # noqa: ANN401
polars_frame_type: PolarsFrameType = PolarsFrameType.LAZY_FRAME,
polars_frame_type: Optional[PolarsFrameType] = None,
polars_kwargs: Any = None, # noqa: ANN401
) -> Any: # noqa: ANN401
"""Load a Kaggle Dataset into a python object based on the selected adapter
Expand All @@ -96,7 +103,8 @@ def dataset_load(
for details: https://pandas.pydata.org/docs/reference/api/pandas.read_sql_query.html
hf_kwargs:
(dict) Optional set of kwargs to pass to Dataset.from_pandas() while constructing the Dataset
polars_frame_type: (PolarsFrameType) Optional value to control what type of frame to return from polars
polars_frame_type: (PolarsFrameType) Optional control for which Frame to return: LazyFrame or DataFrame. The
default is PolarsFrameType.LAZY_FRAME.
polars_kwargs:
(dict) Optional set of kwargs to pass to the polars `read_*` method while constructing the DataFrame(s)
Returns:
Expand All @@ -107,6 +115,15 @@ def dataset_load(
A LazyFrame or DataFrame (or dict[int | str, LazyFrame] / dict[int | str, DataFrame] for Excel-like
files with multiple sheets)
"""
validate_dataset_load_args(
adapter,
pandas_kwargs=pandas_kwargs,
sql_query=sql_query,
hf_kwargs=hf_kwargs,
polars_frame_type=polars_frame_type,
polars_kwargs=polars_kwargs,
)
polars_frame_type = polars_frame_type if polars_frame_type is not None else PolarsFrameType.LAZY_FRAME
try:
if adapter is KaggleDatasetAdapter.HUGGING_FACE:
import kagglehub.hf_datasets
Expand Down Expand Up @@ -134,7 +151,7 @@ def dataset_load(
not_implemented_error_message = f"{adapter} is not yet implemented"
raise NotImplementedError(not_implemented_error_message)
except ImportError:
adapter_optional_dependency = LOAD_DATASET_ADAPTER_OPTIONAL_DEPENDENCIES_MAP[adapter]
adapter_optional_dependency = DATASET_LOAD_ADAPTER_OPTIONAL_DEPENDENCIES_MAP[adapter]
import_warning_message = (
f"The 'dataset_load' function requires the '{adapter_optional_dependency}' extras. "
f"Install them with 'pip install kagglehub[{adapter_optional_dependency}]'"
Expand All @@ -157,3 +174,20 @@ def load_dataset(
"load_dataset is deprecated and will be removed in a future version.", DeprecationWarning, stacklevel=2
)
return dataset_load(adapter, handle, path, pandas_kwargs=pandas_kwargs, sql_query=sql_query, hf_kwargs=hf_kwargs)


def validate_dataset_load_args(
adapter: KaggleDatasetAdapter,
**kwargs: Any, # noqa: ANN401
) -> None:
valid_kwargs = _DATASET_LOAD_VALID_KWARGS_MAP[adapter]
invalid_kwargs_list: list[str] = []
for key, value in kwargs.items():
if key not in valid_kwargs and value is not None:
invalid_kwargs_list.append(key)

if len(invalid_kwargs_list) == 0:
return

invalid_kwargs = ", ".join(invalid_kwargs_list)
logger.warning(f"{invalid_kwargs} are invalid for {adapter} and will be ignored")
63 changes: 62 additions & 1 deletion tests/test_dataset_load.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import io
import logging
import os
from typing import Any
from unittest.mock import MagicMock, patch

import polars as pl
from requests import Response

from kagglehub.datasets import KaggleDatasetAdapter, PolarsFrameType, dataset_load
from kagglehub.datasets import KaggleDatasetAdapter, PolarsFrameType, dataset_load, logger
from kagglehub.exceptions import KaggleApiHTTPError
from tests.fixtures import BaseTestCase

Expand Down Expand Up @@ -42,6 +44,22 @@ def _load_hf_dataset_with_invalid_file_type_and_assert_raises(self) -> None:
)
self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception))

def _load_hf_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None:
output_stream = io.StringIO()
handler = logging.StreamHandler(output_stream)
logger.addHandler(handler)
dataset_load(
KaggleDatasetAdapter.HUGGING_FACE,
DATASET_HANDLE,
AUTO_COMPRESSED_FILE_NAME,
polars_frame_type=PolarsFrameType.LAZY_FRAME,
polars_kwargs={},
)
captured_output = output_stream.getvalue()
self.assertIn(
"polars_frame_type, polars_kwargs are invalid for KaggleDatasetAdapter.HUGGING_FACE", captured_output
)

def _load_hf_dataset_with_multiple_tables_and_assert_raises(self) -> None:
with self.assertRaises(ValueError) as cm:
dataset_load(
Expand Down Expand Up @@ -85,6 +103,10 @@ def _load_hf_dataset_with_splits_and_assert_loaded(self) -> None:
self.assertEqual(TEST_SPLIT_SIZE if split_name == "test" else TRAIN_SPLIT_SIZE, dataset.num_rows)
self.assertEqual(SHAPES_COLUMNS, dataset.column_names)

def test_hf_dataset_with_other_loader_kwargs_prints_warning(self) -> None:
with create_test_cache():
self._load_hf_dataset_with_other_loader_kwargs_and_assert_warning()

def test_hf_dataset_with_invalid_file_type_raises(self) -> None:
with create_test_cache():
self._load_hf_dataset_with_invalid_file_type_and_assert_raises()
Expand Down Expand Up @@ -139,6 +161,23 @@ def _load_pandas_dataset_with_invalid_file_type_and_assert_raises(self) -> None:
)
self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception))

def _load_pandas_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None:
output_stream = io.StringIO()
handler = logging.StreamHandler(output_stream)
logger.addHandler(handler)
dataset_load(
KaggleDatasetAdapter.PANDAS,
DATASET_HANDLE,
AUTO_COMPRESSED_FILE_NAME,
hf_kwargs={},
polars_frame_type=PolarsFrameType.LAZY_FRAME,
polars_kwargs={},
)
captured_output = output_stream.getvalue()
self.assertIn(
"hf_kwargs, polars_frame_type, polars_kwargs are invalid for KaggleDatasetAdapter.PANDAS", captured_output
)

def _load_pandas_simple_dataset_and_assert_loaded(
self,
file_extension: str,
Expand Down Expand Up @@ -187,6 +226,10 @@ def test_pandas_dataset_with_invalid_file_type_raises(self) -> None:
with create_test_cache():
self._load_pandas_dataset_with_invalid_file_type_and_assert_raises()

def test_pandas_dataset_with_other_loader_kwargs_prints_warning(self) -> None:
with create_test_cache():
self._load_pandas_dataset_with_other_loader_kwargs_and_assert_warning()

def test_pandas_dataset_with_multiple_tables_succeeds(self) -> None:
with create_test_cache():
self._load_pandas_dataset_with_multiple_tables_and_assert_loaded()
Expand Down Expand Up @@ -249,6 +292,20 @@ def _load_polars_dataset_with_invalid_file_type_and_assert_raises(self) -> None:
)
self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception))

def _load_polars_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None:
output_stream = io.StringIO()
handler = logging.StreamHandler(output_stream)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the unittest works, but in our code base we use redirect_stdout when working with standard output in unit test.
https://docs.python.org/3/library/contextlib.html#contextlib.redirect_stdout

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I found this example in the code base and based my implementation off of that:

def test_login_returns_403_for_bad_credentials(self) -> None:
output_stream = io.StringIO()
handler = logging.StreamHandler(output_stream)
logger.addHandler(handler)
login("invalid", "invalid")
captured_output = output_stream.getvalue()
self.assertEqual(
captured_output,
"Invalid Kaggle credentials. You can check your credentials on the [Kaggle settings page](https://www.kaggle.com/settings/account).\n",
)

So something like this would be a better pattern to follow?:

def test_kagglehub_console_filter_discards_logrecord(self) -> None:
with TemporaryDirectory() as f:
log_path = Path(f) / "test-log"
logger = logging.getLogger("kagglehub")
stream = StringIO()
with redirect_stdout(stream):
# reconfigure logger, otherwise streamhandler doesnt use the modified stderr
_configure_logger(log_path)
logger.info("HIDE", extra={**EXTRA_CONSOLE_BLOCK})
logger.info("SHOW")
self.assertEqual(stream.getvalue(), "SHOW\n")

logger.addHandler(handler)
dataset_load(
KaggleDatasetAdapter.POLARS,
DATASET_HANDLE,
AUTO_COMPRESSED_FILE_NAME,
pandas_kwargs={},
hf_kwargs={},
)
captured_output = output_stream.getvalue()
self.assertIn("pandas_kwargs, hf_kwargs are invalid for KaggleDatasetAdapter.POLARS", captured_output)

def _load_polars_simple_dataset_and_assert_loaded(
self,
file_extension: str,
Expand Down Expand Up @@ -332,6 +389,10 @@ def test_polars_dataset_with_invalid_file_type_raises(self) -> None:
with create_test_cache():
self._load_polars_dataset_with_invalid_file_type_and_assert_raises()

def test_polars_dataset_with_other_loader_kwargs_prints_warning(self) -> None:
with create_test_cache():
self._load_polars_dataset_with_other_loader_kwargs_and_assert_warning()

def test_polars_dataset_with_multiple_tables_succeeds(self) -> None:
with create_test_cache():
self._load_polars_dataset_with_multiple_tables_and_assert_loaded(PolarsFrameType.LAZY_FRAME)
Expand Down