Skip to content

Commit

Permalink
Rework relation checks
Browse files Browse the repository at this point in the history
Fixes #282
  • Loading branch information
Alexander Senier committed Jun 16, 2020
1 parent 90b3c4f commit fdadf36
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 68 deletions.
196 changes: 136 additions & 60 deletions rflx/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from copy import copy
from pathlib import Path
from typing import Dict, List, Mapping, NamedTuple, Optional, Sequence, Set, Tuple
from typing import Dict, List, Mapping, NamedTuple, Optional, Sequence, Set, Tuple, Union

from rflx.common import flat_name, generic_repr
from rflx.contract import ensure, invariant
Expand All @@ -19,6 +19,7 @@
Equal,
Expr,
First,
Greater,
GreaterEqual,
Last,
Length,
Expand Down Expand Up @@ -433,6 +434,9 @@ class MessageState(Base):
field_condition: Mapping[Field, Expr] = {}


TypeExpr = Union[Type, Expr]


@invariant(lambda self: valid_message_field_types(self))
class AbstractMessage(Type):
# pylint: disable=too-many-arguments
Expand Down Expand Up @@ -664,12 +668,16 @@ def __verify_conditions(self) -> None:
self.__check_vars(l.length, state, l.length.location)
self.__check_vars(l.first, state, l.first.location)
self.__check_attributes(l.condition, l.condition.location)
self.__check_relations(l.condition, l.condition.location)
self.__check_relations(l.condition, literals)
self.__check_first_expression(l, l.first.location)
self.__check_length_expression(l)
self.error.propagate()

def __check_vars(
self, expression: Expr, state: Tuple[Set[ID], Set[ID], Set[ID]], location: Location = None,
self,
expression: Expr,
state: Tuple[Set[ID], Dict[ID, Enumeration], Set[ID]],
location: Location = None,
) -> None:
variables, literals, seen = state
for v in expression.variables(True):
Expand Down Expand Up @@ -699,68 +707,134 @@ def __check_attributes(self, expression: Expr, location: Location = None) -> Non
location,
)

def __check_relations(self, expression: Expr, location: Location = None) -> None:
for r in expression.findall(lambda x: isinstance(x, Relation)):
if (
isinstance(r, Relation)
and not isinstance(r, (Equal, NotEqual))
and (isinstance(r.left, Aggregate) or isinstance(r.right, Aggregate))
):
def __resolve_types(
self, left: Expr, right: Expr, literals: Dict[ID, Enumeration],
) -> Tuple[TypeExpr, TypeExpr]:

lefttype: TypeExpr
if isinstance(left, Variable):
if left.identifier in literals:
lefttype = literals[left.identifier]
elif Field(left.name) in self.types:
lefttype = self.types[Field(left.name)]
else:
self.error.append(
f'invalid relation "{r.symbol}" to aggregate',
'undefined variable "{left.identifier}" referenced',
Subsystem.MODEL,
Severity.ERROR,
location,
left.location,
)
else:
lefttype = left

righttype: TypeExpr
if isinstance(right, Variable):
if right.identifier in literals:
righttype = literals[right.identifier]
elif Field(right.name) in self.types:
righttype = self.types[Field(right.name)]
else:
self.error.append(
'undefined variable "{right.identifier}" referenced',
Subsystem.MODEL,
Severity.ERROR,
right.location,
)
else:
righttype = right

if isinstance(r, (Equal, NotEqual)) and (
isinstance(r.left, Aggregate) or isinstance(r.right, Aggregate)
):
if isinstance(r.left, Aggregate):
other = r.right
aggregate = r.left
elif isinstance(r.right, Aggregate):
other = r.left
aggregate = r.right
if not (
isinstance(other, Variable)
and Field(other.name) in self.fields
and isinstance(self.types[Field(other.name)], Composite)
):
self.error.propagate()
return (lefttype, righttype)

def __check_relations(self, expression: Expr, literals: Dict[ID, Enumeration]) -> None:
def check_composite_element_range(
relation: Relation, aggregate: Aggregate, composite: Composite
) -> None:
first: Expr
last: Expr
if isinstance(composite, Opaque):
first = Number(0)
last = Number(255)

