Skip to content

Commit

Permalink
Fix returning of wrong predecessor field
Browse files Browse the repository at this point in the history
Ref. #289
  • Loading branch information
rssen committed Jun 17, 2020
1 parent 4eae7c9 commit 735990b
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions rflx/pyrflx/typevalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,8 @@ def __init__(self, model: Message, refinements: Sequence[Refinement] = None) ->
self._fields[INITIAL.name] = initial
self._simplified_mapping: Mapping[Name, Expr] = {}
self._preset_fields(INITIAL.name)
self.accessible_fields: List[str]
self._update_accessible_fields()

def __copy__(self) -> "MessageValue":
return MessageValue(self._type, self._refinements)
Expand Down Expand Up @@ -515,17 +517,28 @@ def _next_field(self, fld: str) -> str:
def _prev_field(self, fld: str) -> str:
if fld == INITIAL.name:
return ""
prev: List[str] = []
for l in self._type.incoming(Field(fld)):
if self.__simplified(l.condition) == TRUE:
return l.source.name
prev.append(l.source.name)

if len(prev) == 1:
return prev[0]
for field in prev:
if field in self.accessible_fields:
return field
return ""

def _get_length_unchecked(self, fld: str) -> Expr:
for l in self._type.incoming(Field(fld)):
if self.__simplified(l.condition) == TRUE and l.length != UNDEFINED:
return self.__simplified(l.length)

typeval = self._fields[fld].typeval
if isinstance(typeval, CompositeValue):
for l in self._type.incoming(Field(fld)):
if (
self.__simplified(l.condition) == TRUE
and l.length != UNDEFINED
and self._fields[l.source.name].set
):
return self.__simplified(l.length)
if isinstance(typeval, ScalarValue):
return typeval.size
return UNDEFINED
Expand Down Expand Up @@ -699,6 +712,7 @@ def set_refinement(fld: MessageValue.Field, fld_name: str) -> None:
)

self._preset_fields(field_name)
self._update_accessible_fields()

def _preset_fields(self, fld: str) -> None:
nxt = self._next_field(fld)
Expand Down Expand Up @@ -762,8 +776,7 @@ def bytestring(self) -> bytes:
def fields(self) -> List[str]:
return [f.name for f in self._type.fields]

@property
def accessible_fields(self) -> List[str]:
def _update_accessible_fields(self) -> None:
nxt = self._next_field(INITIAL.name)
fields: List[str] = []
while nxt and nxt != FINAL.name:
Expand All @@ -781,7 +794,7 @@ def accessible_fields(self) -> List[str]:

fields.append(nxt)
nxt = self._next_field(nxt)
return fields
self.accessible_fields = fields

def _is_valid_opaque_field(self, field: str) -> bool:
if self._get_length_unchecked(field) == UNDEFINED:
Expand Down

0 comments on commit 735990b

Please sign in to comment.