diff --git a/rflx/model.py b/rflx/model.py index afdd1c993..411a6870b 100644 --- a/rflx/model.py +++ b/rflx/model.py @@ -4,7 +4,7 @@ from collections import defaultdict from copy import copy from pathlib import Path -from typing import Dict, List, Mapping, NamedTuple, Optional, Sequence, Set, Tuple +from typing import Dict, List, Mapping, NamedTuple, Optional, Sequence, Set, Tuple, Union from rflx.common import flat_name, generic_repr from rflx.contract import ensure, invariant @@ -19,6 +19,7 @@ Equal, Expr, First, + Greater, GreaterEqual, Last, Length, @@ -433,6 +434,9 @@ class MessageState(Base): field_condition: Mapping[Field, Expr] = {} +TypeExpr = Union[Type, Expr] + + @invariant(lambda self: valid_message_field_types(self)) class AbstractMessage(Type): # pylint: disable=too-many-arguments @@ -664,12 +668,16 @@ def __verify_conditions(self) -> None: self.__check_vars(l.length, state, l.length.location) self.__check_vars(l.first, state, l.first.location) self.__check_attributes(l.condition, l.condition.location) - self.__check_relations(l.condition, l.condition.location) + self.__check_relations(l.condition, literals) self.__check_first_expression(l, l.first.location) self.__check_length_expression(l) + self.error.propagate() def __check_vars( - self, expression: Expr, state: Tuple[Set[ID], Set[ID], Set[ID]], location: Location = None, + self, + expression: Expr, + state: Tuple[Set[ID], Dict[ID, Enumeration], Set[ID]], + location: Location = None, ) -> None: variables, literals, seen = state for v in expression.variables(True): @@ -699,68 +707,134 @@ def __check_attributes(self, expression: Expr, location: Location = None) -> Non location, ) - def __check_relations(self, expression: Expr, location: Location = None) -> None: - for r in expression.findall(lambda x: isinstance(x, Relation)): - if ( - isinstance(r, Relation) - and not isinstance(r, (Equal, NotEqual)) - and (isinstance(r.left, Aggregate) or isinstance(r.right, Aggregate)) - ): + def __resolve_types( + self, left: Expr, right: Expr, literals: Dict[ID, Enumeration], + ) -> Tuple[TypeExpr, TypeExpr]: + + lefttype: TypeExpr + if isinstance(left, Variable): + if left.identifier in literals: + lefttype = literals[left.identifier] + elif Field(left.name) in self.types: + lefttype = self.types[Field(left.name)] + else: self.error.append( - f'invalid relation "{r.symbol}" to aggregate', + 'undefined variable "{left.identifier}" referenced', Subsystem.MODEL, Severity.ERROR, - location, + left.location, + ) + else: + lefttype = left + + righttype: TypeExpr + if isinstance(right, Variable): + if right.identifier in literals: + righttype = literals[right.identifier] + elif Field(right.name) in self.types: + righttype = self.types[Field(right.name)] + else: + self.error.append( + 'undefined variable "{right.identifier}" referenced', + Subsystem.MODEL, + Severity.ERROR, + right.location, ) + else: + righttype = right - if isinstance(r, (Equal, NotEqual)) and ( - isinstance(r.left, Aggregate) or isinstance(r.right, Aggregate) - ): - if isinstance(r.left, Aggregate): - other = r.right - aggregate = r.left - elif isinstance(r.right, Aggregate): - other = r.left - aggregate = r.right - if not ( - isinstance(other, Variable) - and Field(other.name) in self.fields - and isinstance(self.types[Field(other.name)], Composite) - ): + self.error.propagate() + return (lefttype, righttype) + + def __check_relations(self, expression: Expr, literals: Dict[ID, Enumeration]) -> None: + def check_composite_element_range( + relation: Relation, aggregate: Aggregate, composite: Composite + ) -> None: + first: Expr + last: Expr + if isinstance(composite, Opaque): + first = Number(0) + last = Number(255) + + if isinstance(composite, Array): + if not isinstance(composite.element_type, Integer): self.error.append( - f'invalid relation between "{other}" and aggregate', + f'invalid array element type "{composite.element_type.identifier}"' + " for aggregate comparison", Subsystem.MODEL, Severity.ERROR, - location, + relation.location, ) - else: - othertype = self.types[Field(other.name)] - first: Expr - last: Expr - if isinstance(othertype, Opaque): - first = Number(0) - last = Number(255) - elif isinstance(othertype, Array): - if not isinstance(othertype.element_type, Integer): - self.error.append( - f'invalid array element type "{othertype.element_type.identifier}"' - " for aggregate comparison", - Subsystem.MODEL, - Severity.ERROR, - r.location, - ) - continue - first = othertype.element_type.first.simplified() - last = othertype.element_type.last.simplified() + return + first = composite.element_type.first.simplified() + last = composite.element_type.last.simplified() - for element in aggregate.elements: - if not first <= element <= last: - self.error.append( - f"aggregate element out of range {first} .. {last}", - Subsystem.MODEL, - Severity.ERROR, - element.location, - ) + for element in aggregate.elements: + if not first <= element <= last: + self.error.append( + f"aggregate element out of range {first} .. {last}", + Subsystem.MODEL, + Severity.ERROR, + element.location, + ) + + def check_enumeration(relation: Relation, left: Enumeration, right: Enumeration) -> None: + if left != right: + self.error.append( + "comparison of incompatible enumeration literals", + Subsystem.MODEL, + Severity.ERROR, + relation.location, + ) + self.error.append( + f'of type "{left.identifier}"', Subsystem.MODEL, Severity.INFO, left.location, + ) + self.error.append( + f'and type "{right.identifier}"', + Subsystem.MODEL, + Severity.INFO, + right.location, + ) + + def relation_error(relation: Relation, left: TypeExpr, right: TypeExpr) -> None: + self.error.append( + f'invalid relation "{relation.symbol}" between {left.__class__.__name__} ' + f"and {right.__class__.__name__}", + Subsystem.MODEL, + Severity.ERROR, + relation.location, + ) + + for relation in expression.findall(lambda x: isinstance(x, Relation)): + assert isinstance(relation, Relation) + left, right = self.__resolve_types( + relation.left.simplified(), relation.right.simplified(), literals + ) + if isinstance(relation, (Less, LessEqual, Greater, GreaterEqual)): + if (isinstance(left, Aggregate) and isinstance(right, Type)) or ( + isinstance(right, Aggregate) and isinstance(left, Type) + ): + relation_error(relation, left, right) + elif isinstance(relation, (Equal, NotEqual)): + # pylint: disable=too-many-boolean-expressions + if ( + (isinstance(left, Opaque) and not isinstance(right, (Opaque, Aggregate))) + or (isinstance(left, Array) and not isinstance(right, (Array, Aggregate))) + or (isinstance(left, Aggregate) and not isinstance(right, Composite)) + ): + relation_error(relation, left, right) + elif ( + (not isinstance(left, (Opaque, Aggregate)) and isinstance(right, Opaque)) + or (not isinstance(left, (Array, Aggregate)) and isinstance(right, Array)) + or (not isinstance(left, Composite) and isinstance(right, Aggregate)) + ): + relation_error(relation, left=right, right=left) + elif isinstance(left, Aggregate) and isinstance(right, Composite): + check_composite_element_range(relation, left, right) + elif isinstance(left, Composite) and isinstance(right, Aggregate): + check_composite_element_range(relation, right, left) + elif isinstance(left, Enumeration) and isinstance(right, Enumeration): + check_enumeration(relation, left, right) def __check_first_expression(self, link: Link, location: Location = None) -> None: if link.first != UNDEFINED and not isinstance(link.first, First): @@ -815,7 +889,9 @@ def get_constraints(aggregate: Aggregate, field: Variable) -> Sequence[Expr]: scalar_types = [ (f.name, t) for f, t in self.types.items() - if isinstance(t, Scalar) and f.name not in [*literals, "Message", "Final"] + if isinstance(t, Scalar) + and f.name not in literals + and f.name not in ["Message", "Final"] ] aggregate_constraints: List[Expr] = [] @@ -1460,16 +1536,16 @@ def refinements(self) -> Sequence[Refinement]: return [m for m in self.types if isinstance(m, Refinement)] -def qualified_literals(types: Mapping[Field, Type], package: ID) -> Set[ID]: - literals = set() +def qualified_literals(types: Mapping[Field, Type], package: ID) -> Dict[ID, Enumeration]: + literals = {} for t in types.values(): if isinstance(t, Enumeration): for l in t.literals: if t.package == BUILTINS_PACKAGE or t.package == package: - literals.add(l) + literals[l] = t if t.package != BUILTINS_PACKAGE: - literals.add(t.package * l) + literals[t.package * l] = t return literals diff --git a/tests/test_model.py b/tests/test_model.py index a6518988b..3cdee0bde 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -678,7 +678,7 @@ def test_message_invalid_relation_to_aggregate() -> None: types = {Field("F1"): Opaque()} assert_type( Message("P.M", structure, types), - r'^:100:20: model: error: invalid relation " <= " to aggregate$', + r'^:100:20: model: error: invalid relation " <= " between Opaque and Aggregate$', ) @@ -694,7 +694,8 @@ def test_message_invalid_element_in_relation_to_aggregate() -> None: types = {Field("F1"): MODULAR_INTEGER} assert_type( Message("P.M", structure, types), - r'^:14:7: model: error: invalid relation between "F1" and aggregate$', + r'^:14:7: model: error: invalid relation " = " ' + r"between Aggregate and ModularInteger$", ) @@ -1151,7 +1152,7 @@ class NewType(Type): @pytest.mark.skipif(not __debug__, reason="depends on contract") def test_invalid_message_field_type() -> None: - with pytest.raises(ViolationError, match=r"rflx/model.py, line 436"): + with pytest.raises(ViolationError): Message( "P.M", [Link(INITIAL, Field("F")), Link(Field("F"), FINAL)], {Field("F"): NewType("T")}, ) diff --git a/tests/test_verification.py b/tests/test_verification.py index cc18803e6..7704e62f3 100644 --- a/tests/test_verification.py +++ b/tests/test_verification.py @@ -330,21 +330,42 @@ def test_invalid_type_condition_modular_lower() -> None: ) -@pytest.mark.skip(reason="ISSUE: Componolit/RecordFlux#87") def test_invalid_type_condition_enum() -> None: structure = [ Link(INITIAL, Field("F1")), - Link(Field("F1"), Field("F2"), condition=Equal(Variable("F1"), Variable("E4"))), + Link( + Field("F1"), + Field("F2"), + condition=Equal(Variable("F1"), Variable("E4"), location=Location((22, 10))), + ), Link(Field("F2"), FINAL), ] - e1 = Enumeration("P.E1", {"E1": Number(1), "E2": Number(2), "E3": Number(3)}, Number(8), False) - e2 = Enumeration("P.E2", {"E4": Number(1), "E5": Number(2), "E6": Number(3)}, Number(8), False) + e1 = Enumeration( + "P.E1", + {"E1": Number(1), "E2": Number(2), "E3": Number(3)}, + Number(8), + False, + location=Location((10, 4)), + ) + e2 = Enumeration( + "P.E2", + {"E4": Number(1), "E5": Number(2), "E6": Number(3)}, + Number(8), + False, + location=Location((11, 4)), + ) types = { Field("F1"): e1, Field("F2"): e2, } assert_message_model_error( - structure, types, r'^invalid type of "E4" in condition 0 from field "F1" to "F2" in "P.M"', + structure, + types, + r"^" + r":22:10: model: error: comparison of incompatible enumeration literals\n" + r':10:4: model: info: of type "P.E1"\n' + r':11:4: model: info: and type "P.E2"' + r"$", ) @@ -1135,3 +1156,18 @@ def test_no_contradiction_multi() -> None: Field("F5"): RANGE_INTEGER, } Message("P.M", structure, types) + + +def test_opaque_equal_scalar() -> None: + final = Field(ID("Final", Location((10, 7)))) + structure = [ + Link(INITIAL, Field("Length")), + Link(Field("Length"), Field("Data"), length=Variable("Length")), + Link(Field("Data"), final, condition=Equal(Variable("Data"), Number(42))), + ] + types = {Field("Length"): RANGE_INTEGER, Field("Data"): Opaque()} + assert_message_model_error( + structure, + types, + r"^" r'model: error: invalid relation " = " between Opaque and Number' r"$", + )