Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 55 additions & 8 deletions pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from typing import Literal as TypingLiteral

from pydantic import Field
from pydantic import ConfigDict, Field, field_serializer, field_validator

from pyiceberg.expressions.literals import (
AboveMax,
Expand All @@ -52,8 +52,14 @@
ConfigDict = dict


def _to_unbound_term(term: Union[str, UnboundTerm[Any]]) -> UnboundTerm[Any]:
return Reference(term) if isinstance(term, str) else term
def _to_unbound_term(term: Union[str, UnboundTerm[Any], BoundReference[Any]]) -> UnboundTerm[Any]:
if isinstance(term, str):
return Reference(term)
if isinstance(term, UnboundTerm):
return term
if isinstance(term, BoundReference):
return Reference(term.field.name)
raise ValueError(f"Expected UnboundTerm | BoundReference | str, got {type(term).__name__}")


def _to_literal_set(values: Union[Iterable[L], Iterable[Literal[L]]]) -> Set[Literal[L]]:
Expand Down Expand Up @@ -743,12 +749,33 @@ def as_bound(self) -> Type[BoundNotIn[L]]:
return BoundNotIn[L]


class LiteralPredicate(UnboundPredicate[L], ABC):
literal: Literal[L]
class LiteralPredicate(IcebergBaseModel, UnboundPredicate[L], ABC):
type: TypingLiteral["lt", "lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type")
term: UnboundTerm[L]
literal: Literal[L] = Field(serialization_alias="value")

model_config = ConfigDict(arbitrary_types_allowed=True)

def __init__(self, *args: Any, **kwargs: Any) -> None:
if args:
if len(args) != 2:
raise TypeError("Expected (term, literal)")
kwargs = {"term": args[0], "literal": args[1], **kwargs}
super().__init__(**kwargs)
Comment on lines +759 to +764
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After having many issues with an init such as:

def __init__(self, term: Union[str, UnboundTerm[Any]], literals: Union[Iterable[L], Iterable[Literal[L]]]):
        super().__init__(term=_to_unbound_term(term), items=_to_literal_set(literals))

Because there are some typing errors with _transform_literal in pyiceberg/transforms.py for example:

  pyiceberg/transforms.py:1113: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[str | None], str | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1113: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[bool | None], bool | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1113: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[int | None], int | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1113: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[float | None], float | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1113: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[bytes | None], bytes | None]"; expected "Callable[[str], str]"  [arg-type]
  pyiceberg/transforms.py:1113: error: Argument 1 to "_transform_literal" has incompatible type "Callable[[UUID | None], UUID | None]"; expected "Callable[[str], str]"  [arg-type]

I decided to just go for this implementation of init. The problem now is that:

assert_type(EqualTo("a", "b"), EqualTo[str])  # <-- Fails
------
  tests/expressions/test_expressions.py:1238: error: Expression is of type "LiteralPredicate[L]", not "EqualTo[str]"  [assert-type]

So I am really stuck, would you mind lending a hand here? @Fokko


@field_validator("term", mode="before")
@classmethod
def _coerce_term(cls, v: Any) -> UnboundTerm[Any]:
return _to_unbound_term(v)

def __init__(self, term: Union[str, UnboundTerm[Any]], literal: Union[L, Literal[L]]): # pylint: disable=W0621
super().__init__(term)
self.literal = _to_literal(literal) # pylint: disable=W0621
@field_validator("literal", mode="before")
@classmethod
def _coerce_literal(cls, v: Union[L, Literal[L]]) -> Literal[L]:
return _to_literal(v)

@field_serializer("literal")
def ser_literal(self, literal: Literal[L]) -> str:
return "Any"

def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate[L]:
bound_term = self.term.bind(schema, case_sensitive)
Expand All @@ -773,6 +800,10 @@ def __eq__(self, other: Any) -> bool:
return self.term == other.term and self.literal == other.literal
return False

def __str__(self) -> str:
"""Return the string representation of the LiteralPredicate class."""
return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})"

def __repr__(self) -> str:
"""Return the string representation of the LiteralPredicate class."""
return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})"
Expand Down Expand Up @@ -886,6 +917,8 @@ def as_unbound(self) -> Type[NotStartsWith[L]]:


