Skip to content

Commit

Permalink
Issue #104: Use new extract/insert functions in messages
Browse files Browse the repository at this point in the history
Use the correct Insert/Extract functions depending on the Byte_Order
specification of the message.
Adding a test.
This requires Recordflux-parser 0.10.0
  • Loading branch information
kanigsson committed Jan 13, 2022
1 parent 5e0cde5 commit aad3fae
Show file tree
Hide file tree
Showing 38 changed files with 3,780 additions and 128 deletions.
10 changes: 6 additions & 4 deletions rflx/generator/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
)
from rflx.common import unique
from rflx.const import BUILTINS_PACKAGE
from rflx.model import FINAL, INITIAL, Composite, Field, Message, Scalar, Type
from rflx.model import FINAL, INITIAL, ByteOrder, Composite, Field, Message, Scalar, Type

from . import common, const

Expand All @@ -70,11 +70,13 @@ class ParserGenerator:
def __init__(self, prefix: str = "") -> None:
self.prefix = prefix

def extract_function(self, type_identifier: ID) -> Subprogram:
def extract_function(self, type_identifier: ID, byte_order: ByteOrder) -> Subprogram:
return GenericFunctionInstantiation(
"Extract",
FunctionSpecification(
const.TYPES * "Extract",
const.TYPES * "Extract"
if byte_order == ByteOrder.HIGH_ORDER_FIRST
else const.TYPES * "Extract_LE",
type_identifier,
[
Parameter(["Buffer"], const.TYPES_BYTES),
Expand Down Expand Up @@ -167,7 +169,7 @@ def result(field: Field, message: Message) -> NamedAggregate:
else common.field_byte_bounds_declarations()
),
*unique(
self.extract_function(common.full_base_type_name(t))
self.extract_function(common.full_base_type_name(t), message.byte_order)
for f, t in message.field_types.items()
if isinstance(t, Scalar)
),
Expand Down
10 changes: 6 additions & 4 deletions rflx/generator/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
)
from rflx.common import unique
from rflx.const import BUILTINS_PACKAGE
from rflx.model import FINAL, Enumeration, Field, Message, Opaque, Scalar, Sequence, Type
from rflx.model import FINAL, ByteOrder, Enumeration, Field, Message, Opaque, Scalar, Sequence, Type

from . import common, const

Expand All @@ -73,11 +73,13 @@ class SerializerGenerator:
def __init__(self, prefix: str = "") -> None:
self.prefix = prefix

def insert_function(self, type_identifier: ID) -> Subprogram:
def insert_function(self, type_identifier: ID, byte_order: ByteOrder) -> Subprogram:
return GenericProcedureInstantiation(
"Insert",
ProcedureSpecification(
const.TYPES * "Insert",
const.TYPES * "Insert"
if byte_order == ByteOrder.HIGH_ORDER_FIRST
else const.TYPES * "Insert_LE",
[
Parameter(["Val"], type_identifier),
InOutParameter(["Buffer"], const.TYPES_BYTES),
Expand Down Expand Up @@ -106,7 +108,7 @@ def create_internal_functions(
*common.field_bit_location_declarations(Variable("Val.Fld")),
*common.field_byte_location_declarations(),
*unique(
self.insert_function(common.full_base_type_name(t))
self.insert_function(common.full_base_type_name(t), message.byte_order)
for t in message.field_types.values()
if isinstance(t, Scalar)
),
Expand Down
1 change: 1 addition & 0 deletions rflx/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
FINAL as FINAL,
INITIAL as INITIAL,
AbstractMessage as AbstractMessage,
ByteOrder as ByteOrder,
DerivedMessage as DerivedMessage,
Field as Field,
Link as Link,
Expand Down
80 changes: 52 additions & 28 deletions rflx/model/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections import defaultdict
from copy import copy
from dataclasses import dataclass, field as dataclass_field
from enum import Enum
from typing import Dict, List, Mapping, Optional, Sequence, Set, Tuple, Union

import rflx.typing_ as rty
Expand All @@ -19,6 +20,11 @@
from . import type_ as mty


class ByteOrder(Enum):
HIGH_ORDER_FIRST = 1
LOW_ORDER_FIRST = 2


class Field(Base):
def __init__(self, identifier: StrID) -> None:
self.identifier = ID(identifier)
Expand Down Expand Up @@ -110,7 +116,8 @@ def __init__(
identifier: StrID,
structure: Sequence[Link],
types: Mapping[Field, mty.Type],
aspects: Mapping[ID, Mapping[ID, Sequence[expr.Expr]]] = None,
checksum_aspects: Mapping[ID, Sequence[expr.Expr]] = None,
byte_order: ByteOrder = None,
location: Location = None,
error: RecordFluxError = None,
state: MessageState = None,
Expand All @@ -122,9 +129,9 @@ def __init__(
self.structure = sorted(structure)

self.__types = types
self.__aspects = aspects or {}
self.__has_unreachable = False
self.__paths_cache: Dict[Field, Set[Tuple[Link, ...]]] = {}
self.byte_order = byte_order if byte_order else ByteOrder.HIGH_ORDER_FIRST

self._state = state or MessageState()
self._unqualified_enum_literals = {
Expand All @@ -146,8 +153,9 @@ def __init__(
self._state.parameter_types = {
f: t for f, t in self.__types.items() if f not in fields
}
if ID("Checksum") in self.__aspects:
self._state.checksums = self.__aspects[ID("Checksum")]
if checksum_aspects:
self._state.checksums = checksum_aspects

except RecordFluxError:
pass

Expand All @@ -160,12 +168,13 @@ def __eq__(self, other: object) -> bool:
self.identifier == other.identifier
and self.structure == other.structure
and self.types == other.types
and self.aspects == other.aspects
and self.byte_order == other.byte_order
and self._state.checksums == self._state.checksums
)
return NotImplemented

def __repr__(self) -> str:
return verbose_repr(self, ["identifier", "structure", "types", "aspects"])
return verbose_repr(self, ["identifier", "structure", "types", "checksums", "byte_order"])

def __str__(self) -> str:
if not self.structure or not self.types:
Expand Down Expand Up @@ -215,7 +224,8 @@ def copy(
identifier: StrID = None,
structure: Sequence[Link] = None,
types: Mapping[Field, mty.Type] = None,
aspects: Mapping[ID, Mapping[ID, Sequence[expr.Expr]]] = None,
checksum_aspects: Mapping[ID, Sequence[expr.Expr]] = None,
byte_order: ByteOrder = None,
location: Location = None,
error: RecordFluxError = None,
) -> "AbstractMessage":
Expand Down Expand Up @@ -253,10 +263,6 @@ def types(self) -> Mapping[Field, mty.Type]:
"""Return parameters, fields and corresponding types topologically sorted."""
return {**self._state.parameter_types, **self._state.field_types}

@property
def aspects(self) -> Mapping[ID, Mapping[ID, Sequence[expr.Expr]]]:
return self.__aspects

@property
def checksums(self) -> Mapping[ID, Sequence[expr.Expr]]:
return self._state.checksums or {}
Expand Down Expand Up @@ -767,14 +773,17 @@ def __init__(
identifier: StrID,
structure: Sequence[Link],
types: Mapping[Field, mty.Type],
aspects: Mapping[ID, Mapping[ID, Sequence[expr.Expr]]] = None,
checksum_aspects: Mapping[ID, Sequence[expr.Expr]] = None,
byte_order: ByteOrder = None,
location: Location = None,
error: RecordFluxError = None,
state: MessageState = None,
skip_proof: bool = False,
workers: int = 1,
) -> None:
super().__init__(identifier, structure, types, aspects, location, error, state)
super().__init__(
identifier, structure, types, checksum_aspects, byte_order, location, error, state
)

self._refinements: List["Refinement"] = []
self._skip_proof = skip_proof
Expand Down Expand Up @@ -808,15 +817,17 @@ def copy(
identifier: StrID = None,
structure: Sequence[Link] = None,
types: Mapping[Field, mty.Type] = None,
aspects: Mapping[ID, Mapping[ID, Sequence[expr.Expr]]] = None,
checksum_aspects: Mapping[ID, Sequence[expr.Expr]] = None,
byte_order: ByteOrder = None,
location: Location = None,
error: RecordFluxError = None,
) -> "Message":
return Message(
identifier if identifier else self.identifier,
structure if structure else copy(self.structure),
types if types else copy(self.types),
aspects if aspects else copy(self.aspects),
checksum_aspects if checksum_aspects else copy(self.checksums),
byte_order if byte_order else self.byte_order,
location if location else self.location,
error if error else self.error,
skip_proof=self._skip_proof,
Expand Down Expand Up @@ -1822,15 +1833,17 @@ def __init__(
base: Message,
structure: Sequence[Link] = None,
types: Mapping[Field, mty.Type] = None,
aspects: Mapping[ID, Mapping[ID, Sequence[expr.Expr]]] = None,
checksum_aspects: Mapping[ID, Sequence[expr.Expr]] = None,
byte_order: ByteOrder = None,
location: Location = None,
error: RecordFluxError = None,
) -> None:
super().__init__(
identifier,
structure if structure else copy(base.structure),
types if types else copy(base.types),
aspects if aspects else copy(base.aspects),
checksum_aspects if checksum_aspects else copy(base.checksums),
byte_order if byte_order else base.byte_order,
location if location else base.location,
error if error else base.error,
)
Expand All @@ -1841,7 +1854,8 @@ def copy(
identifier: StrID = None,
structure: Sequence[Link] = None,
types: Mapping[Field, mty.Type] = None,
aspects: Mapping[ID, Mapping[ID, Sequence[expr.Expr]]] = None,
checksum_aspects: Mapping[ID, Sequence[expr.Expr]] = None,
byte_order: ByteOrder = None,
location: Location = None,
error: RecordFluxError = None,
) -> "DerivedMessage":
Expand All @@ -1850,7 +1864,8 @@ def copy(
self.base,
structure if structure else copy(self.structure),
types if types else copy(self.types),
aspects if aspects else copy(self.aspects),
checksum_aspects if checksum_aspects else copy(self.checksums),
byte_order if byte_order else self.byte_order,
location if location else self.location,
error if error else self.error,
)
Expand All @@ -1866,15 +1881,17 @@ def copy(
identifier: StrID = None,
structure: Sequence[Link] = None,
types: Mapping[Field, mty.Type] = None,
aspects: Mapping[ID, Mapping[ID, Sequence[expr.Expr]]] = None,
checksum_aspects: Mapping[ID, Sequence[expr.Expr]] = None,
byte_order: ByteOrder = None,
location: Location = None,
error: RecordFluxError = None,
) -> "UnprovenMessage":
return UnprovenMessage(
identifier if identifier else self.identifier,
structure if structure else copy(self.structure),
types if types else copy(self.types),
aspects if aspects else copy(self.aspects),
checksum_aspects if checksum_aspects else copy(self.checksums),
byte_order if byte_order else self.byte_order,
location if location else self.location,
error if error else self.error,
)
Expand All @@ -1884,7 +1901,8 @@ def proven(self, skip_proof: bool = False, workers: int = 1) -> Message:
identifier=self.identifier,
structure=self.structure,
types=self.types,
aspects=self.aspects,
checksum_aspects=self.checksums,
byte_order=self.byte_order,
location=self.location,
error=self.error,
state=self._state,
Expand Down Expand Up @@ -2054,7 +2072,8 @@ def replace(expression: expr.Expr) -> expr.Expr:
for l in message.structure
],
message.types,
message.aspects,
message.checksums,
message.byte_order,
message.location,
message.error,
)
Expand Down Expand Up @@ -2160,15 +2179,17 @@ def __init__(
base: Union[UnprovenMessage, Message],
structure: Sequence[Link] = None,
types: Mapping[Field, mty.Type] = None,
aspects: Mapping[ID, Mapping[ID, Sequence[expr.Expr]]] = None,
checksum_aspects: Mapping[ID, Sequence[expr.Expr]] = None,
byte_order: ByteOrder = None,
location: Location = None,
error: RecordFluxError = None,
) -> None:
super().__init__(
identifier,
structure if structure else copy(base.structure),
types if types else copy(base.types),
aspects if aspects else copy(base.aspects),
checksum_aspects if checksum_aspects else copy(base.checksums),
byte_order if byte_order else base.byte_order,
location if location else base.location,
error if error else base.error,
)
Expand Down Expand Up @@ -2199,7 +2220,8 @@ def copy(
identifier: StrID = None,
structure: Sequence[Link] = None,
types: Mapping[Field, mty.Type] = None,
aspects: Mapping[ID, Mapping[ID, Sequence[expr.Expr]]] = None,
checksum_aspects: Mapping[ID, Sequence[expr.Expr]] = None,
byte_order: ByteOrder = None,
location: Location = None,
error: RecordFluxError = None,
) -> "UnprovenDerivedMessage":
Expand All @@ -2208,7 +2230,8 @@ def copy(
self.base,
structure if structure else copy(self.structure),
types if types else copy(self.types),
aspects if aspects else copy(self.aspects),
checksum_aspects if checksum_aspects else copy(self.checksums),
byte_order if byte_order else self.byte_order,
location if location else self.location,
error if error else self.error,
)
Expand All @@ -2219,7 +2242,8 @@ def proven(self, skip_proof: bool = False, workers: int = 1) -> DerivedMessage:
self.base if isinstance(self.base, Message) else self.base.proven(),
self.structure,
self.types,
self.aspects,
self.checksums,
self.byte_order,
self.location,
self.error,
)
Expand Down
Loading

0 comments on commit aad3fae

Please sign in to comment.