Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 45 additions & 21 deletions pyiceberg/expressions/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,8 @@ def visit_less_than(self, term: BoundTerm, literal: LiteralValue) -> bool:
if not isinstance(field.field_type, PrimitiveType):
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound_bytes = self.lower_bounds.get(field_id)
if lower_bound_bytes is not None:
lower_bound = from_bytes(field.field_type, lower_bound_bytes)

if self._is_nan(lower_bound):
Expand All @@ -1263,7 +1264,8 @@ def visit_less_than_or_equal(self, term: BoundTerm, literal: LiteralValue) -> bo
if not isinstance(field.field_type, PrimitiveType):
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound_bytes = self.lower_bounds.get(field_id)
if lower_bound_bytes is not None:
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
if self._is_nan(lower_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
Expand All @@ -1284,7 +1286,8 @@ def visit_greater_than(self, term: BoundTerm, literal: LiteralValue) -> bool:
if not isinstance(field.field_type, PrimitiveType):
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound_bytes = self.upper_bounds.get(field_id)
if upper_bound_bytes is not None:
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
if upper_bound <= literal.value:
if self._is_nan(upper_bound):
Expand All @@ -1305,7 +1308,8 @@ def visit_greater_than_or_equal(self, term: BoundTerm, literal: LiteralValue) ->
if not isinstance(field.field_type, PrimitiveType):
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound_bytes = self.upper_bounds.get(field_id)
if upper_bound_bytes is not None:
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
if upper_bound < literal.value:
if self._is_nan(upper_bound):
Expand All @@ -1326,7 +1330,8 @@ def visit_equal(self, term: BoundTerm, literal: LiteralValue) -> bool:
if not isinstance(field.field_type, PrimitiveType):
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound_bytes = self.lower_bounds.get(field_id)
if lower_bound_bytes is not None:
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
if self._is_nan(lower_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
Expand All @@ -1335,7 +1340,8 @@ def visit_equal(self, term: BoundTerm, literal: LiteralValue) -> bool:
if lower_bound > literal.value:
return ROWS_CANNOT_MATCH

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound_bytes = self.upper_bounds.get(field_id)
if upper_bound_bytes is not None:
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
if self._is_nan(upper_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
Expand Down Expand Up @@ -1363,7 +1369,8 @@ def visit_in(self, term: BoundTerm, literals: set[L]) -> bool:
if not isinstance(field.field_type, PrimitiveType):
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound_bytes = self.lower_bounds.get(field_id)
if lower_bound_bytes is not None:
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
if self._is_nan(lower_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
Expand All @@ -1373,7 +1380,8 @@ def visit_in(self, term: BoundTerm, literals: set[L]) -> bool:
if len(literals) == 0:
return ROWS_CANNOT_MATCH

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound_bytes = self.upper_bounds.get(field_id)
if upper_bound_bytes is not None:
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
# this is different from Java, here NaN is always larger
if self._is_nan(upper_bound):
Expand Down Expand Up @@ -1403,14 +1411,16 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
prefix = str(literal.value)
len_prefix = len(prefix)

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound_bytes = self.lower_bounds.get(field_id)
if lower_bound_bytes is not None:
lower_bound = str(from_bytes(field.field_type, lower_bound_bytes))

# truncate lower bound so that its length is not greater than the length of prefix
if lower_bound and lower_bound[:len_prefix] > prefix:
return ROWS_CANNOT_MATCH

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound_bytes = self.upper_bounds.get(field_id)
if upper_bound_bytes is not None:
upper_bound = str(from_bytes(field.field_type, upper_bound_bytes))

# truncate upper bound so that its length is not greater than the length of prefix
Expand All @@ -1434,7 +1444,9 @@ def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:

# not_starts_with will match unless all values must start with the prefix. This happens when
# the lower and upper bounds both start with the prefix.
if (lower_bound_bytes := self.lower_bounds.get(field_id)) and (upper_bound_bytes := self.upper_bounds.get(field_id)):
lower_bound_bytes = self.lower_bounds.get(field_id)
upper_bound_bytes = self.upper_bounds.get(field_id)
if lower_bound_bytes is not None and upper_bound_bytes is not None:
lower_bound = str(from_bytes(field.field_type, lower_bound_bytes))
upper_bound = str(from_bytes(field.field_type, upper_bound_bytes))

Expand Down Expand Up @@ -1558,7 +1570,8 @@ def visit_less_than(self, term: BoundTerm, literal: LiteralValue) -> bool:
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if upper_bytes := self.upper_bounds.get(field_id):
upper_bytes = self.upper_bounds.get(field_id)
if upper_bytes is not None:
field = self._get_field(field_id)
upper = _from_byte_buffer(field.field_type, upper_bytes)

Expand All @@ -1575,7 +1588,8 @@ def visit_less_than_or_equal(self, term: BoundTerm, literal: LiteralValue) -> bo
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if upper_bytes := self.upper_bounds.get(field_id):
upper_bytes = self.upper_bounds.get(field_id)
if upper_bytes is not None:
field = self._get_field(field_id)
upper = _from_byte_buffer(field.field_type, upper_bytes)

Expand All @@ -1592,7 +1606,8 @@ def visit_greater_than(self, term: BoundTerm, literal: LiteralValue) -> bool:
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if lower_bytes := self.lower_bounds.get(field_id):
lower_bytes = self.lower_bounds.get(field_id)
if lower_bytes is not None:
field = self._get_field(field_id)
lower = _from_byte_buffer(field.field_type, lower_bytes)

Expand All @@ -1613,7 +1628,8 @@ def visit_greater_than_or_equal(self, term: BoundTerm, literal: LiteralValue) ->
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if lower_bytes := self.lower_bounds.get(field_id):
lower_bytes = self.lower_bounds.get(field_id)
if lower_bytes is not None:
field = self._get_field(field_id)
lower = _from_byte_buffer(field.field_type, lower_bytes)

Expand All @@ -1634,7 +1650,9 @@ def visit_equal(self, term: BoundTerm, literal: LiteralValue) -> bool:
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)):
lower_bytes = self.lower_bounds.get(field_id)
upper_bytes = self.upper_bounds.get(field_id)
if lower_bytes is not None and upper_bytes is not None:
field = self._get_field(field_id)
lower = _from_byte_buffer(field.field_type, lower_bytes)
upper = _from_byte_buffer(field.field_type, upper_bytes)
Expand All @@ -1655,7 +1673,8 @@ def visit_not_equal(self, term: BoundTerm, literal: LiteralValue) -> bool:

field = self._get_field(field_id)

if lower_bytes := self.lower_bounds.get(field_id):
lower_bytes = self.lower_bounds.get(field_id)
if lower_bytes is not None:
lower = _from_byte_buffer(field.field_type, lower_bytes)

if self._is_nan(lower):
Expand All @@ -1666,7 +1685,8 @@ def visit_not_equal(self, term: BoundTerm, literal: LiteralValue) -> bool:
if lower > literal.value:
return ROWS_MUST_MATCH

if upper_bytes := self.upper_bounds.get(field_id):
upper_bytes = self.upper_bounds.get(field_id)
if upper_bytes is not None:
upper = _from_byte_buffer(field.field_type, upper_bytes)

if upper < literal.value:
Expand All @@ -1682,7 +1702,9 @@ def visit_in(self, term: BoundTerm, literals: set[L]) -> bool:

field = self._get_field(field_id)

if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)):
lower_bytes = self.lower_bounds.get(field_id)
upper_bytes = self.upper_bounds.get(field_id)
if lower_bytes is not None and upper_bytes is not None:
# similar to the implementation in eq, first check if the lower bound is in the set
lower = _from_byte_buffer(field.field_type, lower_bytes)
if lower not in literals:
Expand Down Expand Up @@ -1711,7 +1733,8 @@ def visit_not_in(self, term: BoundTerm, literals: set[L]) -> bool:

