diff --git a/rflx/expression.py b/rflx/expression.py index 30f495f25..6eaf4d179 100644 --- a/rflx/expression.py +++ b/rflx/expression.py @@ -1209,6 +1209,8 @@ def ada_expr(self) -> ada.Expr: @lru_cache(maxsize=None) def z3expr(self) -> z3.ExprRef: + if self.type_ == rty.BOOLEAN: + return z3.Bool(self.name) if self.negative: return -z3.Int(self.name) return z3.Int(self.name) diff --git a/rflx/model/message.py b/rflx/model/message.py index e781b37bd..f87b723c6 100644 --- a/rflx/model/message.py +++ b/rflx/model/message.py @@ -846,6 +846,13 @@ def __init__( self.error.propagate() + def _set_types(self) -> None: + def set_types(expression: expr.Expr) -> expr.Expr: + return self._typed_variable(expression, self.types) + + for link in self.structure: + link.condition = link.condition.substituted(set_types) + def verify(self) -> None: if not self.is_null: self._verify_parameters() @@ -854,6 +861,7 @@ def verify(self) -> None: self.error.propagate() self._verify_expression_types() + self._set_types() self._verify_expressions() self._verify_checksums() diff --git a/rflx/model/type_.py b/rflx/model/type_.py index 8a55921f5..6b79901af 100644 --- a/rflx/model/type_.py +++ b/rflx/model/type_.py @@ -181,8 +181,12 @@ def constraints( ) -> abc.Sequence[expr.Expr]: if proof: return [ - expr.Less(expr.Variable(name), self._modulus, location=self.location), - expr.GreaterEqual(expr.Variable(name), expr.Number(0), location=self.location), + expr.Less( + expr.Variable(name, type_=self.type_), self._modulus, location=self.location + ), + expr.GreaterEqual( + expr.Variable(name, type_=self.type_), expr.Number(0), location=self.location + ), expr.Equal(expr.Size(name), self.size, location=self.location), ] @@ -344,8 +348,12 @@ def constraints( ) -> abc.Sequence[expr.Expr]: if proof: return [ - expr.GreaterEqual(expr.Variable(name), self.first, location=self.location), - expr.LessEqual(expr.Variable(name), self.last, location=self.location), + expr.GreaterEqual( + expr.Variable(name, type_=self.type_), self.first, location=self.location + ), + expr.LessEqual( + expr.Variable(name, type_=self.type_), self.last, location=self.location + ), expr.Equal(expr.Size(name), self.size, location=self.location), ] @@ -531,14 +539,19 @@ def constraints( result: list[expr.Expr] = [ expr.Or( *[ - expr.Equal(expr.Variable(name), expr.Literal(l), self.location) + expr.Equal( + expr.Variable(name, type_=self.type_), expr.Literal(l), self.location + ) for l in literals ], location=self.location, ) ] result.extend( - [expr.Equal(expr.Literal(l), v, self.location) for l, v in literals.items()] + [ + expr.Equal(expr.Literal(l, type_=self.type_), v, self.location) + for l, v in literals.items() + ] ) result.append(expr.Equal(expr.Size(name), self.size, self.location)) return result diff --git a/tests/unit/model/message_test.py b/tests/unit/model/message_test.py index 7f1ea52b5..5a377c05f 100644 --- a/tests/unit/model/message_test.py +++ b/tests/unit/model/message_test.py @@ -4287,3 +4287,19 @@ def test_refinement_invalid_condition_unqualified_literal() -> None: r' of "P::M"' r"$", ) + + +def test_boolean_value_as_condition() -> None: + Message( + "P::M", + [ + Link(INITIAL, Field("Tag_1")), + Link(Field("Tag_1"), Field("Tag_2"), condition=Variable("Has_Tag")), + Link(Field("Tag_2"), FINAL), + ], + { + Field("Tag_1"): MODULAR_INTEGER, + Field("Tag_2"): MODULAR_INTEGER, + Field("Has_Tag"): BOOLEAN, + }, + )