Skip to content

Commit

Permalink
Enable code generation for parameterized messages in sessions
Browse files Browse the repository at this point in the history
Ref. #609
  • Loading branch information
treiher committed Sep 13, 2021
1 parent 9dbb14b commit 1419bfe
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 51 deletions.
38 changes: 29 additions & 9 deletions rflx/ada.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,22 +451,31 @@ class Call(Name):
def __init__(
self,
identifier: StrID,
args: Sequence[Expr] = None,
arguments: Sequence[Expr] = None,
named_arguments: Mapping[ID, Expr] = None,
negative: bool = False,
) -> None:
self.identifier = ID(identifier)
self.args = args or []
self.arguments = arguments or []
self.named_arguments = named_arguments or {}
super().__init__(negative)

def __neg__(self) -> "Call":
return self.__class__(self.identifier, self.args, not self.negative)
return self.__class__(
self.identifier, self.arguments, self.named_arguments, not self.negative
)

@property
def _representation(self) -> str:
args = ", ".join(map(str, self.args))
if args:
args = f" ({args})"
call = f"{self.identifier}{args}"
arguments = ", ".join(
[
*(str(a) for a in self.arguments),
*(f"{n} => {a}" for n, a in self.named_arguments.items()),
]
)
if arguments:
arguments = f" ({arguments})"
call = f"{self.identifier}{arguments}"
return call


Expand Down Expand Up @@ -1416,12 +1425,23 @@ def __str__(self) -> str:


class CallStatement(Statement):
def __init__(self, identifier: StrID, arguments: Sequence[Expr] = None) -> None:
def __init__(
self,
identifier: StrID,
arguments: Sequence[Expr] = None,
named_arguments: Mapping[ID, Expr] = None,
) -> None:
self.identifier = ID(identifier)
self.arguments = arguments or []
self.named_arguments = named_arguments or {}

def __str__(self) -> str:
arguments = ", ".join(map(str, self.arguments))
arguments = ", ".join(
[
*(str(a) for a in self.arguments),
*(f"{n} => {a}" for n, a in self.named_arguments.items()),
]
)
arguments = f" ({arguments})" if arguments else ""
return f"{self.identifier}{arguments};"

Expand Down
18 changes: 11 additions & 7 deletions rflx/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,8 +1397,8 @@ def __neg__(self) -> "Selected":
def _check_type_subexpr(self) -> RecordFluxError:
error = RecordFluxError()
if isinstance(self.prefix.type_, rty.Message):
if self.selector in self.prefix.type_.fields:
self.type_ = self.prefix.type_.field_types[self.selector]
if self.selector in self.prefix.type_.types:
self.type_ = self.prefix.type_.types[self.selector]
else:
error.extend(
[
Expand All @@ -1409,7 +1409,9 @@ def _check_type_subexpr(self) -> RecordFluxError:
self.location,
),
*_similar_field_names(
self.selector, self.prefix.type_.fields, self.location
self.selector,
self.prefix.type_.parameters | self.prefix.type_.fields,
self.location,
),
]
)
Expand Down Expand Up @@ -1512,7 +1514,9 @@ def representation(self) -> str:
return call

def ada_expr(self) -> ada.Expr:
return ada.Call(ada.ID(self.identifier), [a.ada_expr() for a in self.args], self.negative)
return ada.Call(
ada.ID(self.identifier), [a.ada_expr() for a in self.args], {}, self.negative
)

@lru_cache(maxsize=None)
def z3expr(self) -> z3.ExprRef:
Expand Down Expand Up @@ -2316,7 +2320,7 @@ def _check_type_subexpr(self) -> RecordFluxError:
field_combinations = set(self.type_.field_combinations)

for i, (field, expr) in enumerate(self.field_values.items()):
if field not in self.type_.fields:
if field not in self.type_.types:
error.extend(
[
(
Expand All @@ -2325,12 +2329,12 @@ def _check_type_subexpr(self) -> RecordFluxError:
Severity.ERROR,
field.location,
),
*_similar_field_names(field, self.type_.fields, field.location),
*_similar_field_names(field, self.type_.types, field.location),
]
)
continue

field_type = self.type_.field_types[field]
field_type = self.type_.types[field]

