diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index 71ee7cd98d..14ee9328c1 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -17,14 +17,14 @@ from __future__ import annotations -import builtins from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence from functools import cached_property -from typing import Any +from typing import Any, TypeAlias from typing import Literal as TypingLiteral -from pydantic import ConfigDict, Field +from pydantic import ConfigDict, Field, SerializeAsAny, model_validator +from pydantic_core.core_schema import ValidatorFunctionWrapHandler from pyiceberg.expressions.literals import AboveMax, BelowMin, Literal, literal from pyiceberg.schema import Accessor, Schema @@ -48,7 +48,7 @@ def _to_literal(value: L | Literal[L]) -> Literal[L]: return literal(value) -class BooleanExpression(ABC): +class BooleanExpression(IcebergBaseModel, ABC): """An expression that evaluates to a boolean.""" @abstractmethod @@ -69,6 +69,66 @@ def __or__(self, other: BooleanExpression) -> BooleanExpression: return Or(self, other) + @model_validator(mode="wrap") + @classmethod + def handle_primitive_type(cls, v: Any, handler: ValidatorFunctionWrapHandler) -> BooleanExpression: + """Apply custom deserialization logic before validation.""" + # Already a BooleanExpression? return as-is so we keep the concrete subclass. + if isinstance(v, BooleanExpression): + return v + + # Handle different input formats + if isinstance(v, bool): + return AlwaysTrue() if v is True else AlwaysFalse() + + if isinstance(v, dict) and (field_type := v.get("type")): + # Unary + if field_type == "is-null": + return IsNull(**v) + elif field_type == "not-null": + return NotNull(**v) + elif field_type == "is-nan": + return IsNaN(**v) + elif field_type == "not-nan": + return NotNaN(**v) + + # Literal + elif field_type == "lt": + return LessThan(**v) + elif field_type == "lt-eq": + return LessThanOrEqual(**v) + elif field_type == "gt": + return GreaterThan(**v) + elif field_type == "gt-eq": + return GreaterThanOrEqual(**v) + elif field_type == "eq": + return EqualTo(**v) + elif field_type == "not-eq": + return NotEqualTo(**v) + elif field_type == "starts-with": + return StartsWith(**v) + elif field_type == "not-starts-with": + return NotStartsWith(**v) + + # Set + elif field_type == "in": + return In(**v) + elif field_type == "not-in": + return NotIn(**v) + + # Other + elif field_type == "and": + return And(**v) + elif field_type == "or": + return Or(**v) + elif field_type == "not": + return Not(**v) + + return handler(v) + + +SerializableBooleanExpression: TypeAlias = SerializeAsAny["BooleanExpression"] + def _build_balanced_tree( operator_: Callable[[BooleanExpression, BooleanExpression], BooleanExpression], items: Sequence[BooleanExpression] @@ -237,20 +297,20 @@ def as_bound(self) -> type[BoundReference]: return BoundReference -class And(IcebergBaseModel, BooleanExpression): +class And(BooleanExpression): """AND operation expression - logical conjunction.""" model_config = ConfigDict(arbitrary_types_allowed=True) type: TypingLiteral["and"] = Field(default="and", alias="type") - left: BooleanExpression - right: BooleanExpression + left: SerializableBooleanExpression = Field() + right: SerializableBooleanExpression = Field() - def __init__(self, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression) -> None: + def __init__(self, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression, **_: Any) -> None: if isinstance(self, And) and not hasattr(self, "left") and not hasattr(self, "right"): super().__init__(left=left, right=right) - def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression) -> BooleanExpression: # type: ignore + def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression, **_: Any) -> BooleanExpression: if rest: return _build_balanced_tree(And, (left, right, *rest)) if left is AlwaysFalse() or right is AlwaysFalse(): @@ -260,8 +320,7 @@ def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: Boole elif right is AlwaysTrue(): return left else: - obj = super().__new__(cls) - return obj + return super().__new__(cls) def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the And class.""" @@ -285,20 +344,20 @@ def __getnewargs__(self) -> tuple[BooleanExpression, BooleanExpression]: return (self.left, self.right) -class Or(IcebergBaseModel, BooleanExpression): +class Or(BooleanExpression): """OR operation expression - logical disjunction.""" model_config = ConfigDict(arbitrary_types_allowed=True) type: TypingLiteral["or"] = Field(default="or", alias="type") - left: BooleanExpression - right: BooleanExpression + left: SerializableBooleanExpression = Field() + right: SerializableBooleanExpression = Field() - def __init__(self, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression) -> None: + def __init__(self, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression, **_: Any) -> None: if isinstance(self, Or) and not hasattr(self, "left") and not hasattr(self, "right"): super().__init__(left=left, right=right) - def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression) -> BooleanExpression: # type: ignore + def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression, **_: Any) -> BooleanExpression: if rest: return _build_balanced_tree(Or, (left, right, *rest)) if left is AlwaysTrue() or right is AlwaysTrue(): @@ -308,8 +367,7 @@ def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: Boole elif right is AlwaysFalse(): return left else: - obj = super().__new__(cls) - return obj + return super().__new__(cls) def __str__(self) -> str: """Return the string representation of the Or class.""" @@ -333,26 +391,26 @@ def __getnewargs__(self) -> tuple[BooleanExpression, BooleanExpression]: return (self.left, self.right) -class Not(IcebergBaseModel, BooleanExpression): +class Not(BooleanExpression): """NOT operation expression - logical negation.""" model_config = ConfigDict(arbitrary_types_allowed=True) type: TypingLiteral["not"] = Field(default="not") - child: BooleanExpression = Field() + child: SerializableBooleanExpression = Field() def __init__(self, child: BooleanExpression, **_: Any) -> None: super().__init__(child=child) - def __new__(cls, child: BooleanExpression, **_: Any) -> BooleanExpression: # type: ignore + def __new__(cls, child: BooleanExpression, **_: Any) -> BooleanExpression: if child is AlwaysTrue(): return AlwaysFalse() elif child is AlwaysFalse(): return AlwaysTrue() elif isinstance(child, Not): return child.child - obj = super().__new__(cls) - return obj + else: + return super().__new__(cls) def __str__(self) -> str: """Return the string representation of the Not class.""" @@ -412,10 +470,12 @@ def __repr__(self) -> str: class BoundPredicate(Bound, BooleanExpression, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) + term: BoundTerm - def __init__(self, term: BoundTerm): - self.term = term + def __init__(self, term: BoundTerm, **kwargs: Any) -> None: + super().__init__(term=term, **kwargs) def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the BoundPredicate class.""" @@ -423,16 +483,22 @@ def __eq__(self, other: Any) -> bool: return self.term == other.term return False + def __str__(self) -> str: + """Return the string representation of the BoundPredicate class.""" + return f"{self.__class__.__name__}(term={str(self.term)})" + @property @abstractmethod def as_unbound(self) -> type[UnboundPredicate]: ... class UnboundPredicate(Unbound, BooleanExpression, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) + term: UnboundTerm - def __init__(self, term: str | UnboundTerm): - self.term = _to_unbound_term(term) + def __init__(self, term: str | UnboundTerm, **kwargs: Any) -> None: + super().__init__(term=_to_unbound_term(term), **kwargs) def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the UnboundPredicate class.""" @@ -446,12 +512,12 @@ def bind(self, schema: Schema, case_sensitive: bool = True) -> BooleanExpression def as_bound(self) -> type[BoundPredicate]: ... -class UnaryPredicate(IcebergBaseModel, UnboundPredicate, ABC): - type: str +class UnaryPredicate(UnboundPredicate, ABC): + type: TypingLiteral["is-null", "not-null", "is-nan", "not-nan"] = Field() model_config = {"arbitrary_types_allowed": True} - def __init__(self, term: str | UnboundTerm): + def __init__(self, term: str | UnboundTerm, **_: Any) -> None: unbound = _to_unbound_term(term) super().__init__(term=unbound) @@ -462,7 +528,8 @@ def __str__(self) -> str: def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundUnaryPredicate: bound_term = self.term.bind(schema, case_sensitive) - return self.as_bound(bound_term) # type: ignore + bound_type = self.as_bound + return bound_type(bound_term) # type: ignore[misc] def __repr__(self) -> str: """Return the string representation of the UnaryPredicate class.""" @@ -488,7 +555,7 @@ def __getnewargs__(self) -> tuple[BoundTerm]: class BoundIsNull(BoundUnaryPredicate): - def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 + def __new__(cls, term: BoundTerm) -> BooleanExpression: # pylint: disable=W0221 if term.ref().field.required: return AlwaysFalse() return super().__new__(cls) @@ -503,7 +570,7 @@ def as_unbound(self) -> type[IsNull]: class BoundNotNull(BoundUnaryPredicate): - def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 + def __new__(cls, term: BoundTerm) -> BooleanExpression: # pylint: disable=W0221 if term.ref().field.required: return AlwaysTrue() return super().__new__(cls) @@ -518,31 +585,31 @@ def as_unbound(self) -> type[NotNull]: class IsNull(UnaryPredicate): - type: str = "is-null" + type: TypingLiteral["is-null"] = Field(default="is-null") def __invert__(self) -> NotNull: """Transform the Expression into its negated version.""" return NotNull(self.term) @property - def as_bound(self) -> builtins.type[BoundIsNull]: + def as_bound(self) -> type[BoundIsNull]: # type: ignore return BoundIsNull class NotNull(UnaryPredicate): - type: str = "not-null" + type: TypingLiteral["not-null"] = Field(default="not-null") def __invert__(self) -> IsNull: """Transform the Expression into its negated version.""" return IsNull(self.term) @property - def as_bound(self) -> builtins.type[BoundNotNull]: + def as_bound(self) -> type[BoundNotNull]: # type: ignore return BoundNotNull class BoundIsNaN(BoundUnaryPredicate): - def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 + def __new__(cls, term: BoundTerm) -> BooleanExpression: # pylint: disable=W0221 bound_type = term.ref().field.field_type if isinstance(bound_type, (FloatType, DoubleType)): return super().__new__(cls) @@ -558,7 +625,7 @@ def as_unbound(self) -> type[IsNaN]: class BoundNotNaN(BoundUnaryPredicate): - def __new__(cls, term: BoundTerm) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 + def __new__(cls, term: BoundTerm) -> BooleanExpression: # pylint: disable=W0221 bound_type = term.ref().field.field_type if isinstance(bound_type, (FloatType, DoubleType)): return super().__new__(cls) @@ -574,44 +641,51 @@ def as_unbound(self) -> type[NotNaN]: class IsNaN(UnaryPredicate): - type: str = "is-nan" + type: TypingLiteral["is-nan"] = Field(default="is-nan") def __invert__(self) -> NotNaN: """Transform the Expression into its negated version.""" return NotNaN(self.term) @property - def as_bound(self) -> builtins.type[BoundIsNaN]: + def as_bound(self) -> type[BoundIsNaN]: # type: ignore return BoundIsNaN class NotNaN(UnaryPredicate): - type: str = "not-nan" + type: TypingLiteral["not-nan"] = Field(default="not-nan") def __invert__(self) -> IsNaN: """Transform the Expression into its negated version.""" return IsNaN(self.term) @property - def as_bound(self) -> builtins.type[BoundNotNaN]: + def as_bound(self) -> type[BoundNotNaN]: # type: ignore return BoundNotNaN -class SetPredicate(IcebergBaseModel, UnboundPredicate, ABC): +class SetPredicate(UnboundPredicate, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) type: TypingLiteral["in", "not-in"] = Field(default="in") literals: set[LiteralValue] = Field(alias="values") - def __init__(self, term: str | UnboundTerm, literals: Iterable[Any] | Iterable[LiteralValue]): - literal_set = _to_literal_set(literals) - super().__init__(term=_to_unbound_term(term), values=literal_set) # type: ignore - object.__setattr__(self, "literals", literal_set) + def __init__( + self, term: str | UnboundTerm, literals: Iterable[Any] | Iterable[LiteralValue] | None = None, **kwargs: Any + ) -> None: + if literals is None and "values" in kwargs: + literals = kwargs["values"] + + if literals is None: + literal_set: set[LiteralValue] = set() + else: + literal_set = _to_literal_set(literals) + super().__init__(term=_to_unbound_term(term), values=literal_set) def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundSetPredicate: bound_term = self.term.bind(schema, case_sensitive) literal_set = self.literals - return self.as_bound(bound_term, {lit.to(bound_term.ref().field.field_type) for lit in literal_set}) + return self.as_bound(bound_term, {lit.to(bound_term.ref().field.field_type) for lit in literal_set}) # type: ignore def __str__(self) -> str: """Return the string representation of the SetPredicate class.""" @@ -633,16 +707,16 @@ def __getnewargs__(self) -> tuple[UnboundTerm, set[Any]]: @property @abstractmethod - def as_bound(self) -> builtins.type[BoundSetPredicate]: + def as_bound(self) -> type[BoundSetPredicate]: # type: ignore return BoundSetPredicate class BoundSetPredicate(BoundPredicate, ABC): literals: set[LiteralValue] - def __init__(self, term: BoundTerm, literals: set[LiteralValue]): - super().__init__(term) - self.literals = _to_literal_set(literals) # pylint: disable=W0621 + def __init__(self, term: BoundTerm, literals: set[LiteralValue]) -> None: + literal_set = _to_literal_set(literals) + super().__init__(term=term, literals=literal_set) @cached_property def value_set(self) -> set[Any]: @@ -672,7 +746,7 @@ def as_unbound(self) -> type[SetPredicate]: ... class BoundIn(BoundSetPredicate): - def __new__(cls, term: BoundTerm, literals: set[LiteralValue]) -> BooleanExpression: # type: ignore[misc] # pylint: disable=W0221 + def __new__(cls, term: BoundTerm, literals: set[LiteralValue]) -> BooleanExpression: # pylint: disable=W0221 count = len(literals) if count == 0: return AlwaysFalse() @@ -695,7 +769,7 @@ def as_unbound(self) -> type[In]: class BoundNotIn(BoundSetPredicate): - def __new__( # type: ignore[misc] # pylint: disable=W0221 + def __new__( # pylint: disable=W0221 cls, term: BoundTerm, literals: set[LiteralValue], @@ -720,15 +794,21 @@ def as_unbound(self) -> type[NotIn]: class In(SetPredicate): type: TypingLiteral["in"] = Field(default="in", alias="type") - def __new__( # type: ignore[misc] # pylint: disable=W0221 - cls, term: str | UnboundTerm, literals: Iterable[Any] | Iterable[LiteralValue] - ) -> BooleanExpression: - literals_set: set[LiteralValue] = _to_literal_set(literals) + def __new__( # pylint: disable=W0221 + cls, term: str | UnboundTerm, literals: Iterable[Any] | Iterable[LiteralValue] | None = None, **kwargs: Any + ) -> In: + if literals is None and "values" in kwargs: + literals = kwargs["values"] + + if literals is None: + literals_set: set[LiteralValue] = set() + else: + literals_set = _to_literal_set(literals) count = len(literals_set) if count == 0: return AlwaysFalse() elif count == 1: - return EqualTo(term, next(iter(literals))) + return EqualTo(term, next(iter(literals_set))) else: return super().__new__(cls) @@ -737,17 +817,23 @@ def __invert__(self) -> NotIn: return NotIn(self.term, self.literals) @property - def as_bound(self) -> builtins.type[BoundIn]: + def as_bound(self) -> type[BoundIn]: # type: ignore return BoundIn class NotIn(SetPredicate, ABC): type: TypingLiteral["not-in"] = Field(default="not-in", alias="type") - def __new__( # type: ignore[misc] # pylint: disable=W0221 - cls, term: str | UnboundTerm, literals: Iterable[Any] | Iterable[LiteralValue] - ) -> BooleanExpression: - literals_set: set[LiteralValue] = _to_literal_set(literals) + def __new__( # pylint: disable=W0221 + cls, term: str | UnboundTerm, literals: Iterable[Any] | Iterable[LiteralValue] | None = None, **kwargs: Any + ) -> NotIn: + if literals is None and "values" in kwargs: + literals = kwargs["values"] + + if literals is None: + literals_set: set[LiteralValue] = set() + else: + literals_set = _to_literal_set(literals) count = len(literals_set) if count == 0: return AlwaysTrue() @@ -761,18 +847,21 @@ def __invert__(self) -> In: return In(self.term, self.literals) @property - def as_bound(self) -> builtins.type[BoundNotIn]: + def as_bound(self) -> type[BoundNotIn]: # type: ignore return BoundNotIn -class LiteralPredicate(IcebergBaseModel, UnboundPredicate, ABC): +class LiteralPredicate(UnboundPredicate, ABC): type: TypingLiteral["lt", "lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type") term: UnboundTerm value: LiteralValue = Field() model_config = ConfigDict(populate_by_name=True, frozen=True, arbitrary_types_allowed=True) - def __init__(self, term: str | UnboundTerm, literal: Any): - super().__init__(term=_to_unbound_term(term), value=_to_literal(literal)) # type: ignore[call-arg] + def __init__(self, term: str | UnboundTerm, literal: Any | None = None, **kwargs: Any) -> None: + if literal is None and "value" in kwargs: + literal = kwargs["value"] + + super().__init__(term=_to_unbound_term(term), value=_to_literal(literal)) @property def literal(self) -> LiteralValue: @@ -793,7 +882,7 @@ def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredi elif isinstance(self, (LessThan, LessThanOrEqual, EqualTo)): return AlwaysFalse() - return self.as_bound(bound_term, lit) + return self.as_bound(bound_term, lit) # type: ignore def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the LiteralPredicate class.""" @@ -811,15 +900,14 @@ def __repr__(self) -> str: @property @abstractmethod - def as_bound(self) -> builtins.type[BoundLiteralPredicate]: ... + def as_bound(self) -> type[BoundLiteralPredicate]: ... # type: ignore class BoundLiteralPredicate(BoundPredicate, ABC): literal: LiteralValue def __init__(self, term: BoundTerm, literal: LiteralValue): # pylint: disable=W0621 - super().__init__(term) - self.literal = literal # pylint: disable=W0621 + super().__init__(term=term, literal=literal) def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the BoundLiteralPredicate class.""" @@ -827,6 +915,10 @@ def __eq__(self, other: Any) -> bool: return self.term == other.term and self.literal == other.literal return False + def __str__(self) -> str: + """Return the string representation of the BoundLiteralPredicate class.""" + return f"{self.__class__.__name__}(term={str(self.term)}, literal={repr(self.literal)})" + def __repr__(self) -> str: """Return the string representation of the BoundLiteralPredicate class.""" return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})" @@ -924,7 +1016,7 @@ def __invert__(self) -> NotEqualTo: return NotEqualTo(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundEqualTo]: + def as_bound(self) -> type[BoundEqualTo]: # type: ignore return BoundEqualTo @@ -936,7 +1028,7 @@ def __invert__(self) -> EqualTo: return EqualTo(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundNotEqualTo]: + def as_bound(self) -> type[BoundNotEqualTo]: # type: ignore return BoundNotEqualTo @@ -948,7 +1040,7 @@ def __invert__(self) -> GreaterThanOrEqual: return GreaterThanOrEqual(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundLessThan]: + def as_bound(self) -> type[BoundLessThan]: # type: ignore return BoundLessThan @@ -960,7 +1052,7 @@ def __invert__(self) -> LessThan: return LessThan(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundGreaterThanOrEqual]: + def as_bound(self) -> type[BoundGreaterThanOrEqual]: # type: ignore return BoundGreaterThanOrEqual @@ -972,7 +1064,7 @@ def __invert__(self) -> LessThanOrEqual: return LessThanOrEqual(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundGreaterThan]: + def as_bound(self) -> type[BoundGreaterThan]: # type: ignore return BoundGreaterThan @@ -984,7 +1076,7 @@ def __invert__(self) -> GreaterThan: return GreaterThan(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundLessThanOrEqual]: + def as_bound(self) -> type[BoundLessThanOrEqual]: # type: ignore return BoundLessThanOrEqual @@ -996,7 +1088,7 @@ def __invert__(self) -> NotStartsWith: return NotStartsWith(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundStartsWith]: + def as_bound(self) -> type[BoundStartsWith]: # type: ignore return BoundStartsWith @@ -1008,5 +1100,5 @@ def __invert__(self) -> StartsWith: return StartsWith(self.term, self.literal) @property - def as_bound(self) -> builtins.type[BoundNotStartsWith]: + def as_bound(self) -> type[BoundNotStartsWith]: # type: ignore return BoundNotStartsWith diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 7381e85008..739e18a6e6 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -280,12 +280,12 @@ def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: elif isinstance(pred, BoundEqualTo): return pred.as_unbound(Reference(name), _transform_literal(transformer, pred.literal)) elif isinstance(pred, BoundIn): # NotIn can't be projected - return pred.as_unbound(Reference(name), {_transform_literal(transformer, literal) for literal in pred.literals}) + return pred.as_unbound(Reference(name), {_transform_literal(transformer, literal) for literal in pred.literals}) # type: ignore else: # - Comparison predicates can't be projected, notEq can't be projected # - Small ranges can be projected: # For example, (x > 0) and (x < 3) can be turned into in({1, 2}) and projected. - return None + return None # type: ignore def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: transformer = self.transform(pred.term.ref().field.field_type) @@ -297,10 +297,10 @@ def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | elif isinstance(pred, BoundNotEqualTo): return pred.as_unbound(Reference(name), _transform_literal(transformer, pred.literal)) elif isinstance(pred, BoundNotIn): - return pred.as_unbound(Reference(name), {_transform_literal(transformer, literal) for literal in pred.literals}) + return pred.as_unbound(Reference(name), {_transform_literal(transformer, literal) for literal in pred.literals}) # type: ignore else: # no strict projection for comparison or equality - return None + return None # type: ignore def can_transform(self, source: IcebergType) -> bool: return isinstance( @@ -431,8 +431,6 @@ def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: return _truncate_number(name, pred, transformer) elif isinstance(pred, BoundIn): # NotIn can't be projected return _set_apply_transform(name, pred, transformer) - else: - return None def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: transformer = self.transform(pred.term.ref().field.field_type) @@ -444,8 +442,6 @@ def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | return _truncate_number_strict(name, pred, transformer) elif isinstance(pred, BoundNotIn): return _set_apply_transform(name, pred, transformer) - else: - return None @property def dedup_name(self) -> str: @@ -812,13 +808,12 @@ def project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: return pred.as_unbound(Reference(name)) elif isinstance(pred, BoundIn): return _set_apply_transform(name, pred, self.transform(field_type)) - elif isinstance(field_type, (IntegerType, LongType, DecimalType)): + elif isinstance(field_type, (IntegerType, LongType, DecimalType)): # type: ignore if isinstance(pred, BoundLiteralPredicate): return _truncate_number(name, pred, self.transform(field_type)) elif isinstance(field_type, (BinaryType, StringType)): if isinstance(pred, BoundLiteralPredicate): return _truncate_array(name, pred, self.transform(field_type)) - return None def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | None: field_type = pred.term.ref().field.field_type @@ -835,7 +830,7 @@ def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | elif isinstance(pred, BoundNotIn): return _set_apply_transform(name, pred, self.transform(field_type)) else: - return None + return None # type: ignore if isinstance(pred, BoundLiteralPredicate): if isinstance(pred, BoundStartsWith): @@ -860,7 +855,7 @@ def strict_project(self, name: str, pred: BoundPredicate) -> UnboundPredicate | elif isinstance(pred, BoundNotIn): return _set_apply_transform(name, pred, self.transform(field_type)) else: - return None + return None # type: ignore @property def width(self) -> int: @@ -1135,7 +1130,7 @@ def _remove_transform(partition_name: str, pred: BoundPredicate) -> UnboundPredi elif isinstance(pred, BoundLiteralPredicate): return pred.as_unbound(Reference(partition_name), pred.literal) elif isinstance(pred, (BoundIn, BoundNotIn)): - return pred.as_unbound(Reference(partition_name), pred.literals) + return pred.as_unbound(Reference(partition_name), pred.literals) # type: ignore else: raise ValueError(f"Cannot replace transform in unknown predicate: {pred}") diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 252da478d8..157c1adaf1 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -502,7 +502,7 @@ def test_less_than_or_equal_invert() -> None: ], ) def test_bind(pred: UnboundPredicate, table_schema_simple: Schema) -> None: - assert pred.bind(table_schema_simple, case_sensitive=True).term.field == table_schema_simple.find_field( # type: ignore + assert pred.bind(table_schema_simple, case_sensitive=True).term.field == table_schema_simple.find_field( pred.term.name, # type: ignore case_sensitive=True, ) @@ -522,7 +522,7 @@ def test_bind(pred: UnboundPredicate, table_schema_simple: Schema) -> None: ], ) def test_bind_case_insensitive(pred: UnboundPredicate, table_schema_simple: Schema) -> None: - assert pred.bind(table_schema_simple, case_sensitive=False).term.field == table_schema_simple.find_field( # type: ignore + assert pred.bind(table_schema_simple, case_sensitive=False).term.field == table_schema_simple.find_field( pred.term.name, # type: ignore case_sensitive=False, ) @@ -727,11 +727,10 @@ def test_and() -> None: def test_and_serialization() -> None: expr = And(EqualTo("x", 1), GreaterThan("y", 2)) + json_repr = '{"type":"and","left":{"term":"x","type":"eq","value":1},"right":{"term":"y","type":"gt","value":2}}' - assert ( - expr.model_dump_json() - == '{"type":"and","left":{"term":"x","type":"eq","value":1},"right":{"term":"y","type":"gt","value":2}}' - ) + assert expr.model_dump_json() == json_repr + assert BooleanExpression.model_validate_json(json_repr) == expr def test_or() -> None: @@ -755,11 +754,10 @@ def test_or_serialization() -> None: left = EqualTo("a", 10) right = EqualTo("b", 20) or_ = Or(left, right) + json_repr = '{"type":"or","left":{"term":"a","type":"eq","value":10},"right":{"term":"b","type":"eq","value":20}}' - assert ( - or_.model_dump_json() - == '{"type":"or","left":{"term":"a","type":"eq","value":10},"right":{"term":"b","type":"eq","value":20}}' - ) + assert or_.model_dump_json() == json_repr + assert BooleanExpression.model_validate_json(json_repr) == or_ def test_not() -> None: @@ -780,6 +778,7 @@ def test_not_json_serialization_and_deserialization() -> None: def test_always_true() -> None: always_true = AlwaysTrue() assert always_true.model_dump_json() == "true" + assert BooleanExpression.model_validate_json("true") == always_true assert str(always_true) == "AlwaysTrue()" assert repr(always_true) == "AlwaysTrue()" assert always_true == eval(repr(always_true)) @@ -789,6 +788,7 @@ def test_always_true() -> None: def test_always_false() -> None: always_false = AlwaysFalse() assert always_false.model_dump_json() == "false" + assert BooleanExpression.model_validate_json("false") == always_false assert str(always_false) == "AlwaysFalse()" assert repr(always_false) == "AlwaysFalse()" assert always_false == eval(repr(always_false)) @@ -823,6 +823,10 @@ def test_is_null() -> None: assert repr(is_null) == f"IsNull(term={repr(ref)})" assert is_null == eval(repr(is_null)) assert is_null == pickle.loads(pickle.dumps(is_null)) + pred = IsNull(term="foo") + json_repr = '{"term":"foo","type":"is-null"}' + assert pred.model_dump_json() == json_repr + assert BooleanExpression.model_validate_json(json_repr) == pred def test_not_null() -> None: @@ -832,16 +836,10 @@ def test_not_null() -> None: assert repr(non_null) == f"NotNull(term={repr(ref)})" assert non_null == eval(repr(non_null)) assert non_null == pickle.loads(pickle.dumps(non_null)) - - -def test_serialize_is_null() -> None: - pred = IsNull(term="foo") - assert pred.model_dump_json() == '{"term":"foo","type":"is-null"}' - - -def test_serialize_not_null() -> None: pred = NotNull(term="foo") - assert pred.model_dump_json() == '{"term":"foo","type":"not-null"}' + json_repr = '{"term":"foo","type":"not-null"}' + assert pred.model_dump_json() == json_repr + assert BooleanExpression.model_validate_json(json_repr) == pred def test_bound_is_nan(accessor: Accessor) -> None: @@ -877,6 +875,9 @@ def test_is_nan() -> None: assert repr(is_nan) == f"IsNaN(term={repr(ref)})" assert is_nan == eval(repr(is_nan)) assert is_nan == pickle.loads(pickle.dumps(is_nan)) + json_repr = '{"term":"a","type":"is-nan"}' + assert is_nan.model_dump_json() == json_repr + assert BooleanExpression.model_validate_json(json_repr) == is_nan def test_not_nan() -> None: @@ -886,6 +887,9 @@ def test_not_nan() -> None: assert repr(not_nan) == f"NotNaN(term={repr(ref)})" assert not_nan == eval(repr(not_nan)) assert not_nan == pickle.loads(pickle.dumps(not_nan)) + json_repr = '{"term":"a","type":"not-nan"}' + assert not_nan.model_dump_json() == json_repr + assert BooleanExpression.model_validate_json(json_repr) == not_nan def test_bound_in(term: BoundReference) -> None: @@ -906,7 +910,10 @@ def test_bound_not_in(term: BoundReference) -> None: def test_in() -> None: ref = Reference("a") - unbound_in = In(ref, {"a", "b", "c"}) + unbound_in = In(ref, ["a", "b", "c"]) + json_repr = unbound_in.model_dump_json() + assert json_repr.startswith('{"term":"a","type":"in","values":[') + assert BooleanExpression.model_validate_json(json_repr) == unbound_in assert str(unbound_in) == f"In({str(ref)}, {{a, b, c}})" assert repr(unbound_in) == f"In({repr(ref)}, {{literal('a'), literal('b'), literal('c')}})" assert unbound_in == eval(repr(unbound_in)) @@ -915,23 +922,16 @@ def test_in() -> None: def test_not_in() -> None: ref = Reference("a") - not_in = NotIn(ref, {"a", "b", "c"}) + not_in = NotIn(ref, ["a", "b", "c"]) + json_repr = not_in.model_dump_json() + assert not_in.model_dump_json().startswith('{"term":"a","type":"not-in","values":') + assert BooleanExpression.model_validate_json(json_repr) == not_in assert str(not_in) == f"NotIn({str(ref)}, {{a, b, c}})" assert repr(not_in) == f"NotIn({repr(ref)}, {{literal('a'), literal('b'), literal('c')}})" assert not_in == eval(repr(not_in)) assert not_in == pickle.loads(pickle.dumps(not_in)) -def test_serialize_in() -> None: - pred = In(term="foo", literals=[1, 2, 3]) - assert pred.model_dump_json() == '{"term":"foo","type":"in","values":[1,2,3]}' - - -def test_serialize_not_in() -> None: - pred = NotIn(term="foo", literals=[1, 2, 3]) - assert pred.model_dump_json() == '{"term":"foo","type":"not-in","values":[1,2,3]}' - - def test_bound_equal_to(term: BoundReference) -> None: bound_equal_to = BoundEqualTo(term, literal("a")) assert str(bound_equal_to) == f"BoundEqualTo(term={str(term)}, literal=literal('a'))" @@ -982,7 +982,9 @@ def test_bound_less_than_or_equal(term: BoundReference) -> None: def test_equal_to() -> None: equal_to = EqualTo(Reference("a"), literal("a")) - assert equal_to.model_dump_json() == '{"term":"a","type":"eq","value":"a"}' + json_repr = '{"term":"a","type":"eq","value":"a"}' + assert equal_to.model_dump_json() == json_repr + assert BooleanExpression.model_validate_json(json_repr) == equal_to assert str(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))" assert repr(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))" assert equal_to == eval(repr(equal_to)) @@ -991,7 +993,9 @@ def test_equal_to() -> None: def test_not_equal_to() -> None: not_equal_to = NotEqualTo(Reference("a"), literal("a")) - assert not_equal_to.model_dump_json() == '{"term":"a","type":"not-eq","value":"a"}' + json_repr = '{"term":"a","type":"not-eq","value":"a"}' + assert not_equal_to.model_dump_json() == json_repr + assert BooleanExpression.model_validate_json(json_repr) == not_equal_to assert str(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))" assert repr(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))" assert not_equal_to == eval(repr(not_equal_to)) @@ -1000,7 +1004,9 @@ def test_not_equal_to() -> None: def test_greater_than_or_equal_to() -> None: greater_than_or_equal_to = GreaterThanOrEqual(Reference("a"), literal("a")) - assert greater_than_or_equal_to.model_dump_json() == '{"term":"a","type":"gt-eq","value":"a"}' + json_repr = '{"term":"a","type":"gt-eq","value":"a"}' + assert greater_than_or_equal_to.model_dump_json() == json_repr + assert BooleanExpression.model_validate_json(json_repr) == greater_than_or_equal_to assert str(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))" assert repr(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))" assert greater_than_or_equal_to == eval(repr(greater_than_or_equal_to)) @@ -1009,7 +1015,9 @@ def test_greater_than_or_equal_to() -> None: def test_greater_than() -> None: greater_than = GreaterThan(Reference("a"), literal("a")) - assert greater_than.model_dump_json() == '{"term":"a","type":"gt","value":"a"}' + json_repr = '{"term":"a","type":"gt","value":"a"}' + assert greater_than.model_dump_json() == json_repr + assert BooleanExpression.model_validate_json(json_repr) == greater_than assert str(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))" assert repr(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))" assert greater_than == eval(repr(greater_than)) @@ -1018,7 +1026,9 @@ def test_greater_than() -> None: def test_less_than() -> None: less_than = LessThan(Reference("a"), literal("a")) - assert less_than.model_dump_json() == '{"term":"a","type":"lt","value":"a"}' + json_repr = '{"term":"a","type":"lt","value":"a"}' + assert less_than.model_dump_json() == json_repr + assert BooleanExpression.model_validate_json(json_repr) == less_than assert str(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))" assert repr(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))" assert less_than == eval(repr(less_than)) @@ -1027,7 +1037,9 @@ def test_less_than() -> None: def test_less_than_or_equal() -> None: less_than_or_equal = LessThanOrEqual(Reference("a"), literal("a")) - assert less_than_or_equal.model_dump_json() == '{"term":"a","type":"lt-eq","value":"a"}' + json_repr = '{"term":"a","type":"lt-eq","value":"a"}' + assert less_than_or_equal.model_dump_json() == json_repr + assert BooleanExpression.model_validate_json(json_repr) == less_than_or_equal assert str(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))" assert repr(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))" assert less_than_or_equal == eval(repr(less_than_or_equal)) @@ -1036,12 +1048,16 @@ def test_less_than_or_equal() -> None: def test_starts_with() -> None: starts_with = StartsWith(Reference("a"), literal("a")) - assert starts_with.model_dump_json() == '{"term":"a","type":"starts-with","value":"a"}' + json_repr = '{"term":"a","type":"starts-with","value":"a"}' + assert starts_with.model_dump_json() == json_repr + assert BooleanExpression.model_validate_json(json_repr) == starts_with def test_not_starts_with() -> None: not_starts_with = NotStartsWith(Reference("a"), literal("a")) - assert not_starts_with.model_dump_json() == '{"term":"a","type":"not-starts-with","value":"a"}' + json_repr = '{"term":"a","type":"not-starts-with","value":"a"}' + assert not_starts_with.model_dump_json() == json_repr + assert BooleanExpression.model_validate_json(json_repr) == not_starts_with def test_bound_reference_eval(table_schema_simple: Schema) -> None: