diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index d0824cc315..883a042d54 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -33,7 +33,7 @@ ) from typing import Literal as TypingLiteral -from pydantic import Field +from pydantic import ConfigDict, Field, field_serializer, field_validator from pyiceberg.expressions.literals import ( AboveMax, @@ -52,8 +52,14 @@ ConfigDict = dict -def _to_unbound_term(term: Union[str, UnboundTerm[Any]]) -> UnboundTerm[Any]: - return Reference(term) if isinstance(term, str) else term +def _to_unbound_term(term: Union[str, UnboundTerm[Any], BoundReference[Any]]) -> UnboundTerm[Any]: + if isinstance(term, str): + return Reference(term) + if isinstance(term, UnboundTerm): + return term + if isinstance(term, BoundReference): + return Reference(term.field.name) + raise ValueError(f"Expected UnboundTerm | BoundReference | str, got {type(term).__name__}") def _to_literal_set(values: Union[Iterable[L], Iterable[Literal[L]]]) -> Set[Literal[L]]: @@ -743,12 +749,33 @@ def as_bound(self) -> Type[BoundNotIn[L]]: return BoundNotIn[L] -class LiteralPredicate(UnboundPredicate[L], ABC): - literal: Literal[L] +class LiteralPredicate(IcebergBaseModel, UnboundPredicate[L], ABC): + type: TypingLiteral["lt", "lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type") + term: UnboundTerm[L] + literal: Literal[L] = Field(serialization_alias="value") + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def __init__(self, *args: Any, **kwargs: Any) -> None: + if args: + if len(args) != 2: + raise TypeError("Expected (term, literal)") + kwargs = {"term": args[0], "literal": args[1], **kwargs} + super().__init__(**kwargs) + + @field_validator("term", mode="before") + @classmethod + def _coerce_term(cls, v: Any) -> UnboundTerm[Any]: + return _to_unbound_term(v) - def __init__(self, term: Union[str, UnboundTerm[Any]], literal: Union[L, Literal[L]]): # pylint: disable=W0621 - super().__init__(term) - self.literal = _to_literal(literal) # pylint: disable=W0621 + @field_validator("literal", mode="before") + @classmethod + def _coerce_literal(cls, v: Union[L, Literal[L]]) -> Literal[L]: + return _to_literal(v) + + @field_serializer("literal") + def ser_literal(self, literal: Literal[L]) -> str: + return "Any" def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate[L]: bound_term = self.term.bind(schema, case_sensitive) @@ -773,6 +800,10 @@ def __eq__(self, other: Any) -> bool: return self.term == other.term and self.literal == other.literal return False + def __str__(self) -> str: + """Return the string representation of the LiteralPredicate class.""" + return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})" + def __repr__(self) -> str: """Return the string representation of the LiteralPredicate class.""" return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})" @@ -886,6 +917,8 @@ def as_unbound(self) -> Type[NotStartsWith[L]]: class EqualTo(LiteralPredicate[L]): + type: TypingLiteral["eq"] = Field(default="eq", alias="type") + def __invert__(self) -> NotEqualTo[L]: """Transform the Expression into its negated version.""" return NotEqualTo[L](self.term, self.literal) @@ -896,6 +929,8 @@ def as_bound(self) -> Type[BoundEqualTo[L]]: class NotEqualTo(LiteralPredicate[L]): + type: TypingLiteral["not-eq"] = Field(default="not-eq", alias="type") + def __invert__(self) -> EqualTo[L]: """Transform the Expression into its negated version.""" return EqualTo[L](self.term, self.literal) @@ -906,6 +941,8 @@ def as_bound(self) -> Type[BoundNotEqualTo[L]]: class LessThan(LiteralPredicate[L]): + type: TypingLiteral["lt"] = Field(default="lt", alias="type") + def __invert__(self) -> GreaterThanOrEqual[L]: """Transform the Expression into its negated version.""" return GreaterThanOrEqual[L](self.term, self.literal) @@ -916,6 +953,8 @@ def as_bound(self) -> Type[BoundLessThan[L]]: class GreaterThanOrEqual(LiteralPredicate[L]): + type: TypingLiteral["gt-eq"] = Field(default="gt-eq", alias="type") + def __invert__(self) -> LessThan[L]: """Transform the Expression into its negated version.""" return LessThan[L](self.term, self.literal) @@ -926,6 +965,8 @@ def as_bound(self) -> Type[BoundGreaterThanOrEqual[L]]: class GreaterThan(LiteralPredicate[L]): + type: TypingLiteral["gt"] = Field(default="gt", alias="type") + def __invert__(self) -> LessThanOrEqual[L]: """Transform the Expression into its negated version.""" return LessThanOrEqual[L](self.term, self.literal) @@ -936,6 +977,8 @@ def as_bound(self) -> Type[BoundGreaterThan[L]]: class LessThanOrEqual(LiteralPredicate[L]): + type: TypingLiteral["lt-eq"] = Field(default="lt-eq", alias="type") + def __invert__(self) -> GreaterThan[L]: """Transform the Expression into its negated version.""" return GreaterThan[L](self.term, self.literal) @@ -946,6 +989,8 @@ def as_bound(self) -> Type[BoundLessThanOrEqual[L]]: class StartsWith(LiteralPredicate[L]): + type: TypingLiteral["starts-with"] = Field(default="starts-with", alias="type") + def __invert__(self) -> NotStartsWith[L]: """Transform the Expression into its negated version.""" return NotStartsWith[L](self.term, self.literal) @@ -956,6 +1001,8 @@ def as_bound(self) -> Type[BoundStartsWith[L]]: class NotStartsWith(LiteralPredicate[L]): + type: TypingLiteral["not-starts-with"] = Field(default="not-starts-with", alias="type") + def __invert__(self) -> StartsWith[L]: """Transform the Expression into its negated version.""" return StartsWith[L](self.term, self.literal) diff --git a/tests/expressions/test_evaluator.py b/tests/expressions/test_evaluator.py index cfc32d9b6b..7b15099105 100644 --- a/tests/expressions/test_evaluator.py +++ b/tests/expressions/test_evaluator.py @@ -683,7 +683,7 @@ def data_file_nan() -> DataFile: def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_file_nan: Schema, data_file_nan: DataFile) -> None: - for operator in [LessThan, LessThanOrEqual]: # type: ignore + for operator in [LessThan, LessThanOrEqual]: should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type] assert not should_read, "Should not match: all nan column doesn't contain number" @@ -711,7 +711,7 @@ def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_f def test_inclusive_metrics_evaluator_greater_than_and_greater_than_equal( schema_data_file_nan: Schema, data_file_nan: DataFile ) -> None: - for operator in [GreaterThan, GreaterThanOrEqual]: # type: ignore + for operator in [GreaterThan, GreaterThanOrEqual]: should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type] assert not should_read, "Should not match: all nan column doesn't contain number" diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 5a0c8c9241..16ceb5ad2a 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -55,8 +55,10 @@ NotIn, NotNaN, NotNull, + NotStartsWith, Or, Reference, + StartsWith, UnboundPredicate, ) from pyiceberg.expressions.literals import Literal, literal @@ -427,14 +429,14 @@ def test_bound_less_than_or_equal_invert(table_schema_simple: Schema) -> None: def test_not_equal_to_invert() -> None: bound = NotEqualTo( - term=BoundReference( # type: ignore + term=BoundReference( field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False), accessor=Accessor(position=0, inner=None), ), literal="hello", ) assert ~bound == EqualTo( - term=BoundReference( # type: ignore + term=BoundReference( field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False), accessor=Accessor(position=0, inner=None), ), @@ -444,14 +446,14 @@ def test_not_equal_to_invert() -> None: def test_greater_than_or_equal_invert() -> None: bound = GreaterThanOrEqual( - term=BoundReference( # type: ignore + term=BoundReference( field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False), accessor=Accessor(position=0, inner=None), ), literal="hello", ) assert ~bound == LessThan( - term=BoundReference( # type: ignore + term=BoundReference( field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False), accessor=Accessor(position=0, inner=None), ), @@ -461,14 +463,14 @@ def test_greater_than_or_equal_invert() -> None: def test_less_than_or_equal_invert() -> None: bound = LessThanOrEqual( - term=BoundReference( # type: ignore + term=BoundReference( field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False), accessor=Accessor(position=0, inner=None), ), literal="hello", ) assert ~bound == GreaterThan( - term=BoundReference( # type: ignore + term=BoundReference( field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False), accessor=Accessor(position=0, inner=None), ), @@ -933,6 +935,7 @@ def test_bound_less_than_or_equal(term: BoundReference[Any]) -> None: def test_equal_to() -> None: equal_to = EqualTo(Reference("a"), literal("a")) + assert equal_to.model_dump_json() == '{"term":"a","type":"eq","value":"Any"}' assert str(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))" assert repr(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))" assert equal_to == eval(repr(equal_to)) @@ -941,6 +944,7 @@ def test_equal_to() -> None: def test_not_equal_to() -> None: not_equal_to = NotEqualTo(Reference("a"), literal("a")) + assert not_equal_to.model_dump_json() == '{"term":"a","type":"not-eq","value":"Any"}' assert str(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))" assert repr(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))" assert not_equal_to == eval(repr(not_equal_to)) @@ -949,6 +953,7 @@ def test_not_equal_to() -> None: def test_greater_than_or_equal_to() -> None: greater_than_or_equal_to = GreaterThanOrEqual(Reference("a"), literal("a")) + assert greater_than_or_equal_to.model_dump_json() == '{"term":"a","type":"gt-eq","value":"Any"}' assert str(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))" assert repr(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))" assert greater_than_or_equal_to == eval(repr(greater_than_or_equal_to)) @@ -957,6 +962,7 @@ def test_greater_than_or_equal_to() -> None: def test_greater_than() -> None: greater_than = GreaterThan(Reference("a"), literal("a")) + assert greater_than.model_dump_json() == '{"term":"a","type":"gt","value":"Any"}' assert str(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))" assert repr(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))" assert greater_than == eval(repr(greater_than)) @@ -965,6 +971,7 @@ def test_greater_than() -> None: def test_less_than() -> None: less_than = LessThan(Reference("a"), literal("a")) + assert less_than.model_dump_json() == '{"term":"a","type":"lt","value":"Any"}' assert str(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))" assert repr(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))" assert less_than == eval(repr(less_than)) @@ -973,12 +980,23 @@ def test_less_than() -> None: def test_less_than_or_equal() -> None: less_than_or_equal = LessThanOrEqual(Reference("a"), literal("a")) + assert less_than_or_equal.model_dump_json() == '{"term":"a","type":"lt-eq","value":"Any"}' assert str(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))" assert repr(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))" assert less_than_or_equal == eval(repr(less_than_or_equal)) assert less_than_or_equal == pickle.loads(pickle.dumps(less_than_or_equal)) +def test_starts_with() -> None: + starts_with = StartsWith(Reference("a"), literal("a")) + assert starts_with.model_dump_json() == '{"term":"a","type":"starts-with","value":"Any"}' + + +def test_not_starts_with() -> None: + not_starts_with = NotStartsWith(Reference("a"), literal("a")) + assert not_starts_with.model_dump_json() == '{"term":"a","type":"not-starts-with","value":"Any"}' + + def test_bound_reference_eval(table_schema_simple: Schema) -> None: """Test creating a BoundReference and evaluating it on a StructProtocol""" struct = Record("foovalue", 123, True)