if field_type == rty.OPAQUE:
if not any(
Expand Down
89 changes: 66 additions & 23 deletions rflx/generator/session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable = too-many-lines
from contextlib import contextmanager
from dataclasses import dataclass, field as dataclass_field
from typing import Callable, Generator, Iterable, List, NoReturn, Sequence, Set, Tuple
from typing import Callable, Generator, Iterable, List, Mapping, NoReturn, Sequence, Set, Tuple

import attr

Expand Down Expand Up @@ -91,7 +91,7 @@ class SessionContext:
referenced_packages_body: Set[rid.ID] = dataclass_field(default_factory=set)
referenced_has_data: Set[rid.ID] = dataclass_field(default_factory=set)
used_types: Set[rid.ID] = dataclass_field(default_factory=set)
used_types_body: Set[rid.ID] = dataclass_field(default_factory=set)
used_types_body: List[rid.ID] = dataclass_field(default_factory=list)
state_exception: Set[rid.ID] = dataclass_field(default_factory=set)


Expand Down Expand Up @@ -337,7 +337,7 @@ def _create_function_parameters(
)
if any(
isinstance(field_type, rty.Sequence) and not field_type == rty.OPAQUE
for field_type in function.type_.field_types.values()
for field_type in function.type_.types.values()
):
fail(
"message containing sequence fields"
Expand Down Expand Up @@ -391,7 +391,7 @@ def _create_function_parameters(
def _create_context(self) -> Tuple[List[ContextItem], List[ContextItem]]:
declaration_context: List[ContextItem] = []

if self._session_context.used_types | self._session_context.used_types_body:
if self._session_context.used_types or self._session_context.used_types_body:
declaration_context.append(WithClause(self._prefix * const.TYPES_PACKAGE))

body_context: List[ContextItem] = [
Expand Down Expand Up @@ -426,11 +426,11 @@ def _create_context(self) -> Tuple[List[ContextItem], List[ContextItem]]:
],
)

for type_identifier in (
self._session_context.used_types_body - self._session_context.used_types
):
for type_identifier in self._session_context.used_types_body:
if type_identifier.parent in [INTERNAL_PACKAGE, BUILTINS_PACKAGE]:
continue
if type_identifier in self._session_context.used_types:
continue
if type_identifier in [const.TYPES_LENGTH, const.TYPES_INDEX, const.TYPES_BIT_LENGTH]:
body_context.append(
WithClause(self._prefix * const.TYPES_PACKAGE),
Expand Down Expand Up @@ -926,7 +926,19 @@ def _declare(
[
*([reset_message_contexts] if session_global else []),
self._allocate_buffer(identifier),
self._initialize_context(identifier, type_identifier),
self._initialize_context(
identifier,
type_identifier,
parameters=(
{
ID(n): First(self._ada_type(t.identifier))
for n, t in type_.parameter_types.items()
if isinstance(t, (rty.Integer, rty.Enumeration))
}
if isinstance(type_, rty.Message)
else None
),
),
]
)
result.finalization.extend(self._free_context_buffer(ID(identifier), type_identifier))
Expand Down Expand Up @@ -1072,7 +1084,7 @@ def _assign_to_binding( # pylint: disable = too-many-branches
assert isinstance(binding.expr.type_, rty.Message)
for f, f_v in binding.expr.field_values.items():
if expr.Variable(k) == f_v:
type_ = binding.expr.type_.field_types[f]
type_ = binding.expr.type_.types[f]
break
if expr.Opaque(k) == f_v:
type_ = v.type_
Expand Down Expand Up @@ -1200,15 +1212,28 @@ def _assign_to_message_aggregate(
) -> Sequence[Statement]:
assert isinstance(message_aggregate.type_, rty.Message)

self._session_context.used_types_body.add(const.TYPES_BIT_LENGTH)
self._session_context.used_types_body.append(const.TYPES_BIT_LENGTH)

size = self._message_size(message_aggregate)
size_variables = [
v for v in size.variables() if isinstance(v.type_, (rty.Message, rty.Sequence))
]
required_space = size.substituted(self._substitution()).ada_expr()
required_space = (
size.substituted(
lambda x: expr.Call(const.TYPES_BIT_LENGTH, [x])
if isinstance(x, expr.Selected)
else x
)
.substituted(self._substitution())
.ada_expr()
)
target_type = ID(message_aggregate.type_.identifier)
target_context = context_id(target)
parameter_values = {
f: v
for f, v in message_aggregate.field_values.items()
if f in message_aggregate.type_.parameter_types
}

assign_to_message_aggregate = [
self._if_sufficient_space(
Expand Down Expand Up @@ -1246,6 +1271,10 @@ def _assign_to_message_aggregate(
-Number(1),
),
],
{
ID(p): v.substituted(self._substitution()).ada_expr()
for p, v in parameter_values.items()
},
),
*self._set_message_fields(
target_type, target_context, message_aggregate, exception_handler
Expand Down Expand Up @@ -1355,7 +1384,7 @@ def _assign_to_head(

assert isinstance(head.prefix.type_.element, rty.Message)

self._session_context.used_types_body.add(const.TYPES_LENGTH)
self._session_context.used_types_body.append(const.TYPES_LENGTH)
self._session_context.referenced_types_body.add(target_type)

target_context = context_id(target)
Expand Down Expand Up @@ -1461,7 +1490,7 @@ def _assign_to_comprehension(
)
assert isinstance(comprehension.sequence.type_, rty.Sequence)

self._session_context.used_types_body.add(const.TYPES_BIT_LENGTH)
self._session_context.used_types_body.append(const.TYPES_BIT_LENGTH)

target_id = ID(target)
sequence_type_id = ID(comprehension.sequence.type_.identifier)
Expand Down Expand Up @@ -1723,7 +1752,7 @@ def _append(
) -> Sequence[Statement]:
assert isinstance(append.type_, rty.Sequence)

self._session_context.used_types_body.add(const.TYPES_BIT_LENGTH)
self._session_context.used_types_body.append(const.TYPES_BIT_LENGTH)

def check(sequence_type: ID, required_space: Expr) -> Statement:
return IfStatement(
Expand Down Expand Up @@ -1906,7 +1935,11 @@ def _reset(
target_type = ID(reset.type_.identifier)
target_context = context_id(reset.identifier)
return [
CallStatement(target_type * "Reset", [Variable(target_context)]),
CallStatement(
target_type * "Reset",
[Variable(target_context)],
{ID(n): e.ada_expr() for n, e in reset.associations.items()},
),
]

def _message_size(self, message_aggregate: expr.MessageAggregate) -> expr.Expr:
Expand All @@ -1922,12 +1955,12 @@ def func(expression: expr.Expr) -> expr.Expr:
if isinstance(expression, (expr.Relation, expr.MathBinExpr)):
for e in [expression.left, expression.right]:
if isinstance(e.type_, (rty.Integer, rty.Enumeration)):
self._session_context.used_types_body.add(e.type_.identifier)
self._session_context.used_types_body.append(e.type_.identifier)

if isinstance(expression, expr.MathAssExpr):
for e in expression.terms:
if isinstance(e.type_, (rty.Integer, rty.Enumeration)):
self._session_context.used_types_body.add(e.type_.identifier)
self._session_context.used_types_body.append(e.type_.identifier)

if isinstance(expression, expr.And):
return expr.AndThen(*expression.terms)
Expand All @@ -1938,6 +1971,11 @@ def func(expression: expr.Expr) -> expr.Expr:
if isinstance(expression, expr.Selected):
if isinstance(expression.prefix, expr.Variable):
assert isinstance(expression.prefix.type_, rty.Message)
if expression.selector in expression.prefix.type_.parameter_types:
return expr.Selected(
expr.Variable(context_id(expression.prefix.identifier)),
expression.selector,
)
return expr.Call(
ID(expression.prefix.type_.identifier) * ID(f"Get_{expression.selector}"),
[expr.Variable(context_id(expression.prefix.identifier))],
Expand Down Expand Up @@ -2051,7 +2089,7 @@ def func(expression: expr.Expr) -> expr.Expr:

if selected and literal:
assert isinstance(literal.type_, rty.Enumeration)
self._session_context.used_types_body.add(
self._session_context.used_types_body.append(
literal.type_.identifier + "_Enum"
)
return expr.AndThen(
Expand Down Expand Up @@ -2421,7 +2459,7 @@ def _set_message_fields(
isinstance(v, (expr.Variable, expr.MathBinExpr, expr.MathAssExpr))
and isinstance(v.type_, (rty.AnyInteger, rty.Enumeration, rty.Aggregate))
):
field_type = message_aggregate.type_.field_types[f]
field_type = message_aggregate.type_.types[f]
if isinstance(v, expr.Aggregate) and len(v.elements) == 0:
statements.append(
CallStatement(target_type * f"Set_{f}_Empty", [Variable(target_context)])
Expand Down Expand Up @@ -2459,7 +2497,7 @@ def _set_message_fields(
):
message_type = ID(v.prefix.type_.identifier)
message_context = context_id(v.prefix.identifier)
target_field_type = message_aggregate.type_.field_types[f]
target_field_type = message_aggregate.type_.types[f]
assert isinstance(target_field_type, (rty.Integer, rty.Enumeration, rty.Sequence))
statements = self._ensure(
statements,
Expand Down Expand Up @@ -2851,7 +2889,7 @@ def _update_context(
]

def _allocate_buffer(self, identifier: rid.ID, initialization: Expr = None) -> Assignment:
self._session_context.used_types_body.add(const.TYPES_INDEX)
self._session_context.used_types_body.append(const.TYPES_INDEX)
return Assignment(
buffer_id(identifier),
New(
Expand All @@ -2874,7 +2912,11 @@ def _allocate_buffer(self, identifier: rid.ID, initialization: Expr = None) -> A

@staticmethod
def _initialize_context(
identifier: rid.ID, type_: ID, first: Expr = None, last: Expr = None
identifier: rid.ID,
type_: ID,
first: Expr = None,
last: Expr = None,
parameters: Mapping[ID, Expr] = None,
) -> CallStatement:
return CallStatement(
type_ * "Initialize",
Expand All @@ -2891,13 +2933,14 @@ def _initialize_context(
else []
),
],
parameters,
)

def _copy_to_buffer(
self, type_: ID, source_context: ID, target_buffer: ID, exception_handler: ExceptionHandler
) -> IfStatement:
"""A deferred exception might be raised."""
self._session_context.used_types_body.add(const.TYPES_LENGTH)
self._session_context.used_types_body.append(const.TYPES_LENGTH)
return IfStatement(
[
(
Expand Down
Loading

0 comments on commit 1419bfe

Please sign in to comment.