From b42a7ceb469ca7c8f87cc9a9a238c2079f950864 Mon Sep 17 00:00:00 2001 From: Robin Senn Date: Fri, 29 Apr 2022 15:02:57 +0200 Subject: [PATCH] Remove ability to compare numbers with aggregates Ref: #964 --- rflx/expression.py | 31 ++++++++++++++++++------------- rflx/pyrflx/typevalue.py | 9 +++++---- tests/unit/expression_test.py | 16 ---------------- 3 files changed, 23 insertions(+), 33 deletions(-) diff --git a/rflx/expression.py b/rflx/expression.py index ffb77a630..e3adfe709 100644 --- a/rflx/expression.py +++ b/rflx/expression.py @@ -1570,6 +1570,20 @@ def __init__(self, *elements: Expr, location: Location = None) -> None: super().__init__(rty.Aggregate(rty.common_type([e.type_ for e in elements])), location) self.elements = list(elements) + def __eq__(self, other: object) -> bool: + if ( + isinstance(other, Aggregate) + and all((isinstance(v, Number) for v in self.elements)) + and all((isinstance(v, Number) for v in other.elements)) + ): + return [v.value for v in self.elements if isinstance(v, Number)] == [ + v.value for v in other.elements if isinstance(v, Number) + ] + return super().__eq__(other) + + def __hash__(self) -> int: + return hash(tuple(self.elements)) + def _update_str(self) -> None: self._str = intern("[" + ", ".join(map(str, self.elements)) + "]") @@ -1719,19 +1733,10 @@ def _simplified(self, relation_operator: Callable[[Expr, Expr], bool]) -> Expr: } if (relation_operator, left, right) in mapping: return mapping[(relation_operator, left, right)] - if isinstance(left, (Number, Aggregate)) and isinstance(right, (Number, Aggregate)): - left_number = ( - Number(int.from_bytes(left.to_bytes(), "big")) - if isinstance(left, Aggregate) - else left - ) - right_number = ( - Number(int.from_bytes(right.to_bytes(), "big")) - if isinstance(right, Aggregate) - else right - ) - assert isinstance(left_number, Number) and isinstance(right_number, Number) - return TRUE if relation_operator(left_number, right_number) else FALSE + if isinstance(left, Number) and isinstance(right, Number): + return TRUE if relation_operator(left, right) else FALSE + if isinstance(left, Aggregate) and isinstance(right, Aggregate): + return TRUE if relation_operator(left, right) else FALSE return self.__class__(left, right) @property diff --git a/rflx/pyrflx/typevalue.py b/rflx/pyrflx/typevalue.py index 79032c921..90321fbe5 100644 --- a/rflx/pyrflx/typevalue.py +++ b/rflx/pyrflx/typevalue.py @@ -12,6 +12,7 @@ TRUE, UNDEFINED, Add, + Aggregate, And, Attribute, Expr, @@ -1161,6 +1162,8 @@ def _calculate_checksum(self, checksum: "MessageValue.Checksum") -> int: expr_tuple.evaluated_expression.lower.value, expr_tuple.evaluated_expression.upper.value, ) + elif isinstance(expr_tuple.evaluated_expression, Aggregate): + arguments[str(expr_tuple.expression)] = expr_tuple.evaluated_expression.to_bytes() else: assert isinstance(expr_tuple.evaluated_expression, Number) arguments[str(expr_tuple.expression)] = expr_tuple.evaluated_expression.value @@ -1360,15 +1363,13 @@ def subst(expression: Expr) -> Expr: if self._fields[expression.identifier.flat].set: exp_value = self._fields[expression.identifier.flat].typeval.value if isinstance(exp_value, bytes): - return Number(int.from_bytes(exp_value, "big")) + return Aggregate(*[Number(b) for b in exp_value]) if ( isinstance(exp_value, list) and len(exp_value) > 0 and isinstance(exp_value[0], IntegerValue) ): - return Number( - int.from_bytes(b"".join([bytes(v.bitstring) for v in exp_value]), "big") - ) + return Aggregate(*[Number(e.value) for e in exp_value]) return NotImplemented return expression diff --git a/tests/unit/expression_test.py b/tests/unit/expression_test.py index d6449dee4..1b9507836 100644 --- a/tests/unit/expression_test.py +++ b/tests/unit/expression_test.py @@ -1098,22 +1098,6 @@ def test_relation_simplified() -> None: ).simplified(), FALSE, ) - assert_equal( - Equal(Number(0), Aggregate(Number(0), Number(1), Number(2))).simplified(), - FALSE, - ) - assert_equal( - Equal(Aggregate(Number(0), Number(1), Number(2)), Number(0)).simplified(), - FALSE, - ) - assert_equal( - NotEqual(Number(4), Aggregate(Number(0), Number(1), Number(2))).simplified(), - TRUE, - ) - assert_equal( - NotEqual(Number(0), Aggregate(Number(0), Number(1), Number(2))).simplified(), - TRUE, - ) assert Equal(TRUE, TRUE).simplified() == TRUE assert Equal(TRUE, FALSE).simplified() == FALSE assert NotEqual(TRUE, TRUE).simplified() == FALSE