From 0e5dc2e88cfe0118bc30d87f4e4ad16496907ff9 Mon Sep 17 00:00:00 2001 From: Yftach Zur Date: Wed, 15 Oct 2025 10:02:54 +0100 Subject: [PATCH 1/3] Fix: Support nested struct field filtering with PyArrow (#953) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes filtering on nested struct fields when using PyArrow for scan operations. ## Problem When filtering on nested struct fields (e.g., `mazeMetadata.run_id == 'value'`), PyArrow would fail with: ``` ArrowInvalid: No match for FieldRef.Name(run_id) in ... ``` The issue occurred because PyArrow requires nested field references as tuples (e.g., `("parent", "child")`) rather than dotted strings (e.g., `"parent.child"`). ## Solution 1. Modified `_ConvertToArrowExpression` to accept an optional `Schema` parameter 2. Added `_get_field_name()` method that converts dotted field paths to tuples for nested struct fields 3. Updated `expression_to_pyarrow()` to accept and pass the schema parameter 4. Updated all call sites to pass the schema when available ## Changes - `pyiceberg/io/pyarrow.py`: - Modified `_ConvertToArrowExpression` class to handle nested field paths - Updated `expression_to_pyarrow()` signature to accept schema - Updated `_expression_to_complementary_pyarrow()` signature - `pyiceberg/table/__init__.py`: - Updated call to `_expression_to_complementary_pyarrow()` to pass schema - Tests: - Added `test_ref_binding_nested_struct_field()` for comprehensive nested field testing - Enhanced `test_nested_fields()` with issue #953 scenarios ## Example ```python # Now works correctly: table.scan(row_filter="mazeMetadata.run_id == 'abc123'").to_polars() ``` The fix converts the field reference from: - ❌ `FieldRef.Name(run_id)` (fails - field not found) - ✅ `FieldRef.Nested(FieldRef.Name(mazeMetadata) FieldRef.Name(run_id))` (works!) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- pyiceberg/io/pyarrow.py | 80 ++++++++++++++++++++------- pyiceberg/table/__init__.py | 2 +- tests/expressions/test_expressions.py | 52 +++++++++++++++++ tests/expressions/test_parser.py | 3 + 4 files changed, 117 insertions(+), 20 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index b6ad5659b1..f71369fbea 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -810,51 +810,83 @@ def _convert_scalar(value: Any, iceberg_type: IcebergType) -> pa.scalar: class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]): + """Convert Iceberg bound expressions to PyArrow expressions. + + Args: + schema: Optional Iceberg schema to resolve full field paths for nested fields. + If not provided, only the field name will be used (not dotted path). + """ + + _schema: Optional[Schema] + + def __init__(self, schema: Optional[Schema] = None): + self._schema = schema + + def _get_field_name(self, term: BoundTerm[Any]) -> Union[str, Tuple[str, ...]]: + """Get the field name or nested field path for a bound term. + + For nested struct fields, returns a tuple of field names (e.g., ("mazeMetadata", "run_id")). + For top-level fields, returns just the field name as a string. + + PyArrow requires nested field references as tuples, not dotted strings. + """ + if self._schema is not None: + # Use the schema to get the full dotted path for nested fields + full_name = self._schema.find_column_name(term.ref().field.field_id) + if full_name is not None: + # If the field name contains dots, it's a nested field + # Convert "parent.child" to ("parent", "child") for PyArrow + if '.' in full_name: + return tuple(full_name.split('.')) + return full_name + # Fallback to just the field name if schema is not available + return term.ref().field.name + def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression: pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type)) - return pc.field(term.ref().field.name).isin(pyarrow_literals) + return pc.field(self._get_field_name(term)).isin(pyarrow_literals) def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression: pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type)) - return ~pc.field(term.ref().field.name).isin(pyarrow_literals) + return ~pc.field(self._get_field_name(term)).isin(pyarrow_literals) def visit_is_nan(self, term: BoundTerm[Any]) -> pc.Expression: - ref = pc.field(term.ref().field.name) + ref = pc.field(self._get_field_name(term)) return pc.is_nan(ref) def visit_not_nan(self, term: BoundTerm[Any]) -> pc.Expression: - ref = pc.field(term.ref().field.name) + ref = pc.field(self._get_field_name(term)) return ~pc.is_nan(ref) def visit_is_null(self, term: BoundTerm[Any]) -> pc.Expression: - return pc.field(term.ref().field.name).is_null(nan_is_null=False) + return pc.field(self._get_field_name(term)).is_null(nan_is_null=False) def visit_not_null(self, term: BoundTerm[Any]) -> pc.Expression: - return pc.field(term.ref().field.name).is_valid() + return pc.field(self._get_field_name(term)).is_valid() def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: - return pc.field(term.ref().field.name) == _convert_scalar(literal.value, term.ref().field.field_type) + return pc.field(self._get_field_name(term)) == _convert_scalar(literal.value, term.ref().field.field_type) def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: - return pc.field(term.ref().field.name) != _convert_scalar(literal.value, term.ref().field.field_type) + return pc.field(self._get_field_name(term)) != _convert_scalar(literal.value, term.ref().field.field_type) def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: - return pc.field(term.ref().field.name) >= _convert_scalar(literal.value, term.ref().field.field_type) + return pc.field(self._get_field_name(term)) >= _convert_scalar(literal.value, term.ref().field.field_type) def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: - return pc.field(term.ref().field.name) > _convert_scalar(literal.value, term.ref().field.field_type) + return pc.field(self._get_field_name(term)) > _convert_scalar(literal.value, term.ref().field.field_type) def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: - return pc.field(term.ref().field.name) < _convert_scalar(literal.value, term.ref().field.field_type) + return pc.field(self._get_field_name(term)) < _convert_scalar(literal.value, term.ref().field.field_type) def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: - return pc.field(term.ref().field.name) <= _convert_scalar(literal.value, term.ref().field.field_type) + return pc.field(self._get_field_name(term)) <= _convert_scalar(literal.value, term.ref().field.field_type) def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: - return pc.starts_with(pc.field(term.ref().field.name), literal.value) + return pc.starts_with(pc.field(self._get_field_name(term)), literal.value) def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression: - return ~pc.starts_with(pc.field(term.ref().field.name), literal.value) + return ~pc.starts_with(pc.field(self._get_field_name(term)), literal.value) def visit_true(self) -> pc.Expression: return pc.scalar(True) @@ -990,11 +1022,21 @@ def collect( boolean_expression_visit(expr, self) -def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression: - return boolean_expression_visit(expr, _ConvertToArrowExpression()) +def expression_to_pyarrow(expr: BooleanExpression, schema: Optional[Schema] = None) -> pc.Expression: + """Convert an Iceberg boolean expression to a PyArrow expression. + + Args: + expr: The Iceberg boolean expression to convert. + schema: Optional Iceberg schema to resolve full field paths for nested fields. + If provided, nested struct fields will use dotted paths (e.g., "parent.child"). + + Returns: + A PyArrow compute expression. + """ + return boolean_expression_visit(expr, _ConvertToArrowExpression(schema)) -def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expression: +def _expression_to_complementary_pyarrow(expr: BooleanExpression, schema: Optional[Schema] = None) -> pc.Expression: """Complementary filter conversion function of expression_to_pyarrow. Could not use expression_to_pyarrow(Not(expr)) to achieve this complementary effect because ~ in pyarrow.compute.Expression does not handle null. @@ -1015,7 +1057,7 @@ def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expressi preserve_expr = Or(preserve_expr, BoundIsNull(term=term)) for term in nan_unmentioned_bound_terms: preserve_expr = Or(preserve_expr, BoundIsNaN(term=term)) - return expression_to_pyarrow(preserve_expr) + return expression_to_pyarrow(preserve_expr, schema) @lru_cache @@ -1550,7 +1592,7 @@ def _task_to_record_batches( bound_row_filter, file_schema, case_sensitive=case_sensitive, projected_field_values=projected_missing_fields ) bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) - pyarrow_filter = expression_to_pyarrow(bound_file_filter) + pyarrow_filter = expression_to_pyarrow(bound_file_filter, file_schema) file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 972efc8c47..337acf5498 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -677,7 +677,7 @@ def delete( # Check if there are any files that require an actual rewrite of a data file if delete_snapshot.rewrites_needed is True: bound_delete_filter = bind(self.table_metadata.schema(), delete_filter, case_sensitive) - preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter) + preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter, self.table_metadata.schema()) file_scan = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive) if branch is not None: diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 5a0c8c9241..026b59ef49 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -248,6 +248,58 @@ def test_ref_binding_case_insensitive_failure(table_schema_simple: Schema) -> No ref.bind(table_schema_simple, case_sensitive=False) +def test_ref_binding_nested_struct_field() -> None: + """Test binding references to nested struct fields (issue #953).""" + schema = Schema( + NestedField(field_id=1, name="age", field_type=IntegerType(), required=True), + NestedField( + field_id=2, + name="employment", + field_type=StructType( + NestedField(field_id=3, name="status", field_type=StringType(), required=False), + NestedField(field_id=4, name="company", field_type=StringType(), required=False), + ), + required=False, + ), + NestedField( + field_id=5, + name="contact", + field_type=StructType( + NestedField(field_id=6, name="email", field_type=StringType(), required=False), + ), + required=False, + ), + schema_id=1, + ) + + # Test that nested field names are in the index + assert "employment.status" in schema._name_to_id + assert "employment.company" in schema._name_to_id + assert "contact.email" in schema._name_to_id + + # Test binding a reference to nested fields + ref = Reference("employment.status") + bound = ref.bind(schema, case_sensitive=True) + assert bound.field.field_id == 3 + assert bound.field.name == "status" + + # Test with different nested field + ref2 = Reference("contact.email") + bound2 = ref2.bind(schema, case_sensitive=True) + assert bound2.field.field_id == 6 + assert bound2.field.name == "email" + + # Test case-insensitive binding + ref3 = Reference("EMPLOYMENT.STATUS") + bound3 = ref3.bind(schema, case_sensitive=False) + assert bound3.field.field_id == 3 + + # Test that binding fails for non-existent nested field + ref4 = Reference("employment.department") + with pytest.raises(ValueError): + ref4.bind(schema, case_sensitive=True) + + def test_in_to_eq() -> None: assert In("x", (34.56,)) == EqualTo("x", 34.56) diff --git a/tests/expressions/test_parser.py b/tests/expressions/test_parser.py index 28d7cf110f..2f0c444dc4 100644 --- a/tests/expressions/test_parser.py +++ b/tests/expressions/test_parser.py @@ -225,6 +225,9 @@ def test_with_function() -> None: def test_nested_fields() -> None: assert EqualTo("foo.bar", "data") == parser.parse("foo.bar = 'data'") assert LessThan("location.x", DecimalLiteral(Decimal(52.00))) == parser.parse("location.x < 52.00") + # Test issue #953 scenario - nested struct field filtering + assert EqualTo("employment.status", "Employed") == parser.parse("employment.status = 'Employed'") + assert EqualTo("contact.email", "test@example.com") == parser.parse("contact.email = 'test@example.com'") def test_quoted_column_with_dots() -> None: From cedaa95926e3110665486cb835bae61ceb96395f Mon Sep 17 00:00:00 2001 From: Yftach Zur Date: Thu, 23 Oct 2025 13:30:20 +0200 Subject: [PATCH 2/3] Apply Ruff Formatting --- pyiceberg/io/pyarrow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index f71369fbea..caaef5c661 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -836,8 +836,8 @@ def _get_field_name(self, term: BoundTerm[Any]) -> Union[str, Tuple[str, ...]]: if full_name is not None: # If the field name contains dots, it's a nested field # Convert "parent.child" to ("parent", "child") for PyArrow - if '.' in full_name: - return tuple(full_name.split('.')) + if "." in full_name: + return tuple(full_name.split(".")) return full_name # Fallback to just the field name if schema is not available return term.ref().field.name From 8fe55e3fae560bd013c52ead9a5cfc46c2937d0e Mon Sep 17 00:00:00 2001 From: Yftach Zur Date: Sun, 26 Oct 2025 17:40:17 +0100 Subject: [PATCH 3/3] Remove unused type ignore --- tests/utils/test_manifest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_manifest.py b/tests/utils/test_manifest.py index 7c62b9564c..a5d5a6fefb 100644 --- a/tests/utils/test_manifest.py +++ b/tests/utils/test_manifest.py @@ -48,7 +48,7 @@ @pytest.fixture(autouse=True) def clear_global_manifests_cache() -> None: # Clear the global cache before each test - _manifests.cache_clear() # type: ignore + _manifests.cache_clear() def _verify_metadata_with_fastavro(avro_file: str, expected_metadata: Dict[str, str]) -> None: