Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions paimon-python/pypaimon/daft/daft_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def __init__(
table_path = getattr(table, "table_path", None)
self._table_path = str(table_path) if table_path is not None else None
self._table_options = _extract_table_options(table)
self._pushed_filters: list[PyExpr] | None = None
self._paimon_predicate: Predicate | None = None
self._remaining_filters: list[PyExpr] | None = None
self._init_table(table)
Expand All @@ -295,6 +296,7 @@ def __getstate__(self) -> dict[str, Any]:
"_table_identifier": self._table_identifier,
"_table_path": self._table_path,
"_table_options": self._table_options,
"_pushed_filters": self._pushed_filters,
"_paimon_predicate": self._paimon_predicate,
"_remaining_filters": self._remaining_filters,
}
Expand All @@ -305,6 +307,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
self._table_identifier = state["_table_identifier"]
self._table_path = state["_table_path"]
self._table_options = state["_table_options"]
self._pushed_filters = state.get("_pushed_filters")
self._paimon_predicate = state["_paimon_predicate"]
self._remaining_filters = state["_remaining_filters"]
self._storage_config = _build_storage_config(
Expand Down Expand Up @@ -361,10 +364,6 @@ def _init_table(self, table: FileStoreTable) -> None:
else {}
)

self._pushed_filters: list[PyExpr] | None = None
self._paimon_predicate: Predicate | None = None
self._remaining_filters: list[PyExpr] | None = None

@property
def name(self) -> str:
table_path = getattr(self._table, "table_path", None)
Expand Down
66 changes: 57 additions & 9 deletions paimon-python/pypaimon/tests/daft/daft_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@ def _write_to_paimon(table, arrow_table, mode="append", overwrite_partition=None
table_commit.close()


async def _collect_paimon_source_batches(source, pushdowns):
batches = []
fallback_task_count = 0
async for task in source.get_tasks(pushdowns):
if type(task).__name__ == "_PaimonPKSplitTask":
fallback_task_count += 1
async for batch in task.read():
batches.append(batch.to_pydict())
assert fallback_task_count > 0
return batches


async def _read_paimon_source_batches(
table,
filter_expr=None,
Expand All @@ -111,16 +123,8 @@ async def _read_paimon_source_batches(
assert pushed_filters
assert not remaining_filters

batches = []
fallback_task_count = 0
pushdowns = Pushdowns(filters=filter_expr, columns=columns, limit=limit)
async for task in source.get_tasks(pushdowns):
if type(task).__name__ == "_PaimonPKSplitTask":
fallback_task_count += 1
async for batch in task.read():
batches.append(batch.to_pydict())
assert fallback_task_count > 0
return batches
return await _collect_paimon_source_batches(source, pushdowns)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -245,6 +249,50 @@ def test_read_paimon_source_is_serializable(append_only_table):
assert restored._storage_config.multithreaded_io is False


def test_read_paimon_source_serialization_preserves_pushed_filter_for_fallback(local_paimon_catalog):
"""A serialized source must keep filters accepted by SupportsPushdownFilters."""
from daft import context, runners
from daft.daft import StorageConfig
from daft.io.pushdowns import Pushdowns
from daft.pickle import dumps, loads

from pypaimon.daft.daft_datasource import PaimonDataSource

catalog, _ = local_paimon_catalog
schema = pypaimon.Schema.from_pyarrow_schema(
pa.schema([
pa.field("id", pa.int64()),
pa.field("name", pa.string()),
]),
options={
"file.format": "avro",
"source.split.target-size": "800b",
"source.split.open-file-cost": "600b",
},
)
catalog.create_table("test_db.avro_serialized_pushdown_filter", schema, ignore_if_exists=False)
table = catalog.get_table("test_db.avro_serialized_pushdown_filter")
_write_to_paimon(table, pa.table({"id": [1], "name": ["first"]}))
_write_to_paimon(table, pa.table({"id": [999], "name": ["match"]}))

io_config = context.get_context().daft_planning_config.default_io_config
storage_config = StorageConfig(runners.get_or_create_runner().name != "ray", io_config)
source = PaimonDataSource(table, storage_config=storage_config, catalog_options={})
pushed_filters, remaining_filters = source.push_filters([(col("id") == 999)._expr])
assert pushed_filters
assert not remaining_filters

restored = loads(dumps(source))
batches = asyncio.run(
_collect_paimon_source_batches(
restored,
Pushdowns(filters=None, limit=1),
)
)

assert batches == [{"id": [999], "name": ["match"]}]


def test_read_paimon_remote_ray_task_is_serializable(pk_table, monkeypatch):
"""A fallback PK split task must reopen the table from metadata on Ray workers.

Expand Down
Loading