if isinstance(composite, Array):
if not isinstance(composite.element_type, Integer):
self.error.append(
f'invalid relation between "{other}" and aggregate',
f'invalid array element type "{composite.element_type.identifier}"'
" for aggregate comparison",
Subsystem.MODEL,
Severity.ERROR,
location,
relation.location,
)
else:
othertype = self.types[Field(other.name)]
first: Expr
last: Expr
if isinstance(othertype, Opaque):
first = Number(0)
last = Number(255)
elif isinstance(othertype, Array):
if not isinstance(othertype.element_type, Integer):
self.error.append(
f'invalid array element type "{othertype.element_type.identifier}"'
" for aggregate comparison",
Subsystem.MODEL,
Severity.ERROR,
r.location,
)
continue
first = othertype.element_type.first.simplified()
last = othertype.element_type.last.simplified()
return
first = composite.element_type.first.simplified()
last = composite.element_type.last.simplified()

for element in aggregate.elements:
if not first <= element <= last:
self.error.append(
f"aggregate element out of range {first} .. {last}",
Subsystem.MODEL,
Severity.ERROR,
element.location,
)
for element in aggregate.elements:
if not first <= element <= last:
self.error.append(
f"aggregate element out of range {first} .. {last}",
Subsystem.MODEL,
Severity.ERROR,
element.location,
)

def check_enumeration(relation: Relation, left: Enumeration, right: Enumeration) -> None:
if left != right:
self.error.append(
"comparison of incompatible enumeration literals",
Subsystem.MODEL,
Severity.ERROR,
relation.location,
)
self.error.append(
f'of type "{left.identifier}"', Subsystem.MODEL, Severity.INFO, left.location,
)
self.error.append(
f'and type "{right.identifier}"',
Subsystem.MODEL,
Severity.INFO,
right.location,
)

def relation_error(relation: Relation, left: TypeExpr, right: TypeExpr) -> None:
self.error.append(
f'invalid relation "{relation.symbol}" between {left.__class__.__name__} '
f"and {right.__class__.__name__}",
Subsystem.MODEL,
Severity.ERROR,
relation.location,
)

for relation in expression.findall(lambda x: isinstance(x, Relation)):
assert isinstance(relation, Relation)
left, right = self.__resolve_types(
relation.left.simplified(), relation.right.simplified(), literals
)
if isinstance(relation, (Less, LessEqual, Greater, GreaterEqual)):
if (isinstance(left, Aggregate) and isinstance(right, Type)) or (
isinstance(right, Aggregate) and isinstance(left, Type)
):
relation_error(relation, left, right)
elif isinstance(relation, (Equal, NotEqual)):
# pylint: disable=too-many-boolean-expressions
if (
(isinstance(left, Opaque) and not isinstance(right, (Opaque, Aggregate)))
or (isinstance(left, Array) and not isinstance(right, (Array, Aggregate)))
or (isinstance(left, Aggregate) and not isinstance(right, Composite))
):
relation_error(relation, left, right)
elif (
(not isinstance(left, (Opaque, Aggregate)) and isinstance(right, Opaque))
or (not isinstance(left, (Array, Aggregate)) and isinstance(right, Array))
or (not isinstance(left, Composite) and isinstance(right, Aggregate))
):
relation_error(relation, left=right, right=left)
elif isinstance(left, Aggregate) and isinstance(right, Composite):
check_composite_element_range(relation, left, right)
elif isinstance(left, Composite) and isinstance(right, Aggregate):
check_composite_element_range(relation, right, left)
elif isinstance(left, Enumeration) and isinstance(right, Enumeration):
check_enumeration(relation, left, right)

def __check_first_expression(self, link: Link, location: Location = None) -> None:
if link.first != UNDEFINED and not isinstance(link.first, First):
Expand Down Expand Up @@ -815,7 +889,9 @@ def get_constraints(aggregate: Aggregate, field: Variable) -> Sequence[Expr]:
scalar_types = [
(f.name, t)
for f, t in self.types.items()
if isinstance(t, Scalar) and f.name not in [*literals, "Message", "Final"]
if isinstance(t, Scalar)
and f.name not in literals
and f.name not in ["Message", "Final"]
]

aggregate_constraints: List[Expr] = []
Expand Down Expand Up @@ -1460,16 +1536,16 @@ def refinements(self) -> Sequence[Refinement]:
return [m for m in self.types if isinstance(m, Refinement)]


