Skip to content

Commit

Permalink
Ensure consistent sorting of links and expressions
Browse files Browse the repository at this point in the history
Ref. #1064
  • Loading branch information
treiher committed Jun 28, 2022
1 parent 08d74ba commit 733a946
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 60 deletions.
60 changes: 14 additions & 46 deletions rflx/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,36 +163,36 @@ def __eq__(self, other: object) -> bool:
return str(self) == str(other)
return NotImplemented

def __str__(self) -> str:
try:
return self._str
except AttributeError:
self._update_str()
return self._str

def __hash__(self) -> int:
return hash(self.__class__.__name__)

def __lt__(self, other: object) -> bool:
if isinstance(other, Expr):
return False
return str(self) < str(other)
return NotImplemented

def __le__(self, other: object) -> bool:
if isinstance(other, Expr):
return self == other
return str(self) <= str(other)
return NotImplemented

def __gt__(self, other: object) -> bool:
if isinstance(other, Expr):
return False
return str(self) > str(other)
return NotImplemented

def __ge__(self, other: object) -> bool:
if isinstance(other, Expr):
return self == other
return str(self) >= str(other)
return NotImplemented

def __str__(self) -> str:
try:
return self._str
except AttributeError:
self._update_str()
return self._str

def __hash__(self) -> int:
return hash(self.__class__.__name__)

def __contains__(self, item: "Expr") -> bool:
return item == self

Expand Down Expand Up @@ -430,38 +430,6 @@ def __neg__(self) -> Expr:
def __contains__(self, item: Expr) -> bool:
return item == self or any(item in term for term in self.terms)

def __lt__(self, other: object) -> bool:
if isinstance(other, AssExpr):
if len(self.terms) == len(other.terms):
lt = [x < y for x, y in zip(self.terms, other.terms)]
eq = [x == y for x, y in zip(self.terms, other.terms)]
return any(lt) and all(map((lambda x: x[0] or x[1]), zip(lt, eq)))
return False
return NotImplemented

def __le__(self, other: object) -> bool:
if isinstance(other, AssExpr):
if len(self.terms) == len(other.terms):
return all(x <= y for x, y in zip(self.terms, other.terms))
return False
return NotImplemented

def __gt__(self, other: object) -> bool:
if isinstance(other, AssExpr):
if len(self.terms) == len(other.terms):
gt = [x > y for x, y in zip(self.terms, other.terms)]
eq = [x == y for x, y in zip(self.terms, other.terms)]
return any(gt) and all(map((lambda x: x[0] or x[1]), zip(gt, eq)))
return False
return NotImplemented

def __ge__(self, other: object) -> bool:
if isinstance(other, AssExpr):
if len(self.terms) == len(other.terms):
return all(x >= y for x, y in zip(self.terms, other.terms))
return False
return NotImplemented

@property
@abstractmethod
def precedence(self) -> Precedence:
Expand Down
8 changes: 5 additions & 3 deletions rflx/model/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ def __hash__(self) -> int:
def __repr__(self) -> str:
return f'Field("{self.identifier}")'

def __lt__(self, other: "Field") -> int:
return self.identifier < other.identifier
def __lt__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return self.identifier < other.identifier
return NotImplemented

@property
def name(self) -> str:
Expand Down Expand Up @@ -1082,7 +1084,7 @@ def remove_variable_prefix(expression: expr.Expr) -> expr.Expr:
if path_condition != expr.TRUE
else field_size
for path_condition, groups in itertools.groupby(
conditional_field_size,
sorted(conditional_field_size),
lambda x: x[0],
)
for field_size in [expr.Add(*(s for _, s in groups))]
Expand Down
16 changes: 8 additions & 8 deletions tests/unit/expression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def test_number_gt() -> None:
assert not Number(1) > Number(2)
assert not Number(2) > Number(2)
assert Number(3) > Number(2)
assert not Variable("X") > Number(2)
assert Variable("X") > Number(2)
assert not Number(2) > Variable("X")


Expand All @@ -502,7 +502,7 @@ def test_number_ge() -> None:
assert not Number(1) >= Number(2)
assert Number(2) >= Number(2)
assert Number(3) >= Number(2)
assert not Variable("X") >= Number(2)
assert Variable("X") >= Number(2)
assert not Number(2) >= Variable("X")


