Skip to content

Commit

Permalink
support variable buffer sizes in allocator
Browse files Browse the repository at this point in the history
Closes #713
  • Loading branch information
kanigsson committed Nov 15, 2021
1 parent 9bd0260 commit bde73e8
Show file tree
Hide file tree
Showing 23 changed files with 614 additions and 124 deletions.
107 changes: 78 additions & 29 deletions rflx/generator/allocator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, List, Optional, Sequence, Set
from itertools import zip_longest
from typing import Dict, List, Optional, Sequence, Set, Tuple

from rflx import expression as expr, typing_ as rty
from rflx.ada import (
Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(self, session: Session, prefix: str = "") -> None:
self._declaration_context: List[ContextItem] = []
self._body_context: List[ContextItem] = []
self._allocation_map: Dict[Location, int] = {}
self._size_map: Dict[int, int] = {}
self._unit_part = UnitPart(specification=[Pragma("Elaborate_Body")])
self._all_slots: Set[int] = set()
self._global_slots: Set[int] = set()
Expand Down Expand Up @@ -95,32 +97,46 @@ def free_buffer(self, identifier: ID, location: Optional[Location]) -> Sequence[
]

@staticmethod
def _ptr_type() -> ID:
return ID("Slot_Ptr_Type")
def _ptr_type(size: int) -> ID:
return ID(f"Slot_Ptr_Type_{size}")

def _create(self) -> None:
self._unit_part += self._create_ptr_subtype()
self._unit_part += self._create_ptr_subtypes()
self._unit_part += self._create_slots()
self._unit_part += self._create_init_pred()
self._unit_part += self._create_init_proc()
self._unit_part += self._create_global_allocated_pred()

def _create_ptr_subtype(self) -> UnitPart:
pred = OrElse(
Equal(Variable(self._ptr_type()), Variable("null")),
AndThen(
Equal(First(self._ptr_type()), First(const.TYPES_INDEX)),
Equal(Last(self._ptr_type()), Add(First(const.TYPES_INDEX), Number(4095))),
),
)
self._declaration_context.append(WithClause(self._prefix * const.TYPES_PACKAGE))
self._declaration_context.append(UseTypeClause(self._prefix * const.TYPES_INDEX))
self._declaration_context.append(UseTypeClause(self._prefix * const.TYPES_BYTES_PTR))
def _create_ptr_subtypes(self) -> UnitPart:
sizes_set = set(self._size_map.values())
unit = UnitPart(
specification=[
Subtype(self._ptr_type(), const.TYPES_BYTES_PTR, aspects=[DynamicPredicate(pred)]),
Subtype(
self._ptr_type(cur_size),
const.TYPES_BYTES_PTR,
aspects=[
DynamicPredicate(
OrElse(
Equal(Variable(self._ptr_type(cur_size)), Variable("null")),
AndThen(
Equal(
First(self._ptr_type(cur_size)), First(const.TYPES_INDEX)
),
Equal(
Last(self._ptr_type(cur_size)),
Add(First(const.TYPES_INDEX), Number(cur_size - 1)),
),
),
)
)
],
)
for cur_size in sizes_set
]
)
self._declaration_context.append(WithClause(self._prefix * const.TYPES_PACKAGE))
self._declaration_context.append(UseTypeClause(self._prefix * const.TYPES_INDEX))
self._declaration_context.append(UseTypeClause(self._prefix * const.TYPES_BYTES_PTR))
return unit

def _create_slots(self) -> UnitPart:
Expand All @@ -132,7 +148,8 @@ def _create_slots(self) -> UnitPart:
expression=NamedAggregate(
(
ValueRange(
First(const.TYPES_INDEX), Add(First(const.TYPES_INDEX), Number(4095))
First(const.TYPES_INDEX),
Add(First(const.TYPES_INDEX), Number(self._size_map[i] - 1)),
),
First(const.TYPES_BYTE),
)
Expand All @@ -144,7 +161,7 @@ def _create_slots(self) -> UnitPart:
pointer_decls: List[Declaration] = [
ObjectDeclaration(
[self._slot_id(i)],
self._ptr_type(),
self._ptr_type(self._size_map[i]),
)
for i in self._all_slots
]
Expand Down Expand Up @@ -218,27 +235,38 @@ def _needs_allocation(type_: rty.Type) -> bool:

def _allocate_slots(self) -> None:
"""Create memory slots for each construct in the session that requires memory."""
# pylint: disable=too-many-locals
count: int = 0

def insert(loc: Optional[Location]) -> None:
def create_slot_with_size(size: int) -> int:
nonlocal count
count += 1
assert loc is not None
self._allocation_map[loc] = count
self._all_slots.add(count)
self._size_map[count] = size
return count

def map_sloc_to_slot(loc: Optional[Location], slot: int) -> None:
assert loc is not None
self._allocation_map[loc] = slot

# global variables
for d in self._session.declarations.values():
if isinstance(d, decl.VariableDeclaration) and self._needs_allocation(d.type_):
insert(d.location)
self._global_slots.add(count)
slot = create_slot_with_size(self._session.buffer_sizes.get_size(d.identifier))
map_sloc_to_slot(d.location, slot)
self._global_slots.add(slot)

global_count = count
# local variables

# get all allocation points and required sizes in a list (one element per state)
# of lists of tuples (loc, size)

state_allocation_list = []
for s in self._session.states:
count = global_count
cur_state = []
for d in s.declarations.values():
if isinstance(d, decl.VariableDeclaration) and self._needs_allocation(d.type_):
insert(d.location)
cur_state.append((d.location, s.buffer_sizes.get_size(d.identifier)))
for a in s.actions:
if (
isinstance(a, stmt.Assignment)
Expand All @@ -247,12 +275,33 @@ def insert(loc: Optional[Location]) -> None:
and isinstance(a.expression.sequence.type_.element, rty.Message)
and isinstance(a.expression.sequence, (expr.Selected, expr.Variable))
):
insert(a.location)
cur_state.append((a.location, s.buffer_sizes.default_size))
if isinstance(a, stmt.Assignment) and isinstance(a.expression, expr.Head):
insert(a.location)
cur_state.append((a.location, s.buffer_sizes.default_size))
if (
isinstance(a, stmt.Write)
and isinstance(a.parameter.type_, rty.Message)
and not isinstance(a.parameter, expr.Variable)
):
insert(a.location)
cur_state.append((a.location, s.buffer_sizes.default_size))

state_allocation_list.append(cur_state)
# sort the list for each state by descending sizes
for entry in state_allocation_list:
entry.sort(key=lambda x: x[1], reverse=True)

# get the list of max sizes for each slot

def snd(arg: Tuple[Location, int]) -> int:
return arg[1]

max_sizes_list = [
max(map(snd, l)) for l in zip_longest(*state_allocation_list, fillvalue=(None, 0))
]
# create the actual slot with the max size for this element
states_slot_list = list(map(create_slot_with_size, max_sizes_list))
# map the slocs to the corresponding memory slot
for entry in state_allocation_list:
assert len(entry) <= len(states_slot_list)
for ((sloc, _), slot) in zip(entry, states_slot_list):
map_sloc_to_slot(sloc, slot)
5 changes: 4 additions & 1 deletion rflx/generator/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,10 @@ def _create_initialized_function(self, session: model.Session) -> UnitPart:
),
Equal( # ISSUE: Componolit/RecordFlux#713
Variable(context_id(d.identifier) * "Buffer_Last"),
Add(First(const.TYPES_INDEX), Number(4095)),
Add(
First(const.TYPES_INDEX),
Number(session.buffer_sizes.get_size(d.identifier) - 1),
),
),
]
],
Expand Down
1 change: 1 addition & 0 deletions rflx/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
State as State,
Transition as Transition,
UnprovenSession as UnprovenSession,
defaultsizes as defaultsizes,
)
from .type_ import ( # noqa: F401
BOOLEAN as BOOLEAN,
Expand Down
43 changes: 40 additions & 3 deletions rflx/model/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Iterable, List, Mapping, Optional, Sequence

from rflx import expression as expr, typing_ as rty
Expand All @@ -12,6 +13,19 @@
from . import BasicDeclaration, declaration as decl, statement as stmt, type_ as mty


@dataclass(order=True)
class BufferSizes:
default_size: int
size_map: Dict[str, int]

def get_size(self, ident: ID) -> int:
name = ident.parts[-1]
return self.default_size if name not in self.size_map else self.size_map[name]


defaultsizes: BufferSizes = BufferSizes(default_size=4096, size_map={})


class Transition(Base):
def __init__(
self,
Expand All @@ -36,7 +50,7 @@ def __str__(self) -> str:
return f"goto {self.target}{with_aspects}{if_condition}"


class State(Base):
class State(Base): # pylint: disable=too-many-instance-attributes
def __init__(
self,
identifier: StrID,
Expand All @@ -45,6 +59,7 @@ def __init__(
actions: Sequence[stmt.Statement] = None,
declarations: Sequence[decl.BasicDeclaration] = None,
description: str = None,
buffer_sizes: BufferSizes = defaultsizes,
location: Location = None,
):
# pylint: disable=too-many-arguments
Expand All @@ -63,6 +78,7 @@ def __init__(
self.declarations = {d.identifier: d for d in declarations} if declarations else {}
self.description = description
self.location = location
self.buffer_sizes = buffer_sizes

def __repr__(self) -> str:
return verbose_repr(self, ["identifier", "transitions", "actions", "declarations"])
Expand Down Expand Up @@ -142,6 +158,7 @@ def __init__(
declarations: Sequence[decl.BasicDeclaration],
parameters: Sequence[decl.FormalDeclaration],
types: Sequence[mty.Type],
buffer_sizes: BufferSizes = defaultsizes,
location: Location = None,
):
super().__init__(identifier, location)
Expand All @@ -152,6 +169,7 @@ def __init__(
self.parameters = {p.identifier: p for p in parameters}
self.types = {t.identifier: t for t in types}
self.location = location
self.buffer_sizes = buffer_sizes

refinements = [t for t in types if isinstance(t, Refinement)]

Expand Down Expand Up @@ -244,10 +262,19 @@ def __init__(
declarations: Sequence[decl.BasicDeclaration],
parameters: Sequence[decl.FormalDeclaration],
types: Sequence[mty.Type],
buffer_sizes: BufferSizes = defaultsizes,
location: Location = None,
):
super().__init__(
identifier, initial, final, states, declarations, parameters, types, location
identifier,
initial,
final,
states,
declarations,
parameters,
types,
buffer_sizes,
location,
)
self.__validate()
self.error.propagate()
Expand Down Expand Up @@ -576,11 +603,20 @@ def __init__(
declarations: Sequence[decl.BasicDeclaration],
parameters: Sequence[decl.FormalDeclaration],
types: Sequence[mty.Type],
buffer_sizes: BufferSizes = defaultsizes,
location: Location = None,
):
# pylint: disable=useless-super-delegation
super().__init__(
identifier, initial, final, states, declarations, parameters, types, location
identifier,
initial,
final,
states,
declarations,
parameters,
types,
buffer_sizes,
location,
)

def proven(self) -> Session:
Expand All @@ -592,5 +628,6 @@ def proven(self) -> Session:
list(self.declarations.values()),
list(self.parameters.values()),
list(self.types.values()),
self.buffer_sizes,
self.location,
)
Loading

0 comments on commit bde73e8

Please sign in to comment.