field = self._get_field(field_id)

if lower_bytes := self.lower_bounds.get(field_id):
lower_bytes = self.lower_bounds.get(field_id)
if lower_bytes is not None:
lower = _from_byte_buffer(field.field_type, lower_bytes)

if self._is_nan(lower):
Expand All @@ -1723,7 +1746,8 @@ def visit_not_in(self, term: BoundTerm, literals: set[L]) -> bool:
if len(literals) == 0:
return ROWS_MUST_MATCH

if upper_bytes := self.upper_bounds.get(field_id):
upper_bytes = self.upper_bounds.get(field_id)
if upper_bytes is not None:
upper = _from_byte_buffer(field.field_type, upper_bytes)

literals = {val for val in literals if upper >= val}
Expand Down
6 changes: 4 additions & 2 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,7 +1605,8 @@ def _get_column_projection_values(
for field_id in project_schema_diff:
for partition_field in partition_spec.fields_by_source_id(field_id):
if isinstance(partition_field.transform, IdentityTransform):
if partition_value := accessors[partition_field.field_id].get(file.partition):
partition_value = accessors[partition_field.field_id].get(file.partition)
if partition_value is not None:
projected_missing_fields[field_id] = partition_value

return projected_missing_fields
Expand Down Expand Up @@ -2010,7 +2011,8 @@ def struct(
elif field.optional or field.initial_default is not None:
# When an optional field is added, or when a required field with a non-null initial default is added
arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)
if projected_value := self._projected_missing_fields.get(field.field_id):
projected_value = self._projected_missing_fields.get(field.field_id)
if projected_value is not None:
field_arrays.append(pa.repeat(pa.scalar(projected_value, type=arrow_type), len(struct_array)))
elif field.initial_default is None:
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))
Expand Down
20 changes: 8 additions & 12 deletions pyiceberg/table/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
ALWAYS_TRUE = AlwaysTrue()


