Skip to content

Commit d99936a

Browse files
emilie-wangHanzhi Wang
andauthored
Add expression to table.inspect.partitions() (#2596)
Closes [2562](#2562) # Rationale for this change Allow users to query specific partitions with predicate while inspecting table partitions. As suggested, add a predicate as an argument into `table.inspect.partitions()` that defaults to ALWAYS_TRUE. ## Are these changes tested? Yes, new integration tests added. ## Are there any user-facing changes? I believe no. <!-- In the case of user-facing changes, please add the changelog label. --> --------- Co-authored-by: Hanzhi Wang <hanzhi_wang@apple.com>
1 parent 2a9f2ea commit d99936a

File tree

3 files changed

+189
-100
lines changed

3 files changed

+189
-100
lines changed

pyiceberg/table/__init__.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Callable,
3232
Dict,
3333
Iterable,
34+
Iterator,
3435
List,
3536
Optional,
3637
Set,
@@ -1942,11 +1943,11 @@ def _check_sequence_number(min_sequence_number: int, manifest: ManifestFile) ->
19421943
and (manifest.sequence_number or INITIAL_SEQUENCE_NUMBER) >= min_sequence_number
19431944
)
19441945

1945-
def plan_files(self) -> Iterable[FileScanTask]:
1946-
"""Plans the relevant files by filtering on the PartitionSpecs.
1946+
def scan_plan_helper(self) -> Iterator[List[ManifestEntry]]:
1947+
"""Filter and return manifest entries based on partition and metrics evaluators.
19471948
19481949
Returns:
1949-
List of FileScanTasks that contain both data and delete files.
1950+
Iterator of ManifestEntry objects that match the scan's partition filter.
19501951
"""
19511952
snapshot = self.snapshot()
19521953
if not snapshot:
@@ -1957,8 +1958,6 @@ def plan_files(self) -> Iterable[FileScanTask]:
19571958

19581959
manifest_evaluators: Dict[int, Callable[[ManifestFile], bool]] = KeyDefaultDict(self._build_manifest_evaluator)
19591960

1960-
residual_evaluators: Dict[int, Callable[[DataFile], ResidualEvaluator]] = KeyDefaultDict(self._build_residual_evaluator)
1961-
19621961
manifests = [
19631962
manifest_file
19641963
for manifest_file in snapshot.manifests(self.io)
@@ -1972,25 +1971,34 @@ def plan_files(self) -> Iterable[FileScanTask]:
19721971

19731972
min_sequence_number = _min_sequence_number(manifests)
19741973

1974+
executor = ExecutorFactory.get_or_create()
1975+
1976+
return executor.map(
1977+
lambda args: _open_manifest(*args),
1978+
[
1979+
(
1980+
self.io,
1981+
manifest,
1982+
partition_evaluators[manifest.partition_spec_id],
1983+
self._build_metrics_evaluator(),
1984+
)
1985+
for manifest in manifests
1986+
if self._check_sequence_number(min_sequence_number, manifest)
1987+
],
1988+
)
1989+
1990+
def plan_files(self) -> Iterable[FileScanTask]:
1991+
"""Plans the relevant files by filtering on the PartitionSpecs.
1992+
1993+
Returns:
1994+
List of FileScanTasks that contain both data and delete files.
1995+
"""
19751996
data_entries: List[ManifestEntry] = []
19761997
positional_delete_entries = SortedList(key=lambda entry: entry.sequence_number or INITIAL_SEQUENCE_NUMBER)
19771998

1978-
executor = ExecutorFactory.get_or_create()
1979-
for manifest_entry in chain(
1980-
*executor.map(
1981-
lambda args: _open_manifest(*args),
1982-
[
1983-
(
1984-
self.io,
1985-
manifest,
1986-
partition_evaluators[manifest.partition_spec_id],
1987-
self._build_metrics_evaluator(),
1988-
)
1989-
for manifest in manifests
1990-
if self._check_sequence_number(min_sequence_number, manifest)
1991-
],
1992-
)
1993-
):
1999+
residual_evaluators: Dict[int, Callable[[DataFile], ResidualEvaluator]] = KeyDefaultDict(self._build_residual_evaluator)
2000+
2001+
for manifest_entry in chain.from_iterable(self.scan_plan_helper()):
19942002
data_file = manifest_entry.data_file
19952003
if data_file.content == DataFileContent.DATA:
19962004
data_entries.append(manifest_entry)

pyiceberg/table/inspect.py

Lines changed: 70 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import itertools
1920
from datetime import datetime, timezone
20-
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Tuple
21+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Tuple, Union
2122

2223
from pyiceberg.conversions import from_bytes
23-
from pyiceberg.manifest import DataFileContent, ManifestContent, ManifestFile, PartitionFieldSummary
24+
from pyiceberg.expressions import AlwaysTrue, BooleanExpression
25+
from pyiceberg.manifest import DataFile, DataFileContent, ManifestContent, ManifestFile, PartitionFieldSummary
2426
from pyiceberg.partitioning import PartitionSpec
2527
from pyiceberg.table.snapshots import Snapshot, ancestors_of
2628
from pyiceberg.types import PrimitiveType
@@ -32,6 +34,8 @@
3234

3335
from pyiceberg.table import Table
3436

37+
ALWAYS_TRUE = AlwaysTrue()
38+
3539

3640
class InspectTable:
3741
tbl: Table
@@ -255,10 +259,16 @@ def refs(self) -> "pa.Table":
255259

256260
return pa.Table.from_pylist(ref_results, schema=ref_schema)
257261

258-
def partitions(self, snapshot_id: Optional[int] = None) -> "pa.Table":
262+
def partitions(
263+
self,
264+
snapshot_id: Optional[int] = None,
265+
row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE,
266+
case_sensitive: bool = True,
267+
) -> "pa.Table":
259268
import pyarrow as pa
260269

261270
from pyiceberg.io.pyarrow import schema_to_pyarrow
271+
from pyiceberg.table import DataScan
262272

263273
table_schema = pa.schema(
264274
[
@@ -289,85 +299,74 @@ def partitions(self, snapshot_id: Optional[int] = None) -> "pa.Table":
289299
table_schema = pa.unify_schemas([partitions_schema, table_schema])
290300

291301
snapshot = self._get_snapshot(snapshot_id)
292-
executor = ExecutorFactory.get_or_create()
293-
local_partitions_maps = executor.map(self._process_manifest, snapshot.manifests(self.tbl.io))
294-
295-
partitions_map: Dict[Tuple[str, Any], Any] = {}
296-
for local_map in local_partitions_maps:
297-
for partition_record_key, partition_row in local_map.items():
298-
if partition_record_key not in partitions_map:
299-
partitions_map[partition_record_key] = partition_row
300-
else:
301-
existing = partitions_map[partition_record_key]
302-
existing["record_count"] += partition_row["record_count"]
303-
existing["file_count"] += partition_row["file_count"]
304-
existing["total_data_file_size_in_bytes"] += partition_row["total_data_file_size_in_bytes"]
305-
existing["position_delete_record_count"] += partition_row["position_delete_record_count"]
306-
existing["position_delete_file_count"] += partition_row["position_delete_file_count"]
307-
existing["equality_delete_record_count"] += partition_row["equality_delete_record_count"]
308-
existing["equality_delete_file_count"] += partition_row["equality_delete_file_count"]
309-
310-
if partition_row["last_updated_at"] and (
311-
not existing["last_updated_at"] or partition_row["last_updated_at"] > existing["last_updated_at"]
312-
):
313-
existing["last_updated_at"] = partition_row["last_updated_at"]
314-
existing["last_updated_snapshot_id"] = partition_row["last_updated_snapshot_id"]
315302

316-
return pa.Table.from_pylist(
317-
partitions_map.values(),
318-
schema=table_schema,
303+
scan = DataScan(
304+
table_metadata=self.tbl.metadata,
305+
io=self.tbl.io,
306+
row_filter=row_filter,
307+
case_sensitive=case_sensitive,
308+
snapshot_id=snapshot.snapshot_id,
319309
)
320310

321-
def _process_manifest(self, manifest: ManifestFile) -> Dict[Tuple[str, Any], Any]:
322311
partitions_map: Dict[Tuple[str, Any], Any] = {}
323-
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
312+
313+
for entry in itertools.chain.from_iterable(scan.scan_plan_helper()):
324314
partition = entry.data_file.partition
325315
partition_record_dict = {
326-
field.name: partition[pos]
327-
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
316+
field.name: partition[pos] for pos, field in enumerate(self.tbl.metadata.specs()[entry.data_file.spec_id].fields)
328317
}
329318
entry_snapshot = self.tbl.snapshot_by_id(entry.snapshot_id) if entry.snapshot_id is not None else None
319+
self._update_partitions_map_from_manifest_entry(
320+
partitions_map, entry.data_file, partition_record_dict, entry_snapshot
321+
)
330322

331-
partition_record_key = _convert_to_hashable_type(partition_record_dict)
332-
if partition_record_key not in partitions_map:
333-
partitions_map[partition_record_key] = {
334-
"partition": partition_record_dict,
335-
"spec_id": entry.data_file.spec_id,
336-
"record_count": 0,
337-
"file_count": 0,
338-
"total_data_file_size_in_bytes": 0,
339-
"position_delete_record_count": 0,
340-
"position_delete_file_count": 0,
341-
"equality_delete_record_count": 0,
342-
"equality_delete_file_count": 0,
343-
"last_updated_at": entry_snapshot.timestamp_ms if entry_snapshot else None,
344-
"last_updated_snapshot_id": entry_snapshot.snapshot_id if entry_snapshot else None,
345-
}
323+
return pa.Table.from_pylist(
324+
partitions_map.values(),
325+
schema=table_schema,
326+
)
346327

347-
partition_row = partitions_map[partition_record_key]
348-
349-
if entry_snapshot is not None:
350-
if (
351-
partition_row["last_updated_at"] is None
352-
or partition_row["last_updated_snapshot_id"] < entry_snapshot.timestamp_ms
353-
):
354-
partition_row["last_updated_at"] = entry_snapshot.timestamp_ms
355-
partition_row["last_updated_snapshot_id"] = entry_snapshot.snapshot_id
356-
357-
if entry.data_file.content == DataFileContent.DATA:
358-
partition_row["record_count"] += entry.data_file.record_count
359-
partition_row["file_count"] += 1
360-
partition_row["total_data_file_size_in_bytes"] += entry.data_file.file_size_in_bytes
361-
elif entry.data_file.content == DataFileContent.POSITION_DELETES:
362-
partition_row["position_delete_record_count"] += entry.data_file.record_count
363-
partition_row["position_delete_file_count"] += 1
364-
elif entry.data_file.content == DataFileContent.EQUALITY_DELETES:
365-
partition_row["equality_delete_record_count"] += entry.data_file.record_count
366-
partition_row["equality_delete_file_count"] += 1
367-
else:
368-
raise ValueError(f"Unknown DataFileContent ({entry.data_file.content})")
328+
def _update_partitions_map_from_manifest_entry(
329+
self,
330+
partitions_map: Dict[Tuple[str, Any], Any],
331+
file: DataFile,
332+
partition_record_dict: Dict[str, Any],
333+
snapshot: Optional[Snapshot],
334+
) -> None:
335+
partition_record_key = _convert_to_hashable_type(partition_record_dict)
336+
if partition_record_key not in partitions_map:
337+
partitions_map[partition_record_key] = {
338+
"partition": partition_record_dict,
339+
"spec_id": file.spec_id,
340+
"record_count": 0,
341+
"file_count": 0,
342+
"total_data_file_size_in_bytes": 0,
343+
"position_delete_record_count": 0,
344+
"position_delete_file_count": 0,
345+
"equality_delete_record_count": 0,
346+
"equality_delete_file_count": 0,
347+
"last_updated_at": snapshot.timestamp_ms if snapshot else None,
348+
"last_updated_snapshot_id": snapshot.snapshot_id if snapshot else None,
349+
}
369350

370-
return partitions_map
351+
partition_row = partitions_map[partition_record_key]
352+
353+
if snapshot is not None:
354+
if partition_row["last_updated_at"] is None or partition_row["last_updated_snapshot_id"] < snapshot.timestamp_ms:
355+
partition_row["last_updated_at"] = snapshot.timestamp_ms
356+
partition_row["last_updated_snapshot_id"] = snapshot.snapshot_id
357+
358+
if file.content == DataFileContent.DATA:
359+
partition_row["record_count"] += file.record_count
360+
partition_row["file_count"] += 1
361+
partition_row["total_data_file_size_in_bytes"] += file.file_size_in_bytes
362+
elif file.content == DataFileContent.POSITION_DELETES:
363+
partition_row["position_delete_record_count"] += file.record_count
364+
partition_row["position_delete_file_count"] += 1
365+
elif file.content == DataFileContent.EQUALITY_DELETES:
366+
partition_row["equality_delete_record_count"] += file.record_count
367+
partition_row["equality_delete_file_count"] += 1
368+
else:
369+
raise ValueError(f"Unknown DataFileContent ({file.content})")
371370

372371
def _get_manifests_schema(self) -> "pa.Schema":
373372
import pyarrow as pa

tests/integration/test_inspect_table.py

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import math
2020
from datetime import date, datetime
21+
from typing import Union
2122

2223
import pyarrow as pa
2324
import pytest
@@ -26,6 +27,13 @@
2627

2728
from pyiceberg.catalog import Catalog
2829
from pyiceberg.exceptions import NoSuchTableError
30+
from pyiceberg.expressions import (
31+
And,
32+
BooleanExpression,
33+
EqualTo,
34+
GreaterThanOrEqual,
35+
LessThan,
36+
)
2937
from pyiceberg.schema import Schema
3038
from pyiceberg.table import Table
3139
from pyiceberg.typedef import Properties
@@ -198,6 +206,14 @@ def _inspect_files_asserts(df: pa.Table, spark_df: DataFrame) -> None:
198206
assert left == right, f"Difference in column {column}: {left} != {right}"
199207

200208

209+
def _check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> None:
210+
lhs = df.to_pandas().sort_values("last_updated_at")
211+
rhs = spark_df.toPandas().sort_values("last_updated_at")
212+
for column in df.column_names:
213+
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
214+
assert left == right, f"Difference in column {column}: {left} != {right}"
215+
216+
201217
@pytest.mark.integration
202218
@pytest.mark.parametrize("format_version", [1, 2])
203219
def test_inspect_snapshots(
@@ -581,18 +597,84 @@ def test_inspect_partitions_partitioned(spark: SparkSession, session_catalog: Ca
581597
"""
582598
)
583599

