Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions pyiceberg/expressions/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]):
Args:
file_schema (Schema): The schema of the file.
case_sensitive (bool): Whether to consider case when binding a reference to a field in a schema, defaults to True.
projected_field_values (Dict[str, Any]): Values for projected fields not present in the data file.
projected_field_values (Dict[int, Any]): Values for projected fields not present in the data file.

Raises:
TypeError: In the case of an UnboundPredicate.
Expand All @@ -870,12 +870,12 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]):

file_schema: Schema
case_sensitive: bool
projected_field_values: Dict[str, Any]
projected_field_values: Dict[int, Any]

def __init__(self, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[str, Any] = EMPTY_DICT) -> None:
def __init__(self, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[int, Any] = EMPTY_DICT) -> None:
self.file_schema = file_schema
self.case_sensitive = case_sensitive
self.projected_field_values = projected_field_values or {}
self.projected_field_values = projected_field_values

def visit_true(self) -> BooleanExpression:
return AlwaysTrue()
Expand All @@ -897,7 +897,8 @@ def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpr

def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpression:
field = predicate.term.ref().field
file_column_name = self.file_schema.find_column_name(field.field_id)
field_id = field.field_id
file_column_name = self.file_schema.find_column_name(field_id)

if file_column_name is None:
# In the case of schema evolution or column projection, the field might not be present in the file schema.
Expand All @@ -915,8 +916,10 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
# In the order described by the "Column Projection" section of the Iceberg spec:
# https://iceberg.apache.org/spec/#column-projection
# Evaluate column projection first if it exists
if projected_field_value := self.projected_field_values.get(field.name):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing wallrus := here since the projected value can also be None

if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(projected_field_value)):
if field_id in self.projected_field_values:
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(
Record(self.projected_field_values[field_id])
):
return AlwaysTrue()

# Evaluate initial_default value
Expand All @@ -937,7 +940,7 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi


def translate_column_names(
expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[str, Any] = EMPTY_DICT
expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[int, Any] = EMPTY_DICT
) -> BooleanExpression:
return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive, projected_field_values))

Expand Down
65 changes: 23 additions & 42 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@
)
from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value
from pyiceberg.schema import (
Accessor,
PartnerAccessor,
PreOrderSchemaVisitor,
Schema,
Expand Down Expand Up @@ -1402,41 +1401,23 @@ def _field_id(self, field: pa.Field) -> int:

def _get_column_projection_values(
file: DataFile, projected_schema: Schema, partition_spec: Optional[PartitionSpec], file_project_field_ids: Set[int]
) -> Tuple[bool, Dict[str, Any]]:
) -> Dict[int, Any]:
"""Apply Column Projection rules to File Schema."""
project_schema_diff = projected_schema.field_ids.difference(file_project_field_ids)
should_project_columns = len(project_schema_diff) > 0
projected_missing_fields: Dict[str, Any] = {}
if len(project_schema_diff) == 0 or partition_spec is None:
return EMPTY_DICT

if not should_project_columns:
return False, {}

partition_schema: StructType
accessors: Dict[int, Accessor]

if partition_spec is not None:
partition_schema = partition_spec.partition_type(projected_schema)
accessors = build_position_accessors(partition_schema)
else:
return False, {}
partition_schema = partition_spec.partition_type(projected_schema)
accessors = build_position_accessors(partition_schema)

projected_missing_fields = {}
for field_id in project_schema_diff:
for partition_field in partition_spec.fields_by_source_id(field_id):
if isinstance(partition_field.transform, IdentityTransform):
accessor = accessors.get(partition_field.field_id)

if accessor is None:
continue
if partition_value := accessors[partition_field.field_id].get(file.partition):
projected_missing_fields[field_id] = partition_value
Comment on lines +1417 to +1418
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to fail here, rather than suppress the Error. In the case of an IndexError I think your table is corrupt.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was actually a bug. It always used the current spec, while we should use the spec that it was written with.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, i see its fixed in https://github.com/apache/iceberg-python/pull/2293/files#diff-8d5e63f2a87ead8cebe2fd8ac5dcf2198d229f01e16bb9e06e21f7277c328abdR1672

now im worried about all the other places we use .spec()


# The partition field may not exist in the partition record of the data file.
# This can happen when new partition fields are introduced after the file was written.
try:
if partition_value := accessor.get(file.partition):
projected_missing_fields[partition_field.name] = partition_value
except IndexError:
continue