Expand Down Expand Up @@ -574,19 +574,19 @@ def test_add_lt() -> None:
assert Add(Variable("X"), Number(1)) < Add(Variable("X"), Number(2))
assert not Add(Variable("X"), Number(2)) < Add(Variable("X"), Number(2))
assert not Add(Variable("X"), Number(3)) < Add(Variable("X"), Number(2))
assert not Add(Variable("X"), Number(1)) < Add(Variable("Y"), Number(2))
assert not Add(Variable("X"), Number(2)) < Add(Variable("Y"), Number(1))
assert not Add(Variable("X"), Number(2)) < Add(Variable("Y"), Variable("Z"), Number(1))
assert Add(Variable("X"), Number(1)) < Add(Variable("Y"), Number(2))
assert Add(Variable("X"), Number(2)) < Add(Variable("Y"), Number(1))
assert Add(Variable("X"), Number(2)) < Add(Variable("Y"), Variable("Z"), Number(1))


def test_add_le() -> None:
# pylint: disable=unneeded-not
assert Add(Variable("X"), Number(1)) <= Add(Variable("X"), Number(2))
assert Add(Variable("X"), Number(2)) <= Add(Variable("X"), Number(2))
assert not Add(Variable("X"), Number(3)) <= Add(Variable("X"), Number(2))
assert not Add(Variable("X"), Number(1)) <= Add(Variable("Y"), Number(2))
assert not Add(Variable("X"), Number(2)) <= Add(Variable("Y"), Number(1))
assert not Add(Variable("X"), Number(2)) <= Add(Variable("Y"), Variable("Z"), Number(1))
assert Add(Variable("X"), Number(1)) <= Add(Variable("Y"), Number(2))
assert Add(Variable("X"), Number(2)) <= Add(Variable("Y"), Number(1))
assert Add(Variable("X"), Number(2)) <= Add(Variable("Y"), Variable("Z"), Number(1))


def test_add_gt() -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ def test_dot_graph_with_double_edge(tmp_path: Path) -> None:
Initial -> intermediate_0 [arrowhead=none];
intermediate_0 -> X [minlen=1];
intermediate_1 [color="#6f6f6f", fontcolor="#6f6f6f", fontname="Fira Code", height=0,
label="(X > 100, 0, ⋆)", penwidth=0, style="", width=0];
label="(X < 50, 0, ⋆)", penwidth=0, style="", width=0];
X -> intermediate_1 [arrowhead=none];
intermediate_1 -> Final [minlen=1];
intermediate_2 [color="#6f6f6f", fontcolor="#6f6f6f", fontname="Fira Code", height=0,
label="(X < 50, 0, ⋆)", penwidth=0, style="", width=0];
label="(X > 100, 0, ⋆)", penwidth=0, style="", width=0];
X -> intermediate_2 [arrowhead=none];
intermediate_2 -> Final [minlen=1];
Final [fillcolor="#6f6f6f", label="", shape=circle, width="0.5"];
Expand Down
11 changes: 10 additions & 1 deletion tests/unit/model/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,15 @@ def assert_message(actual: Message, expected: Message, msg: str = None) -> None:
assert actual.fields == expected.fields, msg


def test_link_order() -> None:
l1 = Link(FINAL, INITIAL, condition=Equal(Variable("X"), FALSE))
l2 = Link(FINAL, INITIAL, condition=Equal(Variable("X"), TRUE))
l3 = Link(INITIAL, FINAL, condition=Equal(Variable("X"), FALSE))
l4 = Link(INITIAL, FINAL, condition=Equal(Variable("X"), TRUE))
assert sorted([l1, l2, l3, l4]) == [l1, l2, l3, l4]
assert sorted([l4, l3, l2, l1]) == [l1, l2, l3, l4]


def test_invalid_identifier() -> None:
with pytest.raises(
RecordFluxError,
Expand Down Expand Up @@ -2900,8 +2909,8 @@ def test_size() -> None:
}
) == Number(32)
assert optional_overlayed_field.size() == Add(
IfExpr([(Greater(Variable("A"), Number(0)), Number(32))], Number(0)),
IfExpr([(Equal(Variable("A"), Number(0)), Number(16))], Number(0)),
IfExpr([(Greater(Variable("A"), Number(0)), Number(32))], Number(0)),
)


Expand Down

0 comments on commit 733a946

Please sign in to comment.