class EqualTo(LiteralPredicate[L]):
type: TypingLiteral["eq"] = Field(default="eq", alias="type")

def __invert__(self) -> NotEqualTo[L]:
"""Transform the Expression into its negated version."""
return NotEqualTo[L](self.term, self.literal)
Expand All @@ -896,6 +929,8 @@ def as_bound(self) -> Type[BoundEqualTo[L]]:


class NotEqualTo(LiteralPredicate[L]):
type: TypingLiteral["not-eq"] = Field(default="not-eq", alias="type")

def __invert__(self) -> EqualTo[L]:
"""Transform the Expression into its negated version."""
return EqualTo[L](self.term, self.literal)
Expand All @@ -906,6 +941,8 @@ def as_bound(self) -> Type[BoundNotEqualTo[L]]:


class LessThan(LiteralPredicate[L]):
type: TypingLiteral["lt"] = Field(default="lt", alias="type")

def __invert__(self) -> GreaterThanOrEqual[L]:
"""Transform the Expression into its negated version."""
return GreaterThanOrEqual[L](self.term, self.literal)
Expand All @@ -916,6 +953,8 @@ def as_bound(self) -> Type[BoundLessThan[L]]:


class GreaterThanOrEqual(LiteralPredicate[L]):
type: TypingLiteral["gt-eq"] = Field(default="gt-eq", alias="type")

def __invert__(self) -> LessThan[L]:
"""Transform the Expression into its negated version."""
return LessThan[L](self.term, self.literal)
Expand All @@ -926,6 +965,8 @@ def as_bound(self) -> Type[BoundGreaterThanOrEqual[L]]:


class GreaterThan(LiteralPredicate[L]):
type: TypingLiteral["gt"] = Field(default="gt", alias="type")

def __invert__(self) -> LessThanOrEqual[L]:
"""Transform the Expression into its negated version."""
return LessThanOrEqual[L](self.term, self.literal)
Expand All @@ -936,6 +977,8 @@ def as_bound(self) -> Type[BoundGreaterThan[L]]:


class LessThanOrEqual(LiteralPredicate[L]):
type: TypingLiteral["lt-eq"] = Field(default="lt-eq", alias="type")

def __invert__(self) -> GreaterThan[L]:
"""Transform the Expression into its negated version."""
return GreaterThan[L](self.term, self.literal)
Expand All @@ -946,6 +989,8 @@ def as_bound(self) -> Type[BoundLessThanOrEqual[L]]:


class StartsWith(LiteralPredicate[L]):
type: TypingLiteral["starts-with"] = Field(default="starts-with", alias="type")

def __invert__(self) -> NotStartsWith[L]:
"""Transform the Expression into its negated version."""
return NotStartsWith[L](self.term, self.literal)
Expand All @@ -956,6 +1001,8 @@ def as_bound(self) -> Type[BoundStartsWith[L]]:


class NotStartsWith(LiteralPredicate[L]):
type: TypingLiteral["not-starts-with"] = Field(default="not-starts-with", alias="type")

def __invert__(self) -> StartsWith[L]:
"""Transform the Expression into its negated version."""
return StartsWith[L](self.term, self.literal)
Expand Down
4 changes: 2 additions & 2 deletions tests/expressions/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def data_file_nan() -> DataFile:


def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_file_nan: Schema, data_file_nan: DataFile) -> None:
for operator in [LessThan, LessThanOrEqual]: # type: ignore
for operator in [LessThan, LessThanOrEqual]:
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
assert not should_read, "Should not match: all nan column doesn't contain number"

Expand Down Expand Up @@ -711,7 +711,7 @@ def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_f
def test_inclusive_metrics_evaluator_greater_than_and_greater_than_equal(
schema_data_file_nan: Schema, data_file_nan: DataFile
) -> None:
for operator in [GreaterThan, GreaterThanOrEqual]: # type: ignore
for operator in [GreaterThan, GreaterThanOrEqual]:
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
assert not should_read, "Should not match: all nan column doesn't contain number"

Expand Down
30 changes: 24 additions & 6 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@
NotIn,
NotNaN,
NotNull,
NotStartsWith,
Or,
Reference,
StartsWith,
UnboundPredicate,
)
from pyiceberg.expressions.literals import Literal, literal
Expand Down Expand Up @@ -427,14 +429,14 @@ def test_bound_less_than_or_equal_invert(table_schema_simple: Schema) -> None:

