Skip to content

Commit

Permalink
Prevent code generator from substituting literal field names
Browse files Browse the repository at this point in the history
Ref. #47
  • Loading branch information
Alexander Senier committed Aug 1, 2020
1 parent 7a3a09f commit b4bda68
Show file tree
Hide file tree
Showing 7 changed files with 505 additions and 173 deletions.
27 changes: 22 additions & 5 deletions rflx/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,9 +825,12 @@ def z3expr(self) -> z3.ArithRef:


class Name(Expr):
def __init__(self, negative: bool = False, location: Location = None) -> None:
def __init__(
self, negative: bool = False, immutable: bool = False, location: Location = None
) -> None:
super().__init__(location)
self.negative = negative
self.immutable = immutable

def __str__(self) -> str:
if self.negative:
Expand All @@ -851,6 +854,8 @@ def representation(self) -> str:
def substituted(
self, func: Callable[[Expr], Expr] = None, mapping: Mapping["Name", Expr] = None
) -> Expr:
if self.immutable:
return self
func = substitution(mapping or {}, func)
positive_self = copy(self)
positive_self.negative = False
Expand All @@ -865,11 +870,23 @@ def z3expr(self) -> z3.ExprRef:

class Variable(Name):
def __init__(
self, identifier: StrID, negative: bool = False, location: Location = None
self,
identifier: StrID,
negative: bool = False,
immutable: bool = False,
location: Location = None,
) -> None:
super().__init__(negative, location)
super().__init__(negative, immutable, location)
self.identifier = ID(identifier)

def __eq__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return self.negative == other.negative and self.identifier == other.identifier
return NotImplemented

def __hash__(self) -> int:
return hash(self.identifier)

@property
def name(self) -> str:
return str(self.identifier)
Expand Down Expand Up @@ -905,7 +922,7 @@ def __init__(self, prefix: Union[StrID, Expr], negative: bool = False) -> None:
prefix = Variable(prefix)

self.prefix: Expr = prefix
super().__init__(negative, prefix.location)
super().__init__(negative, location=prefix.location)

@property
def representation(self) -> str:
Expand Down Expand Up @@ -1058,7 +1075,7 @@ class Selected(Name):
def __init__(
self, prefix: Expr, selector_name: StrID, negative: bool = False, location: Location = None
) -> None:
super().__init__(negative, location)
super().__init__(negative, location=location)
self.prefix = prefix
self.selector_name = ID(selector_name)