return True, projected_missing_fields
return projected_missing_fields


def _task_to_record_batches(
Expand All @@ -1460,9 +1441,8 @@ def _task_to_record_batches(
# the table format version.
file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True)

# Apply column projection rules
# https://iceberg.apache.org/spec/#column-projection
should_project_columns, projected_missing_fields = _get_column_projection_values(
# Apply column projection rules: https://iceberg.apache.org/spec/#column-projection
projected_missing_fields = _get_column_projection_values(
task.file, projected_schema, partition_spec, file_schema.field_ids
)

Expand Down Expand Up @@ -1517,16 +1497,9 @@ def _task_to_record_batches(
file_project_schema,
current_batch,
downcast_ns_timestamp_to_us=True,
projected_missing_fields=projected_missing_fields,
)

# Inject projected column values if available
if should_project_columns:
for name, value in projected_missing_fields.items():
index = result_batch.schema.get_field_index(name)
if index != -1:
arr = pa.repeat(value, result_batch.num_rows)
result_batch = result_batch.set_column(index, name, arr)

yield result_batch


Expand Down Expand Up @@ -1696,7 +1669,7 @@ def _record_batches_from_scan_tasks_and_deletes(
deletes_per_file.get(task.file.file_path),
self._case_sensitive,
self._table_metadata.name_mapping(),
self._table_metadata.spec(),
self._table_metadata.specs().get(task.file.spec_id),
)
for batch in batches:
if self._limit is not None:
Expand All @@ -1714,12 +1687,15 @@ def _to_requested_schema(
batch: pa.RecordBatch,
downcast_ns_timestamp_to_us: bool = False,
include_field_ids: bool = False,
projected_missing_fields: Dict[int, Any] = EMPTY_DICT,
) -> pa.RecordBatch:
# We could reuse some of these visitors
struct_array = visit_with_partner(
requested_schema,
batch,
ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us, include_field_ids),
ArrowProjectionVisitor(
file_schema, downcast_ns_timestamp_to_us, include_field_ids, projected_missing_fields=projected_missing_fields
),
ArrowAccessor(file_schema),
)
return pa.RecordBatch.from_struct_array(struct_array)
Expand All @@ -1730,18 +1706,21 @@ class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
_include_field_ids: bool
_downcast_ns_timestamp_to_us: bool
_use_large_types: Optional[bool]
_projected_missing_fields: Dict[int, Any]

def __init__(
self,
file_schema: Schema,
downcast_ns_timestamp_to_us: bool = False,
include_field_ids: bool = False,
use_large_types: Optional[bool] = None,
projected_missing_fields: Dict[int, Any] = EMPTY_DICT,
) -> None:
self._file_schema = file_schema
self._include_field_ids = include_field_ids
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
self._use_large_types = use_large_types
self._projected_missing_fields = projected_missing_fields

if use_large_types is not None:
deprecation_message(
Expand Down Expand Up @@ -1821,7 +1800,9 @@ def struct(
elif field.optional or field.initial_default is not None:
# When an optional field is added, or when a required field with a non-null initial default is added
arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)
if field.initial_default is None:
if projected_value := self._projected_missing_fields.get(field.field_id):
field_arrays.append(pa.repeat(pa.scalar(projected_value, type=arrow_type), len(struct_array)))
elif field.initial_default is None:
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))
else:
field_arrays.append(pa.repeat(field.initial_default, len(struct_array)))
Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2375,8 +2375,10 @@ def data_file(table_schema_simple: Schema, tmp_path: str) -> str:

@pytest.fixture
def example_task(data_file: str) -> FileScanTask:
datafile = DataFile.from_args(file_path=data_file, file_format=FileFormat.PARQUET, file_size_in_bytes=1925)
datafile.spec_id = 0
return FileScanTask(
data_file=DataFile.from_args(file_path=data_file, file_format=FileFormat.PARQUET, file_size_in_bytes=1925),
data_file=datafile,
)


Expand Down
35 changes: 31 additions & 4 deletions tests/expressions/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,33 @@ def test_translate_column_names_missing_column_match_null() -> None:
assert translated_expr == AlwaysTrue()


def test_translate_column_names_missing_column_match_explicit_null() -> None:
"""Test translate_column_names when missing column matches null."""
# Original schema
original_schema = Schema(
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False),
schema_id=1,
)

# Create bound expression for the missing column
unbound_expr = IsNull("missing_col")
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))

