Skip to content

Commit

Permalink
Implement validity checks for case expressions
Browse files Browse the repository at this point in the history
Ref. #907
  • Loading branch information
senier committed Jul 12, 2022
1 parent c307ebf commit 6f334ea
Show file tree
Hide file tree
Showing 8 changed files with 387 additions and 72 deletions.
167 changes: 158 additions & 9 deletions rflx/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2775,24 +2775,139 @@ def _entity_name(expr: Expr) -> str:

class Case(Expr):
def __init__(
self, expr: Expr, choices: List[Tuple[List[StrID], Expr]], location: Location = None
self,
expr: Expr,
choices: List[Tuple[List[Union[ID, Number]], Expr]],
location: Location = None,
) -> None:
super().__init__(rty.Undefined(), location)
self.expr = expr
self.choices = choices

def _update_str(self) -> None:
data = ",\n".join(f"\n when {' | '.join(map(str, c))} => {e}" for c, e in self.choices)
self._str = intern(f"(case {self.expr} is{data})")
data = ",\n".join(f" when {' | '.join(map(str, c))} => {e}" for c, e in self.choices)
self._str = intern(f"(case {self.expr} is\n{data})")

def _check_enumeration(self) -> RecordFluxError:
assert isinstance(self.expr.type_, rty.Enumeration)
assert self.expr.type_.literals

def _check_type_subexpr(self) -> RecordFluxError:
error = RecordFluxError()
literals = [
c.name for (choice, _) in self.choices for c in choice if isinstance(c, (str, ID))
]
type_literals = [l.name for l in self.expr.type_.literals]
missing = set(type_literals) - set(literals)
if missing:
error.extend(
[
(
"not all enumeration literals covered by case expression",
Subsystem.MODEL,
Severity.ERROR,
self.location,
),
*[
(
f'missing literal "{l.name}"',
Subsystem.MODEL,
Severity.WARNING,
self.expr.type_.location,
)
for l in missing
],
]
)

invalid = set(literals) - set(type_literals)
if invalid:
error.extend(
[
(
"invalid literals used in case expression",
Subsystem.MODEL,
Severity.ERROR,
self.location,
),
*[
(
f'literal "{l.name}" not part of {self.expr.type_.identifier}',
Subsystem.MODEL,
Severity.WARNING,
self.expr.type_.location,
)
for l in invalid
],
]
)
return error

resulttype: rty.Type = rty.Any()
def _check_integer(self) -> RecordFluxError:
assert isinstance(self.expr.type_, rty.Integer)
assert self.expr.type_.bounds.lower
assert self.expr.type_.bounds.upper

error = RecordFluxError()
literals = [
c.value for (choice, _) in self.choices for c in choice if isinstance(c, Number)
]
type_literals = range(self.expr.type_.bounds.lower, self.expr.type_.bounds.upper + 1)

missing = set(type_literals) - set(literals)
if missing:
error.extend(
[
(
f"case expression does not cover full range of "
f"{self.expr.type_.identifier}",
Subsystem.MODEL,
Severity.ERROR,
self.location,
),
*[
(
f'missing literal "{l}"',
Subsystem.MODEL,
Severity.WARNING,
self.expr.type_.location,
)
for l in missing
],
]
)

invalid = set(literals) - set(type_literals)
if invalid:
error.extend(
[
(
"invalid literals used in case expression",
Subsystem.MODEL,
Severity.ERROR,
self.location,
),
*[
(
f'literal "{l}" not part of {self.expr.type_.identifier}',
Subsystem.MODEL,
Severity.WARNING,
self.expr.type_.location,
)
for l in invalid
],
]
)

return error

def _check_type_subexpr(self) -> RecordFluxError:
error = RecordFluxError()
result_type: rty.Type = rty.Any()
literals = [c for (choice, _) in self.choices for c in choice]

for _, expr in self.choices:
error += expr.check_type_instance(rty.Any)
resulttype = resulttype.common_type(expr.type_)
result_type = result_type.common_type(expr.type_)

for i1, (_, e1) in enumerate(self.choices):
for i2, (_, e2) in enumerate(self.choices):
Expand All @@ -2818,7 +2933,38 @@ def _check_type_subexpr(self) -> RecordFluxError:
error += self.expr.check_type_instance(rty.Any)
error.propagate()

if not isinstance(self.expr.type_, (rty.AnyInteger, rty.Enumeration)):
duplicates = [
e1
for i1, e1 in enumerate(literals)
for i2, e2 in enumerate(literals)
if i1 > i2 and e1 == e2
]
if duplicates:
error.extend(
[
(
"duplicate literals used in case expression",
Subsystem.MODEL,
Severity.ERROR,
self.location,
),
*[
(
f'duplicate literal "{l}"',
Subsystem.MODEL,
Severity.WARNING,
l.location,
)
for l in duplicates
],
]
)

