Skip to content

Commit

Permalink
Only merge satisfiable final conditions
Browse files Browse the repository at this point in the history
Ref. #410
  • Loading branch information
Alexander Senier committed Aug 23, 2020
1 parent 02d9feb commit d83631f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 19 deletions.
64 changes: 45 additions & 19 deletions rflx/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,9 +841,7 @@ def is_possibly_empty(self, field: Field) -> bool:
conditions = [l.condition for l in p if l.condition != TRUE]
lengths = [Equal(Length(l.target.name), l.length) for l in p if l.length != UNDEFINED]
empty_field = Equal(Length(field.name), Number(0))
proof = empty_field.check(
[*self.__type_constraints(empty_field), *conditions, *lengths]
)
proof = empty_field.check([*self.type_constraints(empty_field), *conditions, *lengths])
if proof.result == ProofResult.sat:
return True

Expand Down Expand Up @@ -1159,7 +1157,7 @@ def __check_length_expression(self, link: Link) -> None:
link.target.identifier.location,
)

def __type_constraints(self, expr: Expr) -> Sequence[Expr]:
def type_constraints(self, expr: Expr) -> Sequence[Expr]:
def get_constraints(aggregate: Aggregate, field: Variable) -> Sequence[Expr]:
comp = self.types[Field(field.name)]
assert isinstance(comp, Composite)
Expand Down Expand Up @@ -1301,7 +1299,7 @@ def __prove_conflicting_conditions(self) -> None:
for i2, c2 in enumerate(self.outgoing(f)):
if i1 < i2:
conflict = And(c1.condition, c2.condition)
proof = conflict.check(self.__type_constraints(conflict))
proof = conflict.check(self.type_constraints(conflict))
if proof.result == ProofResult.sat:
c1_message = str(c1.condition).replace("\n", " ")
c2_message = str(c2.condition).replace("\n", " ")
Expand Down Expand Up @@ -1388,7 +1386,7 @@ def __prove_contradictions(self) -> None:
for c in self.outgoing(f):
paths += 1
contradiction = c.condition
constraints = self.__type_constraints(contradiction)
constraints = self.type_constraints(contradiction)
proof = contradiction.check([*constraints, *facts])
if proof.result == ProofResult.sat:
continue
Expand Down Expand Up @@ -1477,8 +1475,8 @@ def __prove_field_positions(self) -> None:
Or(*[o.condition for o in outgoing], location=f.identifier.location)
)

facts.extend(self.__type_constraints(negative))
facts.extend(self.__type_constraints(start))
facts.extend(self.type_constraints(negative))
facts.extend(self.type_constraints(start))

proof = TRUE.check(facts)

Expand Down Expand Up @@ -1524,7 +1522,7 @@ def __prove_field_positions(self) -> None:
Mod(self.__target_first(last), element_size), Number(1), last.location
)
)
proof = start_aligned.check([*facts, *self.__type_constraints(start_aligned)])
proof = start_aligned.check([*facts, *self.type_constraints(start_aligned)])
if proof.result != ProofResult.unsat:
path_message = " -> ".join([p.target.name for p in path])
self.error.append(
Expand All @@ -1542,7 +1540,7 @@ def __prove_field_positions(self) -> None:
)
)
proof = length_multiple_element_size.check(
[*facts, *self.__type_constraints(length_multiple_element_size)]
[*facts, *self.type_constraints(length_multiple_element_size)]
)
if proof.result != ProofResult.unsat:
path_message = " -> ".join([p.target.name for p in path])
Expand Down Expand Up @@ -1822,6 +1820,25 @@ def proven(self, skip_proof: bool = False) -> Message:

@ensure(lambda result: valid_message_field_types(result))
def merged(self) -> "UnprovenMessage":
def prune_dangling_states(
structure: List[Link], types: Dict[Field, Type]
) -> Tuple[List[Link], Dict[Field, Type]]:
dangling = []
progress = True
while progress:
progress = False
states = {x for l in structure for x in (l.source, l.target) if x != FINAL}
for s in states:
if not [l for l in structure if l.source == s]:
dangling.append(s)
progress = True
structure = [l for l in structure if l.target not in dangling]

return (
structure,
{k: v for k, v in types.items() if k not in dangling},
)

message = self

while True:
Expand Down Expand Up @@ -1877,16 +1894,24 @@ def merged(self) -> "UnprovenMessage":
)
elif link.source == field:
for final_link in inner_message.incoming(FINAL):
structure.append(
Link(
final_link.source,
link.target,
And(link.condition, final_link.condition).simplified(),
link.length,
link.first,
link.location,
)
merged_condition = And(link.condition, final_link.condition)
proof = merged_condition.check(
[
*inner_message.type_constraints(merged_condition),
inner_message.field_condition(final_link.source),
]
)
if proof.result != ProofResult.unsat:
structure.append(
Link(
final_link.source,
link.target,
merged_condition.simplified(),
link.length,
link.first,
link.location,
)
)
else:
structure.append(link)

Expand All @@ -1899,6 +1924,7 @@ def merged(self) -> "UnprovenMessage":
**inner_message.types,
}

structure, types = prune_dangling_states(structure, types)
message = message.copy(structure=structure, types=types)

return message
Expand Down
47 changes: 47 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,53 @@ def test_merge_message_simple_derived() -> None:
)


def test_merge_message_constrained() -> None:
m1 = UnprovenMessage(
"P.M1",
[
Link(INITIAL, Field("F1")),
Link(Field("F1"), Field("F3"), Equal(Variable("F1"), Variable("True"))),
Link(Field("F1"), Field("F2")),
Link(Field("F2"), FINAL, Equal(Variable("F1"), Variable("False"))),
Link(Field("F3"), FINAL),
],
{Field("F1"): BOOLEAN, Field("F2"): BOOLEAN, Field("F3"): BOOLEAN},
)
m2 = UnprovenMessage(
"P.M2",
[
Link(INITIAL, Field("F4")),
Link(
Field("F4"),
FINAL,
And(
Equal(Variable("F4_F1"), Variable("True")),
Equal(Variable("F4_F3"), Variable("False")),
),
),
],
{Field("F4"): m1},
)
expected = UnprovenMessage(
"P.M2",
[
Link(INITIAL, Field("F4_F1"),),
Link(Field("F4_F1"), Field("F4_F3"), Equal(Variable("F4_F1"), Variable("True"))),
Link(
Field("F4_F3"),
FINAL,
And(
Equal(Variable("F4_F1"), Variable("True")),
Equal(Variable("F4_F3"), Variable("False")),
),
),
],
{Field("F4_F1"): BOOLEAN, Field("F4_F3"): BOOLEAN},
)
merged = m2.merged()
assert merged == expected


def test_merge_message_error_name_conflict() -> None:
m2_f2 = Field(ID("F2", Location((10, 5))))

Expand Down

0 comments on commit d83631f

Please sign in to comment.