584-
def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> None:
585-
lhs = df.to_pandas().sort_values("spec_id")
586-
rhs = spark_df.toPandas().sort_values("spec_id")
587-
for column in df.column_names:
588-
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
589-
assert left == right, f"Difference in column {column}: {left} != {right}"
590-
591600
tbl = session_catalog.load_table(identifier)
592601
for snapshot in tbl.metadata.snapshots:
593602
df = tbl.inspect.partitions(snapshot_id=snapshot.snapshot_id)
594603
spark_df = spark.sql(f"SELECT * FROM {identifier}.partitions VERSION AS OF {snapshot.snapshot_id}")
595-
check_pyiceberg_df_equals_spark_df(df, spark_df)
604+
_check_pyiceberg_df_equals_spark_df(df, spark_df)
605+
606+
607+
@pytest.mark.integration
608+
@pytest.mark.parametrize("format_version", [1, 2])
609+
def test_inspect_partitions_partitioned_with_filter(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
610+
identifier = "default.table_metadata_partitions_with_filter"
611+
try:
612+
session_catalog.drop_table(identifier=identifier)
613+
except NoSuchTableError:
614+
pass
615+
616+
spark.sql(
617+
f"""
618+
CREATE TABLE {identifier} (
619+
name string,
620+
dt date
621+
)
622+
PARTITIONED BY (dt)
623+
"""
624+
)
625+
626+
spark.sql(
627+
f"""
628+
INSERT INTO {identifier} VALUES ('John', CAST('2021-01-01' AS date))
629+
"""
630+
)
631+
632+
spark.sql(
633+
f"""
634+
INSERT INTO {identifier} VALUES ('Doe', CAST('2021-01-05' AS date))
635+
"""
636+
)
637+
638+
spark.sql(
639+
f"""
640+
INSERT INTO {identifier} VALUES ('Jenny', CAST('2021-02-01' AS date))
641+
"""
642+
)
643+
644+
tbl = session_catalog.load_table(identifier)
645+
for snapshot in tbl.metadata.snapshots:
646+
test_cases: list[tuple[Union[str, BooleanExpression], str]] = [
647+
("dt >= '2021-01-01'", "partition.dt >= '2021-01-01'"),
648+
(GreaterThanOrEqual("dt", "2021-01-01"), "partition.dt >= '2021-01-01'"),
649+
("dt >= '2021-01-01' and dt < '2021-03-01'", "partition.dt >= '2021-01-01' AND partition.dt < '2021-03-01'"),
650+
(
651+
And(GreaterThanOrEqual("dt", "2021-01-01"), LessThan("dt", "2021-03-01")),
652+
"partition.dt >= '2021-01-01' AND partition.dt < '2021-03-01'",
653+
),
654+
("dt == '2021-02-01'", "partition.dt = '2021-02-01'"),
655+
(EqualTo("dt", "2021-02-01"), "partition.dt = '2021-02-01'"),
656+
]
657+
for filter_predicate_lt, filter_predicate_rt in test_cases:
658+
df = tbl.inspect.partitions(snapshot_id=snapshot.snapshot_id, row_filter=filter_predicate_lt)
659+
spark_df = spark.sql(
660+
f"SELECT * FROM {identifier}.partitions VERSION AS OF {snapshot.snapshot_id} WHERE {filter_predicate_rt}"
661+
)
662+
_check_pyiceberg_df_equals_spark_df(df, spark_df)
663+
664+
665+
@pytest.mark.integration
666+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog")])
667+
def test_inspect_partitions_partitioned_transform_with_filter(spark: SparkSession, catalog: Catalog) -> None:
668+
for table_name, predicate, partition_predicate in [
669+
("test_partitioned_by_identity", "ts >= '2023-03-05T00:00:00+00:00'", "ts >= '2023-03-05T00:00:00+00:00'"),
670+
("test_partitioned_by_years", "dt >= '2023-03-05'", "dt_year >= 53"),
671+
("test_partitioned_by_months", "dt >= '2023-03-05'", "dt_month >= 638"),
672+
("test_partitioned_by_days", "ts >= '2023-03-05T00:00:00+00:00'", "ts_day >= '2023-03-05'"),
673+
]:
674+
table = catalog.load_table(f"default.{table_name}")
675+
df = table.inspect.partitions(row_filter=predicate)
676+
expected_df = spark.sql(f"select * from default.{table_name}.partitions where partition.{partition_predicate}")
677+
assert len(df.to_pandas()) == len(expected_df.toPandas())
596678

597679

598680
@pytest.mark.integration

0 commit comments

Comments
 (0)