# File schema only has the existing column (field_id=1), missing field_id=2
file_schema = Schema(
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
schema_id=1,
)

# Translate column names
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True, projected_field_values={2: None})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A partition can be null as well 👍


# Should evaluate to AlwaysTrue because the missing column is treated as null
# missing_col's default initial_default (None) satisfies the IsNull predicate
assert translated_expr == AlwaysTrue()


def test_translate_column_names_missing_column_with_initial_default() -> None:
"""Test translate_column_names when missing column's initial_default matches expression."""
# Original schema
Expand Down Expand Up @@ -1801,7 +1828,7 @@ def test_translate_column_names_missing_column_with_projected_field_matches() ->
)

# Projected column that is missing in the file schema
projected_field_values = {"missing_col": 42}
projected_field_values = {2: 42}

# Translate column names
translated_expr = translate_column_names(
Expand Down Expand Up @@ -1833,7 +1860,7 @@ def test_translate_column_names_missing_column_with_projected_field_mismatch() -
)

# Projected column that is missing in the file schema
projected_field_values = {"missing_col": 1}
projected_field_values = {2: 1}

# Translate column names
translated_expr = translate_column_names(
Expand Down Expand Up @@ -1864,7 +1891,7 @@ def test_translate_column_names_missing_column_projected_field_fallbacks_to_init
)

# Projected field value that differs from both the expression literal and initial_default
projected_field_values = {"missing_col": 10} # This doesn't match expression literal (42)
projected_field_values = {2: 10} # This doesn't match expression literal (42)

# Translate column names
translated_expr = translate_column_names(
Expand Down Expand Up @@ -1895,7 +1922,7 @@ def test_translate_column_names_missing_column_projected_field_matches_initial_d
)

# Projected field value that matches the expression literal
projected_field_values = {"missing_col": 10} # This doesn't match expression literal (42)
projected_field_values = {2: 10} # This doesn't match expression literal (42)

# Translate column names
translated_expr = translate_column_names(
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,8 +711,10 @@ def test_dynamic_partition_overwrite_evolve_partition(spark: SparkSession, sessi
)

identifier = f"default.partitioned_{format_version}_test_dynamic_partition_overwrite_evolve_partition"
with pytest.raises(NoSuchTableError):
try:
session_catalog.drop_table(identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(
identifier=identifier,
Expand Down
26 changes: 16 additions & 10 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,10 @@ def file_map(schema_map: Schema, tmpdir: str) -> str:
def project(
schema: Schema, files: List[str], expr: Optional[BooleanExpression] = None, table_schema: Optional[Schema] = None
) -> pa.Table:
def _set_spec_id(datafile: DataFile) -> DataFile:
datafile.spec_id = 0
return datafile

return ArrowScan(
table_metadata=TableMetadataV2(
location="file://a/b/",
Expand All @@ -985,13 +989,15 @@ def project(
).to_table(
tasks=[
FileScanTask(
DataFile.from_args(
content=DataFileContent.DATA,
file_path=file,
file_format=FileFormat.PARQUET,
partition={},
record_count=3,
file_size_in_bytes=3,
_set_spec_id(
DataFile.from_args(
content=DataFileContent.DATA,
file_path=file,
file_format=FileFormat.PARQUET,
partition={},
record_count=3,
file_size_in_bytes=3,
)
)
)
for file in files
Expand Down Expand Up @@ -1189,7 +1195,7 @@ def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCa
with transaction.update_snapshot().overwrite() as update:
update.append_data_file(unpartitioned_file)

schema = pa.schema([("other_field", pa.string()), ("partition_id", pa.int64())])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IdentityTransform returns the same type as the one in the table:

    schema = Schema(
        NestedField(1, "other_field", StringType(), required=False), NestedField(2, "partition_id", IntegerType(), required=False)
    )

schema = pa.schema([("other_field", pa.string()), ("partition_id", pa.int32())])
assert table.scan().to_arrow() == pa.table(
{
"other_field": ["foo", "bar", "baz"],
Expand Down Expand Up @@ -1264,8 +1270,8 @@ def test_identity_transform_columns_projection(tmp_path: str, catalog: InMemoryC
str(table.scan().to_arrow())
== """pyarrow.Table
field_1: string
field_2: int64
field_3: int64
field_2: int32
field_3: int32
----
field_1: [["foo"]]
field_2: [[2]]
Expand Down