if isinstance(self.expr.type_, rty.Enumeration):
error += self._check_enumeration()
elif isinstance(self.expr.type_, rty.Integer):
error += self._check_integer()
else:
error.extend(
[
(
Expand All @@ -2830,7 +2976,7 @@ def _check_type_subexpr(self) -> RecordFluxError:
]
)

self.type_ = resulttype
self.type_ = result_type

return error

Expand Down Expand Up @@ -2870,7 +3016,10 @@ def precedence(self) -> Precedence:

def ada_expr(self) -> ada.Expr:
choices = [
(Literal(choice).ada_expr(), expr.ada_expr())
(
Literal(choice).ada_expr() if isinstance(choice, (str, ID)) else choice.ada_expr(),
expr.ada_expr(),
)
for choices, expr in self.choices
for choice in choices
]
Expand Down
13 changes: 10 additions & 3 deletions rflx/model/type_.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def constraints(
class Integer(Scalar):
@property
def type_(self) -> rty.Type:
return rty.Integer(self.full_name, rty.Bounds(self.first.value, self.last.value))
return rty.Integer(
self.full_name, rty.Bounds(self.first.value, self.last.value), location=self.location
)

@property
def value_count(self) -> expr.Number:
Expand Down Expand Up @@ -498,7 +500,12 @@ def __str__(self) -> str:

@property
def type_(self) -> rty.Type:
return rty.Enumeration(self.full_name, self.always_valid)
return rty.Enumeration(
self.full_name,
list(map(ID, self.literals.keys())),
self.always_valid,
location=self.location,
)

@property
def value_count(self) -> expr.Number:
Expand Down Expand Up @@ -698,7 +705,7 @@ def type_(self) -> rty.Type:
],
expr.Number(1),
always_valid=False,
location=Location((0, 0), Path(str(const.BUILTINS_PACKAGE)), (0, 0)),
location=rty.BOOLEAN.location,
)

BUILTIN_TYPES = {
Expand Down
20 changes: 18 additions & 2 deletions rflx/specification/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,13 +645,29 @@ def create_selected(expression: lang.Expr, filename: Path) -> expr.Expr:

def create_case(expression: lang.Expr, filename: Path) -> expr.Expr:
assert isinstance(expression, lang.CaseExpression)
choices: List[Tuple[List[StrID], expr.Expr]] = [

def create_choice(
value: Union[lang.AbstractID, lang.Expr], filename: Path
) -> Union[ID, expr.Number]:
if isinstance(value, lang.AbstractID):
return create_id(value, filename)
assert isinstance(value, lang.Expr)
result = create_numeric_literal(value, filename)
assert isinstance(result, expr.Number)
return result

choices: List[Tuple[List[Union[ID, expr.Number]], expr.Expr]] = [
(
[create_id(s, filename) for s in c.f_selectors if isinstance(s, lang.AbstractID)],
[
create_choice(s, filename)
for s in c.f_selectors
if isinstance(s, (lang.AbstractID, lang.Expr))
],
create_expression(c.f_expression, filename),
)
for c in expression.f_choices
]

return expr.Case(
create_expression(expression.f_expression, filename),
choices,
Expand Down
10 changes: 9 additions & 1 deletion rflx/typing_.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as ty
from abc import abstractmethod
from pathlib import Path

import attr

Expand Down Expand Up @@ -108,13 +109,19 @@ def common_type(self, other: Type) -> Type:
@attr.s(frozen=True)
class Enumeration(IndependentType):
DESCRIPTIVE_NAME: ty.ClassVar[str] = "enumeration type"
literals: ty.Sequence[ID] = attr.ib()
always_valid: bool = attr.ib(False)
location: ty.Optional[Location] = attr.ib(default=None, cmp=False)

def __str__(self) -> str:
return f'{self.DESCRIPTIVE_NAME} "{self.identifier}"'


BOOLEAN = Enumeration(const.BUILTINS_PACKAGE * "Boolean")
BOOLEAN = Enumeration(
const.BUILTINS_PACKAGE * "Boolean",
[ID("False"), ID("True")],
location=Location((0, 0), Path(str(const.BUILTINS_PACKAGE)), (0, 0)),
)


@attr.s(frozen=True)
Expand Down Expand Up @@ -172,6 +179,7 @@ class Integer(AnyInteger):
DESCRIPTIVE_NAME: ty.ClassVar[str] = "integer type"
identifier: ID = attr.ib(converter=ID)
bounds: Bounds = attr.ib(Bounds(None, None))
location: ty.Optional[Location] = attr.ib(default=None, cmp=False)

def __str__(self) -> str:
return f'{self.DESCRIPTIVE_NAME} "{self.identifier}" ({self.bounds})'
Expand Down
2 changes: 2 additions & 0 deletions tests/data/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from rflx.error import Location
from rflx.expression import (
Aggregate,
And,
Expand Down Expand Up @@ -306,6 +307,7 @@
[("Zero", Number(0)), ("One", Number(1)), ("Two", Number(2))],
Number(8),
always_valid=False,
location=Location((10, 2)),
)

MESSAGE = Message(
Expand Down
Loading

0 comments on commit 6f334ea

Please sign in to comment.