Expand Down
126 changes: 91 additions & 35 deletions rflx/generator/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def byte_aggregate(aggregate: Aggregate) -> Aggregate:
[
Selected(
Indexed(
Variable("Cursors"), Variable(field.affixed_name),
Variable("Cursors"),
Variable(field.affixed_name, immutable=True),
),
"First",
)
Expand All @@ -111,7 +112,8 @@ def byte_aggregate(aggregate: Aggregate) -> Aggregate:
[
Selected(
Indexed(
Variable("Cursors"), Variable(field.affixed_name),
Variable("Cursors"),
Variable(field.affixed_name, immutable=True),
),
"Last",
)
Expand All @@ -122,7 +124,8 @@ def byte_aggregate(aggregate: Aggregate) -> Aggregate:
aggregate,
)
equal_call = Call(
"Equal", [Variable("Ctx"), Variable(field.affixed_name), aggregate]
"Equal",
[Variable("Ctx"), Variable(field.affixed_name, immutable=True), aggregate],
)
return equal_call if isinstance(expression, Equal) else Not(equal_call)

Expand All @@ -132,7 +135,7 @@ def field_value(field: Field) -> Expr:
return Selected(
Indexed(
Variable("Ctx.Cursors" if not embedded else "Cursors"),
Variable(field.affixed_name),
Variable(field.affixed_name, immutable=True),
),
f"Value.{field.name}_Value",
)
Expand Down Expand Up @@ -175,21 +178,27 @@ def prefixed(name: str) -> Expr:

def field_first(field: Field) -> Expr:
if public:
return Call("Field_First", [Variable("Ctx"), Variable(field.affixed_name)])
return Selected(Indexed(cursors, Variable(field.affixed_name)), "First")
return Call(
"Field_First", [Variable("Ctx"), Variable(field.affixed_name, immutable=True)]
)
return Selected(Indexed(cursors, Variable(field.affixed_name, immutable=True)), "First")

def field_last(field: Field) -> Expr:
if public:
return Call("Field_Last", [Variable("Ctx"), Variable(field.affixed_name)])
return Selected(Indexed(cursors, Variable(field.affixed_name)), "Last")
return Call(
"Field_Last", [Variable("Ctx"), Variable(field.affixed_name, immutable=True)]
)
return Selected(Indexed(cursors, Variable(field.affixed_name, immutable=True)), "Last")

def field_length(field: Field) -> Expr:
if public:
return Call("Field_Length", [Variable("Ctx"), Variable(field.affixed_name)])
return Call(
"Field_Length", [Variable("Ctx"), Variable(field.affixed_name, immutable=True)]
)
return Add(
Sub(
Selected(Indexed(cursors, Variable(field.affixed_name)), "Last"),
Selected(Indexed(cursors, Variable(field.affixed_name)), "First"),
Selected(Indexed(cursors, Variable(field.affixed_name, immutable=True)), "Last"),
Selected(Indexed(cursors, Variable(field.affixed_name, immutable=True)), "First"),
),
Number(1),
)
Expand All @@ -204,7 +213,8 @@ def field_value(field: Field, field_type: Type) -> Expr:
target_type,
[
Selected(
Indexed(cursors, Variable(field.affixed_name)), f"Value.{field.name}_Value"
Indexed(cursors, Variable(field.affixed_name, immutable=True)),
f"Value.{field.name}_Value",
)
],
)
Expand All @@ -215,7 +225,8 @@ def field_value(field: Field, field_type: Type) -> Expr:
target_type,
[
Selected(
Indexed(cursors, Variable(field.affixed_name)), f"Value.{field.name}_Value"
Indexed(cursors, Variable(field.affixed_name, immutable=True)),
f"Value.{field.name}_Value",
)
],
)
Expand Down Expand Up @@ -280,7 +291,10 @@ def prefixed(name: str) -> Expr:
.substituted(
mapping={
UNDEFINED: Add(
Selected(Indexed(prefixed("Cursors"), Variable(source.affixed_name)), "Last"),
Selected(
Indexed(prefixed("Cursors"), Variable(source.affixed_name, immutable=True)),
"Last",
),
Number(1),
)
}
Expand All @@ -294,7 +308,11 @@ def prefixed(name: str) -> Expr:
AndThen(
Call(
"Structural_Valid",
[Indexed(prefixed("Cursors"), Variable(target.affixed_name))],
[
Indexed(
prefixed("Cursors"), Variable(target.affixed_name, immutable=True)
)
],
),
condition,
),
Expand All @@ -303,11 +321,17 @@ def prefixed(name: str) -> Expr:
Add(
Sub(
Selected(
Indexed(prefixed("Cursors"), Variable(target.affixed_name)),
Indexed(
prefixed("Cursors"),
Variable(target.affixed_name, immutable=True),
),
"Last",
),
Selected(
Indexed(prefixed("Cursors"), Variable(target.affixed_name)),
Indexed(
prefixed("Cursors"),
Variable(target.affixed_name, immutable=True),
),
"First",
),
),
Expand All @@ -317,14 +341,19 @@ def prefixed(name: str) -> Expr:
),
Equal(
Selected(
Indexed(prefixed("Cursors"), Variable(target.affixed_name)),
Indexed(
prefixed("Cursors"), Variable(target.affixed_name, immutable=True)
),
"Predecessor",
),
Variable(source.affixed_name),
Variable(source.affixed_name, immutable=True),
),
Equal(
Selected(
Indexed(prefixed("Cursors"), Variable(target.affixed_name)), "First"
Indexed(
prefixed("Cursors"), Variable(target.affixed_name, immutable=True)
),
"First",
),
first,
),
Expand All @@ -347,7 +376,12 @@ def valid_predecessors_invariant() -> Expr:
(
Call(
"Structural_Valid",
[Indexed(Variable("Cursors"), Variable(f.affixed_name))],
[
Indexed(
Variable("Cursors"),
Variable(f.affixed_name, immutable=True),
)
],
),
Or(
*[
Expand All @@ -359,18 +393,19 @@ def valid_predecessors_invariant() -> Expr:
[
Indexed(
Variable("Cursors"),
Variable(l.source.affixed_name),
Variable(l.source.affixed_name, immutable=True),
)
],
),
Equal(
Selected(
Indexed(
Variable("Cursors"), Variable(f.affixed_name),
Variable("Cursors"),
Variable(f.affixed_name, immutable=True),
),
"Predecessor",
),
Variable(l.source.affixed_name),
Variable(l.source.affixed_name, immutable=True),
),
l.condition.substituted(
substitution(message, embedded=True)
Expand All @@ -397,13 +432,24 @@ def invalid_successors_invariant() -> Expr:
*[
Call(
"Invalid",
[Indexed(Variable("Cursors"), Variable(p.affixed_name))],
[
Indexed(
Variable("Cursors"),
Variable(p.affixed_name, immutable=True),
)
],
)
for p in message.direct_predecessors(f)
]
),
Call(
"Invalid", [Indexed(Variable("Cursors"), Variable(f.affixed_name))],
"Invalid",
[
Indexed(
Variable("Cursors"),
Variable(f.affixed_name, immutable=True),
)
],
),
)
]
Expand Down Expand Up @@ -485,11 +531,15 @@ def valid_path_to_next_field_condition(message: Message, field: Field) -> Sequen
And(
Equal(
Call(
"Predecessor", [Variable("Ctx"), Variable(l.target.affixed_name)],
"Predecessor",
[Variable("Ctx"), Variable(l.target.affixed_name, immutable=True)],
),
Variable(field.affixed_name),
Variable(field.affixed_name, immutable=True),
),
Call("Valid_Next", [Variable("Ctx"), Variable(l.target.affixed_name)])
Call(
"Valid_Next",
[Variable("Ctx"), Variable(l.target.affixed_name, immutable=True)],
)
if l.target != FINAL
else TRUE,
),
Expand All @@ -510,7 +560,10 @@ def sufficient_space_for_field_condition(field_name: Name) -> Expr:

def initialize_field_statements(message: Message, field: Field, prefix: str) -> Sequence[Statement]:
return [
CallStatement("Reset_Dependent_Fields", [Variable("Ctx"), Variable(field.affixed_name)],),
CallStatement(
"Reset_Dependent_Fields",
[Variable("Ctx"), Variable(field.affixed_name, immutable=True)],
),
Assignment(
"Ctx",
Aggregate(
Expand All @@ -528,16 +581,18 @@ def initialize_field_statements(message: Message, field: Field, prefix: str) ->
# predicate as assert
PragmaStatement("Assert", [str(message_structure_invariant(message, prefix))],),
Assignment(
Indexed(Variable("Ctx.Cursors"), Variable(field.affixed_name)),
Indexed(Variable("Ctx.Cursors"), Variable(field.affixed_name, immutable=True)),
NamedAggregate(
("State", Variable("S_Structural_Valid")),
("First", Variable("First")),
("Last", Variable("Last")),
("Value", NamedAggregate(("Fld", Variable(field.affixed_name))),),
("Value", NamedAggregate(("Fld", Variable(field.affixed_name, immutable=True))),),
(
"Predecessor",
Selected(
Indexed(Variable("Ctx.Cursors"), Variable(field.affixed_name),),
Indexed(
Variable("Ctx.Cursors"), Variable(field.affixed_name, immutable=True),
),
"Predecessor",
),
),
Expand All @@ -546,10 +601,11 @@ def initialize_field_statements(message: Message, field: Field, prefix: str) ->
Assignment(
Indexed(
Variable("Ctx.Cursors"),
Call("Successor", [Variable("Ctx"), Variable(field.affixed_name)]),
Call("Successor", [Variable("Ctx"), Variable(field.affixed_name, immutable=True)]),
),
NamedAggregate(
("State", Variable("S_Invalid")), ("Predecessor", Variable(field.affixed_name)),
("State", Variable("S_Invalid")),
("Predecessor", Variable(field.affixed_name, immutable=True)),
),
),
]
Expand Down
Loading

0 comments on commit b4bda68

Please sign in to comment.