-
Notifications
You must be signed in to change notification settings - Fork 352
Convert _get_column_projection_values
to use Field-IDs
#2293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -131,7 +131,6 @@ | |
) | ||
from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value | ||
from pyiceberg.schema import ( | ||
Accessor, | ||
PartnerAccessor, | ||
PreOrderSchemaVisitor, | ||
Schema, | ||
|
@@ -1402,41 +1401,23 @@ def _field_id(self, field: pa.Field) -> int: | |
|
||
def _get_column_projection_values( | ||
file: DataFile, projected_schema: Schema, partition_spec: Optional[PartitionSpec], file_project_field_ids: Set[int] | ||
) -> Tuple[bool, Dict[str, Any]]: | ||
) -> Dict[int, Any]: | ||
"""Apply Column Projection rules to File Schema.""" | ||
project_schema_diff = projected_schema.field_ids.difference(file_project_field_ids) | ||
should_project_columns = len(project_schema_diff) > 0 | ||
projected_missing_fields: Dict[str, Any] = {} | ||
if len(project_schema_diff) == 0 or partition_spec is None: | ||
return EMPTY_DICT | ||
|
||
if not should_project_columns: | ||
return False, {} | ||
|
||
partition_schema: StructType | ||
accessors: Dict[int, Accessor] | ||
|
||
if partition_spec is not None: | ||
partition_schema = partition_spec.partition_type(projected_schema) | ||
accessors = build_position_accessors(partition_schema) | ||
else: | ||
return False, {} | ||
partition_schema = partition_spec.partition_type(projected_schema) | ||
accessors = build_position_accessors(partition_schema) | ||
|
||
projected_missing_fields = {} | ||
for field_id in project_schema_diff: | ||
for partition_field in partition_spec.fields_by_source_id(field_id): | ||
if isinstance(partition_field.transform, IdentityTransform): | ||
accessor = accessors.get(partition_field.field_id) | ||
|
||
if accessor is None: | ||
continue | ||
if partition_value := accessors[partition_field.field_id].get(file.partition): | ||
projected_missing_fields[field_id] = partition_value | ||
Comment on lines
+1417
to
+1418
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes sense to fail here, rather than suppress the Error. In the case of an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was actually a bug. It always used the current spec, while we should use the spec that it was written with. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch, i see its fixed in https://github.com/apache/iceberg-python/pull/2293/files#diff-8d5e63f2a87ead8cebe2fd8ac5dcf2198d229f01e16bb9e06e21f7277c328abdR1672 now im worried about all the other places we use |
||
|
||
# The partition field may not exist in the partition record of the data file. | ||
# This can happen when new partition fields are introduced after the file was written. | ||
try: | ||
if partition_value := accessor.get(file.partition): | ||
projected_missing_fields[partition_field.name] = partition_value | ||
except IndexError: | ||
continue | ||
|
||
return True, projected_missing_fields | ||
return projected_missing_fields | ||
|
||
|
||
def _task_to_record_batches( | ||
|
@@ -1460,9 +1441,8 @@ def _task_to_record_batches( | |
# the table format version. | ||
file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True) | ||
|
||
# Apply column projection rules | ||
# https://iceberg.apache.org/spec/#column-projection | ||
should_project_columns, projected_missing_fields = _get_column_projection_values( | ||
# Apply column projection rules: https://iceberg.apache.org/spec/#column-projection | ||
projected_missing_fields = _get_column_projection_values( | ||
task.file, projected_schema, partition_spec, file_schema.field_ids | ||
) | ||
|
||
|
@@ -1517,16 +1497,9 @@ def _task_to_record_batches( | |
file_project_schema, | ||
current_batch, | ||
downcast_ns_timestamp_to_us=True, | ||
projected_missing_fields=projected_missing_fields, | ||
) | ||
|
||
# Inject projected column values if available | ||
if should_project_columns: | ||
for name, value in projected_missing_fields.items(): | ||
index = result_batch.schema.get_field_index(name) | ||
if index != -1: | ||
arr = pa.repeat(value, result_batch.num_rows) | ||
result_batch = result_batch.set_column(index, name, arr) | ||
|
||
yield result_batch | ||
|
||
|
||
|
@@ -1696,7 +1669,7 @@ def _record_batches_from_scan_tasks_and_deletes( | |
deletes_per_file.get(task.file.file_path), | ||
self._case_sensitive, | ||
self._table_metadata.name_mapping(), | ||
self._table_metadata.spec(), | ||
self._table_metadata.specs().get(task.file.spec_id), | ||
) | ||
for batch in batches: | ||
if self._limit is not None: | ||
|
@@ -1714,12 +1687,15 @@ def _to_requested_schema( | |
batch: pa.RecordBatch, | ||
downcast_ns_timestamp_to_us: bool = False, | ||
include_field_ids: bool = False, | ||
projected_missing_fields: Dict[int, Any] = EMPTY_DICT, | ||
) -> pa.RecordBatch: | ||
# We could reuse some of these visitors | ||
struct_array = visit_with_partner( | ||
requested_schema, | ||
batch, | ||
ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us, include_field_ids), | ||
ArrowProjectionVisitor( | ||
file_schema, downcast_ns_timestamp_to_us, include_field_ids, projected_missing_fields=projected_missing_fields | ||
), | ||
ArrowAccessor(file_schema), | ||
) | ||
return pa.RecordBatch.from_struct_array(struct_array) | ||
|
@@ -1730,18 +1706,21 @@ class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra | |
_include_field_ids: bool | ||
_downcast_ns_timestamp_to_us: bool | ||
_use_large_types: Optional[bool] | ||
_projected_missing_fields: Dict[int, Any] | ||
|
||
def __init__( | ||
self, | ||
file_schema: Schema, | ||
downcast_ns_timestamp_to_us: bool = False, | ||
include_field_ids: bool = False, | ||
use_large_types: Optional[bool] = None, | ||
projected_missing_fields: Dict[int, Any] = EMPTY_DICT, | ||
) -> None: | ||
self._file_schema = file_schema | ||
self._include_field_ids = include_field_ids | ||
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us | ||
self._use_large_types = use_large_types | ||
self._projected_missing_fields = projected_missing_fields | ||
|
||
if use_large_types is not None: | ||
deprecation_message( | ||
|
@@ -1821,7 +1800,9 @@ def struct( | |
elif field.optional or field.initial_default is not None: | ||
# When an optional field is added, or when a required field with a non-null initial default is added | ||
arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids) | ||
if field.initial_default is None: | ||
if projected_value := self._projected_missing_fields.get(field.field_id): | ||
field_arrays.append(pa.repeat(pa.scalar(projected_value, type=arrow_type), len(struct_array))) | ||
elif field.initial_default is None: | ||
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type)) | ||
else: | ||
field_arrays.append(pa.repeat(field.initial_default, len(struct_array))) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1730,6 +1730,33 @@ def test_translate_column_names_missing_column_match_null() -> None: | |
assert translated_expr == AlwaysTrue() | ||
|
||
|
||
def test_translate_column_names_missing_column_match_explicit_null() -> None: | ||
"""Test translate_column_names when missing column matches null.""" | ||
# Original schema | ||
original_schema = Schema( | ||
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), | ||
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False), | ||
schema_id=1, | ||
) | ||
|
||
# Create bound expression for the missing column | ||
unbound_expr = IsNull("missing_col") | ||
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) | ||
|
||
# File schema only has the existing column (field_id=1), missing field_id=2 | ||
file_schema = Schema( | ||
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), | ||
schema_id=1, | ||
) | ||
|
||
# Translate column names | ||
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True, projected_field_values={2: None}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A partition can be null as well 👍 |
||
|
||
# Should evaluate to AlwaysTrue because the missing column is treated as null | ||
# missing_col's default initial_default (None) satisfies the IsNull predicate | ||
assert translated_expr == AlwaysTrue() | ||
|
||
|
||
def test_translate_column_names_missing_column_with_initial_default() -> None: | ||
"""Test translate_column_names when missing column's initial_default matches expression.""" | ||
# Original schema | ||
|
@@ -1801,7 +1828,7 @@ def test_translate_column_names_missing_column_with_projected_field_matches() -> | |
) | ||
|
||
# Projected column that is missing in the file schema | ||
projected_field_values = {"missing_col": 42} | ||
projected_field_values = {2: 42} | ||
|
||
# Translate column names | ||
translated_expr = translate_column_names( | ||
|
@@ -1833,7 +1860,7 @@ def test_translate_column_names_missing_column_with_projected_field_mismatch() - | |
) | ||
|
||
# Projected column that is missing in the file schema | ||
projected_field_values = {"missing_col": 1} | ||
projected_field_values = {2: 1} | ||
|
||
# Translate column names | ||
translated_expr = translate_column_names( | ||
|
@@ -1864,7 +1891,7 @@ def test_translate_column_names_missing_column_projected_field_fallbacks_to_init | |
) | ||
|
||
# Projected field value that differs from both the expression literal and initial_default | ||
projected_field_values = {"missing_col": 10} # This doesn't match expression literal (42) | ||
projected_field_values = {2: 10} # This doesn't match expression literal (42) | ||
|
||
# Translate column names | ||
translated_expr = translate_column_names( | ||
|
@@ -1895,7 +1922,7 @@ def test_translate_column_names_missing_column_projected_field_matches_initial_d | |
) | ||
|
||
# Projected field value that matches the expression literal | ||
projected_field_values = {"missing_col": 10} # This doesn't match expression literal (42) | ||
projected_field_values = {2: 10} # This doesn't match expression literal (42) | ||
|
||
# Translate column names | ||
translated_expr = translate_column_names( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -970,6 +970,10 @@ def file_map(schema_map: Schema, tmpdir: str) -> str: | |
def project( | ||
schema: Schema, files: List[str], expr: Optional[BooleanExpression] = None, table_schema: Optional[Schema] = None | ||
) -> pa.Table: | ||
def _set_spec_id(datafile: DataFile) -> DataFile: | ||
datafile.spec_id = 0 | ||
return datafile | ||
|
||
return ArrowScan( | ||
table_metadata=TableMetadataV2( | ||
location="file://a/b/", | ||
|
@@ -985,13 +989,15 @@ def project( | |
).to_table( | ||
tasks=[ | ||
FileScanTask( | ||
DataFile.from_args( | ||
content=DataFileContent.DATA, | ||
file_path=file, | ||
file_format=FileFormat.PARQUET, | ||
partition={}, | ||
record_count=3, | ||
file_size_in_bytes=3, | ||
_set_spec_id( | ||
DataFile.from_args( | ||
content=DataFileContent.DATA, | ||
file_path=file, | ||
file_format=FileFormat.PARQUET, | ||
partition={}, | ||
record_count=3, | ||
file_size_in_bytes=3, | ||
) | ||
) | ||
) | ||
for file in files | ||
|
@@ -1189,7 +1195,7 @@ def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCa | |
with transaction.update_snapshot().overwrite() as update: | ||
update.append_data_file(unpartitioned_file) | ||
|
||
schema = pa.schema([("other_field", pa.string()), ("partition_id", pa.int64())]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The schema = Schema(
NestedField(1, "other_field", StringType(), required=False), NestedField(2, "partition_id", IntegerType(), required=False)
) |
||
schema = pa.schema([("other_field", pa.string()), ("partition_id", pa.int32())]) | ||
assert table.scan().to_arrow() == pa.table( | ||
{ | ||
"other_field": ["foo", "bar", "baz"], | ||
|
@@ -1264,8 +1270,8 @@ def test_identity_transform_columns_projection(tmp_path: str, catalog: InMemoryC | |
str(table.scan().to_arrow()) | ||
== """pyarrow.Table | ||
field_1: string | ||
field_2: int64 | ||
field_3: int64 | ||
field_2: int32 | ||
field_3: int32 | ||
---- | ||
field_1: [["foo"]] | ||
field_2: [[2]] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing wallrus
:=
here since the projected value can also beNone