def _readable_bound(field_type: PrimitiveType, bound: bytes | None) -> Any | None:
return from_bytes(field_type, bound) if bound is not None else None


class InspectTable:
tbl: Table

Expand Down Expand Up @@ -180,12 +184,8 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType:
"null_value_count": null_value_counts.get(field.field_id),
"nan_value_count": nan_value_counts.get(field.field_id),
# Makes them readable
"lower_bound": from_bytes(field.field_type, lower_bound)
if (lower_bound := lower_bounds.get(field.field_id))
else None,
"upper_bound": from_bytes(field.field_type, upper_bound)
if (upper_bound := upper_bounds.get(field.field_id))
else None,
"lower_bound": _readable_bound(field.field_type, lower_bounds.get(field.field_id)),
"upper_bound": _readable_bound(field.field_type, upper_bounds.get(field.field_id)),
}
for field in self.tbl.metadata.schema().fields
}
Expand Down Expand Up @@ -570,12 +570,8 @@ def _get_files_from_manifest(
"value_count": value_counts.get(field.field_id),
"null_value_count": null_value_counts.get(field.field_id),
"nan_value_count": nan_value_counts.get(field.field_id),
"lower_bound": from_bytes(field.field_type, lower_bound)
if (lower_bound := lower_bounds.get(field.field_id))
else None,
"upper_bound": from_bytes(field.field_type, upper_bound)
if (upper_bound := upper_bounds.get(field.field_id))
else None,
"lower_bound": _readable_bound(field.field_type, lower_bounds.get(field.field_id)),
"upper_bound": _readable_bound(field.field_type, upper_bounds.get(field.field_id)),
}
for field in self.tbl.metadata.schema().fields
}
Expand Down
56 changes: 56 additions & 0 deletions tests/expressions/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,62 @@ def test_string_starts_with(
# assert not should_read, "Should not read: range doesn't match"


def test_inclusive_metrics_evaluator_uses_empty_byte_lower_bound() -> None:
schema = Schema(NestedField(1, "empty_string", StringType(), required=True))
data_file = DataFile.from_args(
file_path="file.parquet",
file_format=FileFormat.PARQUET,
partition={},
record_count=10,
file_size_in_bytes=1,
value_counts={1: 10},
null_value_counts={1: 0},
nan_value_counts=None,
lower_bounds={1: to_bytes(StringType(), "")},
upper_bounds={1: to_bytes(StringType(), "")},
)

# Lower-bound branch: LessThan reads lower_bound only.
should_read = _InclusiveMetricsEvaluator(schema, LessThan("empty_string", "")).eval(data_file)
assert not should_read, "Should not read: lower bound is present and equal to the literal"

# Upper-bound branch: GreaterThan reads upper_bound only.
should_read = _InclusiveMetricsEvaluator(schema, GreaterThan("empty_string", "abc")).eval(data_file)
assert not should_read, "Should not read: upper bound '' is not greater than 'abc'"

# Both-bounds branch: EqualTo reads lower_bound and upper_bound.
should_read = _InclusiveMetricsEvaluator(schema, EqualTo("empty_string", "abc")).eval(data_file)
assert not should_read, "Should not read: 'abc' falls outside ['', '']"


def test_strict_metrics_evaluator_uses_empty_byte_bounds() -> None:
schema = Schema(NestedField(1, "empty_string", StringType(), required=True))
data_file = DataFile.from_args(
file_path="file.parquet",
file_format=FileFormat.PARQUET,
partition={},
record_count=10,
file_size_in_bytes=1,
value_counts={1: 10},
null_value_counts={1: 0},
nan_value_counts=None,
lower_bounds={1: to_bytes(StringType(), "")},
upper_bounds={1: to_bytes(StringType(), "")},
)

# Both-bounds branch: EqualTo reads lower_bound and upper_bound.
should_read = _StrictMetricsEvaluator(schema, EqualTo("empty_string", "")).eval(data_file)
assert should_read, "Should match: lower and upper bounds are present and equal to the literal"

# Upper-bound branch: LessThan reads upper_bound only.
should_read = _StrictMetricsEvaluator(schema, LessThan("empty_string", "a")).eval(data_file)
assert should_read, "Should match: upper bound '' is strictly less than 'a'"

# Both-bounds branch: NotEqualTo reads lower_bound and upper_bound.
should_read = _StrictMetricsEvaluator(schema, NotEqualTo("empty_string", "abc")).eval(data_file)
assert should_read, "Should match: 'abc' falls outside ['', '']"


def test_string_not_starts_with(
schema_data_file: Schema, data_file: DataFile, data_file_2: DataFile, data_file_3: DataFile, data_file_4: DataFile
) -> None:
Expand Down
Loading