Skip to content

Commit

Permalink
Add case expression to model
Browse files Browse the repository at this point in the history
Ref. #907
  • Loading branch information
senier committed Jul 12, 2022
1 parent f1c408e commit 9d371e1
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 0 deletions.
100 changes: 100 additions & 0 deletions rflx/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2773,6 +2773,106 @@ def _entity_name(expr: Expr) -> str:
return f'{expr_type} "{expr_name}"'


class Case(Expr):
def __init__(
self, expr: Expr, choices: List[Tuple[List[StrID], Expr]], location: Location = None
) -> None:
super().__init__(rty.Undefined(), location)
self.expr = expr
self.choices = choices

def _update_str(self) -> None:
data = ",\n".join(f"\n when {' | '.join(map(str, c))} => {e}" for c, e in self.choices)
self._str = intern(f"(case {self.expr} is{data})")

def _check_type_subexpr(self) -> RecordFluxError:
error = RecordFluxError()
resulttype: rty.Type = rty.Any()
for _, expr in self.choices:
error += expr.check_type_instance(rty.Any)
resulttype = resulttype.common_type(expr.type_)

for i1, (_, e1) in enumerate(self.choices):
for i2, (_, e2) in enumerate(self.choices):
if i1 < i2:
if not e1.type_.is_compatible(e2.type_):
error.extend(
[
(
f'dependent expression "{e1}" has incompatible type {e1.type_}',
Subsystem.MODEL,
Severity.ERROR,
e1.location,
),
(
f'conflicting with "{e2}" which has type {e2.type_}',
Subsystem.MODEL,
Severity.WARNING,
e2.location,
),
]
)

error += self.expr.check_type_instance(rty.Any)
error.propagate()
self.type_ = resulttype

return error

def __neg__(self) -> Expr:
raise NotImplementedError

def findall(self, match: Callable[["Expr"], bool]) -> Sequence["Expr"]:
return [
*([self] if match(self) else []),
*self.expr.findall(match),
*[e for _, v in self.choices for e in v.findall(match)],
]

def simplified(self) -> Expr:
return self.__class__(
self.expr.simplified(),
[(c, e.simplified()) for c, e in self.choices],
location=self.location,
)

def substituted(
self, func: Callable[[Expr], Expr] = None, mapping: Mapping[Name, Expr] = None
) -> Expr:
func = substitution(mapping or {}, func)
expr = func(self)
if isinstance(expr, Case):
return expr.__class__(
expr.expr.substituted(func),
[(c, e.substituted(func)) for c, e in self.choices],
location=expr.location,
)
return expr

@property
def precedence(self) -> Precedence:
raise NotImplementedError

def ada_expr(self) -> ada.Expr:
raise NotImplementedError

@lru_cache(maxsize=None)
def z3expr(self) -> z3.ExprRef:
raise NotImplementedError

def variables(self) -> List["Variable"]:
simplified = self.simplified()
assert isinstance(simplified, Case)
return list(
unique(
[
*simplified.expr.variables(),
*[v for _, e in simplified.choices for v in e.variables()],
]
)
)


def _similar_field_names(
field: ID, fields: Iterable[ID], location: Optional[Location]
) -> List[Tuple[str, Subsystem, Severity, Optional[Location]]]:
Expand Down
18 changes: 18 additions & 0 deletions rflx/specification/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,22 @@ def create_selected(expression: lang.Expr, filename: Path) -> expr.Expr:
)


def create_case(expression: lang.Expr, filename: Path) -> expr.Expr:
assert isinstance(expression, lang.CaseExpression)
choices: List[Tuple[List[StrID], expr.Expr]] = [
(
[create_id(s, filename) for s in c.f_selectors if isinstance(s, lang.AbstractID)],
create_expression(c.f_expression, filename),
)
for c in expression.f_choices
]
return expr.Case(
create_expression(expression.f_expression, filename),
choices,
location=node_location(expression, filename),
)


def create_conversion(expression: lang.Expr, filename: Path) -> expr.Expr:
assert isinstance(expression, lang.Conversion)
return expr.Conversion(
Expand Down Expand Up @@ -689,6 +705,7 @@ def create_message_aggregate(expression: lang.Expr, filename: Path) -> expr.Expr
"Conversion": create_conversion,
"MessageAggregate": create_message_aggregate,
"BinOp": create_binop,
"CaseExpression": create_case,
}


