Skip to content

Commit

Permalink
Fix generation of use clauses for sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
treiher committed Aug 30, 2021
1 parent 72eccc0 commit 0f32b57
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 16 deletions.
10 changes: 10 additions & 0 deletions rflx/ada.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,6 +1891,16 @@ class UnitPart:
private: List[Declaration] = dataclass_field(default_factory=list)
statements: List[Statement] = dataclass_field(default_factory=list)

def __add__(self, other: object) -> "UnitPart":
if isinstance(other, UnitPart):
return UnitPart(
self.specification + other.specification,
self.body + other.body,
self.private + other.private,
self.statements + other.statements,
)
return NotImplemented

def __iadd__(self, other: object) -> "UnitPart":
if isinstance(other, UnitPart):
self.specification += other.specification
Expand Down
48 changes: 32 additions & 16 deletions rflx/generator/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
)
from rflx.const import BUILTINS_PACKAGE, INTERNAL_PACKAGE
from rflx.error import Subsystem, fail, fatal_fail
from rflx.model.type_ import is_builtin_type, is_internal_type

from . import const

Expand All @@ -87,6 +88,7 @@ class SessionContext:
referenced_types_body: Set[rid.ID] = dataclass_field(default_factory=set)
referenced_packages_body: Set[rid.ID] = dataclass_field(default_factory=set)
referenced_has_data: Set[rid.ID] = dataclass_field(default_factory=set)
used_types: Set[rid.ID] = dataclass_field(default_factory=set)
used_types_body: Set[rid.ID] = dataclass_field(default_factory=set)
state_exception: Set[rid.ID] = dataclass_field(default_factory=set)

Expand Down Expand Up @@ -342,9 +344,11 @@ def _create_function_parameters(
]

def _create_context(self) -> Tuple[List[ContextItem], List[ContextItem]]:
declaration_context: List[ContextItem] = [
WithClause(self._prefix * const.TYPES_PACKAGE),
]
declaration_context: List[ContextItem] = []

if self._session_context.used_types | self._session_context.used_types_body:
declaration_context.append(WithClause(self._prefix * const.TYPES_PACKAGE))

body_context: List[ContextItem] = [
*([WithClause("Ada.Text_IO")] if self._debug else []),
]
Expand Down Expand Up @@ -377,7 +381,9 @@ def _create_context(self) -> Tuple[List[ContextItem], List[ContextItem]]:
],
)

for type_identifier in self._session_context.used_types_body:
for type_identifier in (
self._session_context.used_types_body - self._session_context.used_types
):
if type_identifier.parent in [INTERNAL_PACKAGE, BUILTINS_PACKAGE]:
continue
if type_identifier in [const.TYPES_LENGTH, const.TYPES_INDEX, const.TYPES_BIT_LENGTH]:
Expand Down Expand Up @@ -411,7 +417,6 @@ def _create_state_machine(self) -> UnitPart:
)

unit = UnitPart()
unit += self._create_declarations(self._session, evaluated_declarations.global_declarations)
unit += self._create_uninitialized_function(self._session.declarations.values())
unit += self._create_states(self._session)
unit += self._create_initialized_procedure(self._session)
Expand All @@ -428,15 +433,21 @@ def _create_state_machine(self) -> UnitPart:
)
unit += self._create_tick_procedure(self._session)
unit += self._create_run_procedure()
return unit
return (
self._create_declarations(self._session, evaluated_declarations.global_declarations)
+ unit
)

@staticmethod
def _create_declarations(
session: model.Session, declarations: Sequence[Declaration]
self, session: model.Session, declarations: Sequence[Declaration]
) -> UnitPart:
return UnitPart(
private=[
UseTypeClause(const.TYPES_INDEX),
*[
UseTypeClause(self._prefix * ID(t))
for t in self._session_context.used_types
if not is_builtin_type(t) and not is_internal_type(t)
],
EnumerationType(
"Session_State", {ID(f"S_{s.identifier}"): None for s in session.states}
),
Expand Down Expand Up @@ -536,9 +547,16 @@ def _create_states(self, session: model.Session) -> UnitPart:

return UnitPart(body=unit_body)

@staticmethod
def _create_initialized_procedure(session: model.Session) -> UnitPart:
def _create_initialized_procedure(self, session: model.Session) -> UnitPart:
specification = FunctionSpecification("Initialized", "Boolean")
context_declarations = [
d
for d in session.declarations.values()
if isinstance(d, decl.VariableDeclaration)
and isinstance(d.type_, (rty.Message, rty.Sequence))
]
if context_declarations:
self._session_context.used_types.add(const.TYPES_INDEX)
return UnitPart(
[
SubprogramDeclaration(specification),
Expand All @@ -549,9 +567,7 @@ def _create_initialized_procedure(session: model.Session) -> UnitPart:
AndThen(
*[
e
for d in session.declarations.values()
if isinstance(d, decl.VariableDeclaration)
and isinstance(d.type_, (rty.Message, rty.Sequence))
for d in context_declarations
for e in [
Call(
ID(d.type_identifier) * "Has_Buffer",
Expand Down Expand Up @@ -2671,8 +2687,8 @@ def _update_context(
),
]

@staticmethod
def _allocate_buffer(identifier: rid.ID, initialization: Expr = None) -> Assignment:
def _allocate_buffer(self, identifier: rid.ID, initialization: Expr = None) -> Assignment:
self._session_context.used_types_body.add(const.TYPES_INDEX)
return Assignment(
buffer_id(identifier),
New(
Expand Down

0 comments on commit 0f32b57

Please sign in to comment.