def test_not_equal_to_invert() -> None:
bound = NotEqualTo(
term=BoundReference( # type: ignore
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literal="hello",
)
assert ~bound == EqualTo(
term=BoundReference( # type: ignore
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
Expand All @@ -444,14 +446,14 @@ def test_not_equal_to_invert() -> None:

def test_greater_than_or_equal_invert() -> None:
bound = GreaterThanOrEqual(
term=BoundReference( # type: ignore
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literal="hello",
)
assert ~bound == LessThan(
term=BoundReference( # type: ignore
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
Expand All @@ -461,14 +463,14 @@ def test_greater_than_or_equal_invert() -> None:

def test_less_than_or_equal_invert() -> None:
bound = LessThanOrEqual(
term=BoundReference( # type: ignore
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literal="hello",
)
assert ~bound == GreaterThan(
term=BoundReference( # type: ignore
term=BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
Expand Down Expand Up @@ -933,6 +935,7 @@ def test_bound_less_than_or_equal(term: BoundReference[Any]) -> None:

def test_equal_to() -> None:
equal_to = EqualTo(Reference("a"), literal("a"))
assert equal_to.model_dump_json() == '{"term":"a","type":"eq","value":"Any"}'
assert str(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
assert repr(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
assert equal_to == eval(repr(equal_to))
Expand All @@ -941,6 +944,7 @@ def test_equal_to() -> None:

def test_not_equal_to() -> None:
not_equal_to = NotEqualTo(Reference("a"), literal("a"))
assert not_equal_to.model_dump_json() == '{"term":"a","type":"not-eq","value":"Any"}'
assert str(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
assert repr(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
assert not_equal_to == eval(repr(not_equal_to))
Expand All @@ -949,6 +953,7 @@ def test_not_equal_to() -> None:

def test_greater_than_or_equal_to() -> None:
greater_than_or_equal_to = GreaterThanOrEqual(Reference("a"), literal("a"))
assert greater_than_or_equal_to.model_dump_json() == '{"term":"a","type":"gt-eq","value":"Any"}'
assert str(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
assert repr(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
assert greater_than_or_equal_to == eval(repr(greater_than_or_equal_to))
Expand All @@ -957,6 +962,7 @@ def test_greater_than_or_equal_to() -> None:

def test_greater_than() -> None:
greater_than = GreaterThan(Reference("a"), literal("a"))
assert greater_than.model_dump_json() == '{"term":"a","type":"gt","value":"Any"}'
assert str(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
assert repr(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
assert greater_than == eval(repr(greater_than))
Expand All @@ -965,6 +971,7 @@ def test_greater_than() -> None:

def test_less_than() -> None:
less_than = LessThan(Reference("a"), literal("a"))
assert less_than.model_dump_json() == '{"term":"a","type":"lt","value":"Any"}'
assert str(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
assert repr(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
assert less_than == eval(repr(less_than))
Expand All @@ -973,12 +980,23 @@ def test_less_than() -> None:

def test_less_than_or_equal() -> None:
less_than_or_equal = LessThanOrEqual(Reference("a"), literal("a"))
assert less_than_or_equal.model_dump_json() == '{"term":"a","type":"lt-eq","value":"Any"}'
assert str(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
assert repr(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
assert less_than_or_equal == eval(repr(less_than_or_equal))
assert less_than_or_equal == pickle.loads(pickle.dumps(less_than_or_equal))


def test_starts_with() -> None:
starts_with = StartsWith(Reference("a"), literal("a"))
assert starts_with.model_dump_json() == '{"term":"a","type":"starts-with","value":"Any"}'


def test_not_starts_with() -> None:
not_starts_with = NotStartsWith(Reference("a"), literal("a"))
assert not_starts_with.model_dump_json() == '{"term":"a","type":"not-starts-with","value":"Any"}'


def test_bound_reference_eval(table_schema_simple: Schema) -> None:
"""Test creating a BoundReference and evaluating it on a StructProtocol"""
struct = Record("foovalue", 123, True)
Expand Down