-
Notifications
You must be signed in to change notification settings - Fork 87
Add validation of loader kwargs to dataset_load
#241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,11 +1,13 @@ | ||||||||||||||||||||||||||||||||||||||||||||||
| import io | ||||||||||||||||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any | ||||||||||||||||||||||||||||||||||||||||||||||
| from unittest.mock import MagicMock, patch | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| import polars as pl | ||||||||||||||||||||||||||||||||||||||||||||||
| from requests import Response | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| from kagglehub.datasets import KaggleDatasetAdapter, PolarsFrameType, dataset_load | ||||||||||||||||||||||||||||||||||||||||||||||
| from kagglehub.datasets import KaggleDatasetAdapter, PolarsFrameType, dataset_load, logger | ||||||||||||||||||||||||||||||||||||||||||||||
| from kagglehub.exceptions import KaggleApiHTTPError | ||||||||||||||||||||||||||||||||||||||||||||||
| from tests.fixtures import BaseTestCase | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -42,6 +44,22 @@ def _load_hf_dataset_with_invalid_file_type_and_assert_raises(self) -> None: | |||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception)) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _load_hf_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| output_stream = io.StringIO() | ||||||||||||||||||||||||||||||||||||||||||||||
| handler = logging.StreamHandler(output_stream) | ||||||||||||||||||||||||||||||||||||||||||||||
| logger.addHandler(handler) | ||||||||||||||||||||||||||||||||||||||||||||||
| dataset_load( | ||||||||||||||||||||||||||||||||||||||||||||||
| KaggleDatasetAdapter.HUGGING_FACE, | ||||||||||||||||||||||||||||||||||||||||||||||
| DATASET_HANDLE, | ||||||||||||||||||||||||||||||||||||||||||||||
| AUTO_COMPRESSED_FILE_NAME, | ||||||||||||||||||||||||||||||||||||||||||||||
| polars_frame_type=PolarsFrameType.LAZY_FRAME, | ||||||||||||||||||||||||||||||||||||||||||||||
| polars_kwargs={}, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| captured_output = output_stream.getvalue() | ||||||||||||||||||||||||||||||||||||||||||||||
| self.assertIn( | ||||||||||||||||||||||||||||||||||||||||||||||
| "polars_frame_type, polars_kwargs are invalid for KaggleDatasetAdapter.HUGGING_FACE", captured_output | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _load_hf_dataset_with_multiple_tables_and_assert_raises(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| with self.assertRaises(ValueError) as cm: | ||||||||||||||||||||||||||||||||||||||||||||||
| dataset_load( | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -85,6 +103,10 @@ def _load_hf_dataset_with_splits_and_assert_loaded(self) -> None: | |||||||||||||||||||||||||||||||||||||||||||||
| self.assertEqual(TEST_SPLIT_SIZE if split_name == "test" else TRAIN_SPLIT_SIZE, dataset.num_rows) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.assertEqual(SHAPES_COLUMNS, dataset.column_names) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def test_hf_dataset_with_other_loader_kwargs_prints_warning(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| with create_test_cache(): | ||||||||||||||||||||||||||||||||||||||||||||||
| self._load_hf_dataset_with_other_loader_kwargs_and_assert_warning() | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def test_hf_dataset_with_invalid_file_type_raises(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| with create_test_cache(): | ||||||||||||||||||||||||||||||||||||||||||||||
| self._load_hf_dataset_with_invalid_file_type_and_assert_raises() | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -139,6 +161,23 @@ def _load_pandas_dataset_with_invalid_file_type_and_assert_raises(self) -> None: | |||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception)) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _load_pandas_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| output_stream = io.StringIO() | ||||||||||||||||||||||||||||||||||||||||||||||
| handler = logging.StreamHandler(output_stream) | ||||||||||||||||||||||||||||||||||||||||||||||
| logger.addHandler(handler) | ||||||||||||||||||||||||||||||||||||||||||||||
| dataset_load( | ||||||||||||||||||||||||||||||||||||||||||||||
| KaggleDatasetAdapter.PANDAS, | ||||||||||||||||||||||||||||||||||||||||||||||
| DATASET_HANDLE, | ||||||||||||||||||||||||||||||||||||||||||||||
| AUTO_COMPRESSED_FILE_NAME, | ||||||||||||||||||||||||||||||||||||||||||||||
| hf_kwargs={}, | ||||||||||||||||||||||||||||||||||||||||||||||
| polars_frame_type=PolarsFrameType.LAZY_FRAME, | ||||||||||||||||||||||||||||||||||||||||||||||
| polars_kwargs={}, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| captured_output = output_stream.getvalue() | ||||||||||||||||||||||||||||||||||||||||||||||
| self.assertIn( | ||||||||||||||||||||||||||||||||||||||||||||||
| "hf_kwargs, polars_frame_type, polars_kwargs are invalid for KaggleDatasetAdapter.PANDAS", captured_output | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _load_pandas_simple_dataset_and_assert_loaded( | ||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||
| file_extension: str, | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -187,6 +226,10 @@ def test_pandas_dataset_with_invalid_file_type_raises(self) -> None: | |||||||||||||||||||||||||||||||||||||||||||||
| with create_test_cache(): | ||||||||||||||||||||||||||||||||||||||||||||||
| self._load_pandas_dataset_with_invalid_file_type_and_assert_raises() | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def test_pandas_dataset_with_other_loader_kwargs_prints_warning(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| with create_test_cache(): | ||||||||||||||||||||||||||||||||||||||||||||||
| self._load_pandas_dataset_with_other_loader_kwargs_and_assert_warning() | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def test_pandas_dataset_with_multiple_tables_succeeds(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| with create_test_cache(): | ||||||||||||||||||||||||||||||||||||||||||||||
| self._load_pandas_dataset_with_multiple_tables_and_assert_loaded() | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -249,6 +292,20 @@ def _load_polars_dataset_with_invalid_file_type_and_assert_raises(self) -> None: | |||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.assertIn(f"Unsupported file extension: '{os.path.splitext(TEXT_FILE)[1]}'", str(cm.exception)) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _load_polars_dataset_with_other_loader_kwargs_and_assert_warning(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| output_stream = io.StringIO() | ||||||||||||||||||||||||||||||||||||||||||||||
| handler = logging.StreamHandler(output_stream) | ||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the unittest works, but in our code base we use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see. I found this example in the code base and based my implementation off of that: Lines 57 to 67 in 4fe7fd0
So something like this would be a better pattern to follow?: kagglehub/tests/test_logger.py Lines 45 to 55 in 4fe7fd0
|
||||||||||||||||||||||||||||||||||||||||||||||
| logger.addHandler(handler) | ||||||||||||||||||||||||||||||||||||||||||||||
| dataset_load( | ||||||||||||||||||||||||||||||||||||||||||||||
| KaggleDatasetAdapter.POLARS, | ||||||||||||||||||||||||||||||||||||||||||||||
| DATASET_HANDLE, | ||||||||||||||||||||||||||||||||||||||||||||||
| AUTO_COMPRESSED_FILE_NAME, | ||||||||||||||||||||||||||||||||||||||||||||||
| pandas_kwargs={}, | ||||||||||||||||||||||||||||||||||||||||||||||
| hf_kwargs={}, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| captured_output = output_stream.getvalue() | ||||||||||||||||||||||||||||||||||||||||||||||
| self.assertIn("pandas_kwargs, hf_kwargs are invalid for KaggleDatasetAdapter.POLARS", captured_output) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _load_polars_simple_dataset_and_assert_loaded( | ||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||
| file_extension: str, | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -332,6 +389,10 @@ def test_polars_dataset_with_invalid_file_type_raises(self) -> None: | |||||||||||||||||||||||||||||||||||||||||||||
| with create_test_cache(): | ||||||||||||||||||||||||||||||||||||||||||||||
| self._load_polars_dataset_with_invalid_file_type_and_assert_raises() | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def test_polars_dataset_with_other_loader_kwargs_prints_warning(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| with create_test_cache(): | ||||||||||||||||||||||||||||||||||||||||||||||
| self._load_polars_dataset_with_other_loader_kwargs_and_assert_warning() | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def test_polars_dataset_with_multiple_tables_succeeds(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| with create_test_cache(): | ||||||||||||||||||||||||||||||||||||||||||||||
| self._load_polars_dataset_with_multiple_tables_and_assert_loaded(PolarsFrameType.LAZY_FRAME) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This can become tedious as the methods get more complex. Instead, we can use the inspect module to get the function signature. See https://docs.python.org/3/library/inspect.html#introspecting-callables-with-the-signature-object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, very cool. When we do the next data loader, I'll make a point to refactor this then.