Expand Down Expand Up @@ -743,6 +760,7 @@ def create_bool_expression(expression: lang.Expr, filename: Path) -> expr.Expr:
"QuantifiedExpression": create_quantified_expression,
"Binding": create_binding,
"SelectNode": create_selected,
"CaseExpression": create_case,
}
return handlers[expression.kind_name](expression, filename)

Expand Down
76 changes: 76 additions & 0 deletions tests/unit/expression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Attribute,
Binding,
Call,
Case,
Comprehension,
Conversion,
Div,
Expand Down Expand Up @@ -2453,3 +2454,78 @@ def test_proof_invalid_logic() -> None:
None,
)
]


def test_case_variables() -> None:
assert_equal(
Case(
Variable("C"), [([ID("V1"), ID("V2")], Number(1)), ([ID("V3")], Variable("E"))]
).variables(),
[Variable("C"), Variable("E")],
)


def test_case_substituted() -> None:
c = Case(Variable("C"), [([ID("V1"), ID("V2")], Variable("E1")), ([ID("V3")], Variable("E2"))])
assert_equal(
c.substituted(lambda x: Number(42) if x == Variable("E1") else x),
Case(Variable("C"), [([ID("V1"), ID("V2")], Number(42)), ([ID("V3")], Variable("E2"))]),
)
assert_equal(
c.substituted(
lambda x: Number(42) if isinstance(x, Variable) and x.name.startswith("E") else x
),
Case(Variable("C"), [([ID("V1"), ID("V2")], Number(42)), ([ID("V3")], Number(42))]),
)
assert_equal(
c.substituted(lambda x: Variable("C_Subst") if x == Variable("C") else x),
Case(
Variable("C_Subst"),
[([ID("V1"), ID("V2")], Variable("E1")), ([ID("V3")], Variable("E2"))],
),
)
assert_equal(
c.substituted(lambda x: Variable("C_Subst") if isinstance(x, Case) else x),
Variable("C_Subst"),
)


def test_case_substituted_location() -> None:
c = Case(
Variable("C"),
[([ID("V1"), ID("V2")], Variable("E1")), ([ID("V3")], Variable("E2"))],
location=Location((1, 2)),
).substituted(lambda x: x)
assert c.location


def test_case_findall() -> None:
assert_equal(
Case(
Variable("C1"), [([ID("V1"), ID("V2")], Variable("E1")), ([ID("V3")], Variable("E2"))]
).findall(lambda x: isinstance(x, Variable) and x.name.endswith("1")),
[Variable("C1"), Variable("E1")],
)


def test_case_type() -> None:
assert_type(Case(Number(1), [([ID("V1"), ID("V2")], TRUE), ([ID("V3")], FALSE)]), rty.BOOLEAN)
assert_type(
Case(Number(1), [([ID("V1"), ID("V2")], Number(1)), ([ID("V3")], Number(2))]),
rty.UniversalInteger(rty.Bounds(1, 2)),
)
assert_type_error(
Case(Number(1), [([ID("V1"), ID("V2")], TRUE), ([ID("V3")], Number(1))]),
r'^model: error: dependent expression "True" has incompatible type enumeration type '
r'"__BUILTINS__::Boolean"\n'
r'model: warning: conflicting with "1" which has type type universal integer \(1\)$',
)


def test_case_simplified() -> None:
assert_equal(
Case(
Variable("C"), [([ID("V1"), ID("V2")], And(TRUE, FALSE)), ([ID("V3")], FALSE)]
).simplified(),
Case(Variable("C"), [([ID("V1"), ID("V2")], FALSE), ([ID("V3")], FALSE)]),
)
10 changes: 10 additions & 0 deletions tests/unit/specification/grammar_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,16 @@ def test_expression_base(string: str, expected: expr.Expr) -> None:
},
),
),
(
"(case C is when V1 | V2 => 8, when V3 => 16)",
expr.Case(
expr.Variable("C"),
[
([ID("V1"), ID("V2")], expr.Number(8)),
([ID("V3")], expr.Number(16)),
],
),
),
],
)
def test_expression_complex(string: str, expected: expr.Expr) -> None:
Expand Down

0 comments on commit 9d371e1

Please sign in to comment.