def qualified_literals(types: Mapping[Field, Type], package: ID) -> Set[ID]:
literals = set()
def qualified_literals(types: Mapping[Field, Type], package: ID) -> Dict[ID, Enumeration]:
literals = {}

for t in types.values():
if isinstance(t, Enumeration):
for l in t.literals:
if t.package == BUILTINS_PACKAGE or t.package == package:
literals.add(l)
literals[l] = t
if t.package != BUILTINS_PACKAGE:
literals.add(t.package * l)
literals[t.package * l] = t

return literals

Expand Down
7 changes: 4 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def test_message_invalid_relation_to_aggregate() -> None:
types = {Field("F1"): Opaque()}
assert_type(
Message("P.M", structure, types),
r'^<stdin>:100:20: model: error: invalid relation " <= " to aggregate$',
r'^<stdin>:100:20: model: error: invalid relation " <= " between Opaque and Aggregate$',
)


Expand All @@ -694,7 +694,8 @@ def test_message_invalid_element_in_relation_to_aggregate() -> None:
types = {Field("F1"): MODULAR_INTEGER}
assert_type(
Message("P.M", structure, types),
r'^<stdin>:14:7: model: error: invalid relation between "F1" and aggregate$',
r'^<stdin>:14:7: model: error: invalid relation " = " '
r"between Aggregate and ModularInteger$",
)


Expand Down Expand Up @@ -1151,7 +1152,7 @@ class NewType(Type):

@pytest.mark.skipif(not __debug__, reason="depends on contract")
def test_invalid_message_field_type() -> None:
with pytest.raises(ViolationError, match=r"rflx/model.py, line 436"):
with pytest.raises(ViolationError):
Message(
"P.M", [Link(INITIAL, Field("F")), Link(Field("F"), FINAL)], {Field("F"): NewType("T")},
)
46 changes: 41 additions & 5 deletions tests/test_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,21 +330,42 @@ def test_invalid_type_condition_modular_lower() -> None:
)


@pytest.mark.skip(reason="ISSUE: Componolit/RecordFlux#87")
def test_invalid_type_condition_enum() -> None:
structure = [
Link(INITIAL, Field("F1")),
Link(Field("F1"), Field("F2"), condition=Equal(Variable("F1"), Variable("E4"))),
Link(
Field("F1"),
Field("F2"),
condition=Equal(Variable("F1"), Variable("E4"), location=Location((22, 10))),
),
Link(Field("F2"), FINAL),
]
e1 = Enumeration("P.E1", {"E1": Number(1), "E2": Number(2), "E3": Number(3)}, Number(8), False)
e2 = Enumeration("P.E2", {"E4": Number(1), "E5": Number(2), "E6": Number(3)}, Number(8), False)
e1 = Enumeration(
"P.E1",
{"E1": Number(1), "E2": Number(2), "E3": Number(3)},
Number(8),
False,
location=Location((10, 4)),
)
e2 = Enumeration(
"P.E2",
{"E4": Number(1), "E5": Number(2), "E6": Number(3)},
Number(8),
False,
location=Location((11, 4)),
)
types = {
Field("F1"): e1,
Field("F2"): e2,
}
assert_message_model_error(
structure, types, r'^invalid type of "E4" in condition 0 from field "F1" to "F2" in "P.M"',
structure,
types,
r"^"
r"<stdin>:22:10: model: error: comparison of incompatible enumeration literals\n"
r'<stdin>:10:4: model: info: of type "P.E1"\n'
r'<stdin>:11:4: model: info: and type "P.E2"'
r"$",
)


Expand Down Expand Up @@ -1135,3 +1156,18 @@ def test_no_contradiction_multi() -> None:
Field("F5"): RANGE_INTEGER,
}
Message("P.M", structure, types)


def test_opaque_equal_scalar() -> None:
final = Field(ID("Final", Location((10, 7))))
structure = [
Link(INITIAL, Field("Length")),
Link(Field("Length"), Field("Data"), length=Variable("Length")),
Link(Field("Data"), final, condition=Equal(Variable("Data"), Number(42))),
]
types = {Field("Length"): RANGE_INTEGER, Field("Data"): Opaque()}
assert_message_model_error(
structure,
types,
r"^" r'model: error: invalid relation " = " between Opaque and Number' r"$",
)

0 comments on commit fdadf36

Please sign in to comment.