Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,14 +1492,18 @@ def _field_id(self, field: pa.Field) -> int:


def _get_column_projection_values(
file: DataFile, projected_schema: Schema, partition_spec: Optional[PartitionSpec], file_project_field_ids: Set[int]
file: DataFile,
projected_schema: Schema,
table_schema: Schema,
partition_spec: Optional[PartitionSpec],
file_project_field_ids: Set[int],
) -> Dict[int, Any]:
"""Apply Column Projection rules to File Schema."""
project_schema_diff = projected_schema.field_ids.difference(file_project_field_ids)
if len(project_schema_diff) == 0 or partition_spec is None:
return EMPTY_DICT

partition_schema = partition_spec.partition_type(projected_schema)
partition_schema = partition_spec.partition_type(table_schema)
accessors = build_position_accessors(partition_schema)

projected_missing_fields = {}
Expand All @@ -1517,6 +1521,7 @@ def _task_to_record_batches(
task: FileScanTask,
bound_row_filter: BooleanExpression,
projected_schema: Schema,
table_schema: Schema,
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
Expand All @@ -1541,7 +1546,7 @@ def _task_to_record_batches(

# Apply column projection rules: https://iceberg.apache.org/spec/#column-projection
projected_missing_fields = _get_column_projection_values(
task.file, projected_schema, partition_spec, file_schema.field_ids
task.file, projected_schema, table_schema, partition_spec, file_schema.field_ids
)

pyarrow_filter = None
Expand Down Expand Up @@ -1763,6 +1768,7 @@ def _record_batches_from_scan_tasks_and_deletes(
task,
self._bound_row_filter,
self._projected_schema,
self._table_metadata.schema(),
self._projected_field_ids,
deletes_per_file.get(task.file.file_path),
self._case_sensitive,
Expand Down
70 changes: 70 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2846,6 +2846,7 @@ def test_task_to_record_batches_nanos(format_version: TableVersion, tmpdir: str)
FileScanTask(data_file),
bound_row_filter=AlwaysTrue(),
projected_schema=table_schema,
table_schema=table_schema,
projected_field_ids={1},
positional_deletes=None,
case_sensitive=True,
Expand Down Expand Up @@ -4590,3 +4591,72 @@ def test_orc_stripe_based_batching(tmp_path: Path) -> None:
# Verify total rows
total_rows = sum(batch.num_rows for batch in batches)
assert total_rows == 10000, f"Expected 10000 total rows, got {total_rows}"


def test_partition_column_projection_with_schema_evolution(catalog: InMemoryCatalog) -> None:
"""Test column projection on partitioned table after schema evolution (https://github.com/apache/iceberg-python/issues/2672)."""
initial_schema = Schema(
NestedField(1, "partition_date", DateType(), required=False),
NestedField(2, "id", IntegerType(), required=False),
NestedField(3, "name", StringType(), required=False),
NestedField(4, "value", IntegerType(), required=False),
)

partition_spec = PartitionSpec(
PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="partition_date"),
)

catalog.create_namespace("default")
table = catalog.create_table(
"default.test_schema_evolution_projection",
schema=initial_schema,
partition_spec=partition_spec,
)

data_v1 = pa.Table.from_pylist(
[
{"partition_date": date(2024, 1, 1), "id": 1, "name": "Alice", "value": 100},
{"partition_date": date(2024, 1, 1), "id": 2, "name": "Bob", "value": 200},
],
schema=pa.schema(
[
("partition_date", pa.date32()),
("id", pa.int32()),
("name", pa.string()),
("value", pa.int32()),
]
),
)

table.append(data_v1)

with table.update_schema() as update:
update.add_column("new_column", StringType())

table = catalog.load_table("default.test_schema_evolution_projection")

data_v2 = pa.Table.from_pylist(
[
{"partition_date": date(2024, 1, 2), "id": 3, "name": "Charlie", "value": 300, "new_column": "new1"},
{"partition_date": date(2024, 1, 2), "id": 4, "name": "David", "value": 400, "new_column": "new2"},
],
schema=pa.schema(
[
("partition_date", pa.date32()),
("id", pa.int32()),
("name", pa.string()),
("value", pa.int32()),
("new_column", pa.string()),
]
),
)

table.append(data_v2)

result = table.scan(selected_fields=("id", "name", "value", "new_column")).to_arrow()

assert set(result.schema.names) == {"id", "name", "value", "new_column"}
assert result.num_rows == 4
result_sorted = result.sort_by("name")
assert result_sorted["name"].to_pylist() == ["Alice", "Bob", "Charlie", "David"]
assert result_sorted["new_column"].to_pylist() == [None, None, "new1", "new2"]