Skip to content

Commit

Permalink
Remove contradictions check in messages
Browse files Browse the repository at this point in the history
All problems found are already covered by the reachability check.

Ref. eng/recordflux/RecordFlux#1476
  • Loading branch information
treiher committed Nov 24, 2023
1 parent a260b4a commit 96f8025
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 280 deletions.
21 changes: 20 additions & 1 deletion rflx/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,35 @@ def result(self) -> ProofResult:
@property
def error(self) -> list[tuple[str, Optional[Location]]]:
assert self._result != ProofResult.SAT

if self._result == ProofResult.UNKNOWN:
assert self._unknown_reason is not None
return [(self._unknown_reason, None)]

solver = z3.SolverFor(self._logic)
solver.set(unsat_core=True)

facts = {f"H{index}": fact for index, fact in enumerate(self._facts)}

# Track facts for proof goals in disjunctive normal form
if isinstance(self._expr, Or):
for term in self._expr.terms:
index_start = len(facts)
if isinstance(term, And):
facts.update(
{
f"H{index}": fact
for index, fact in enumerate(term.terms, start=index_start)
},
)
else:
facts.update({f"H{index_start}": term})
else:
solver.assert_and_track(self._expr.z3expr(), "goal")

for name, fact in facts.items():
solver.assert_and_track(fact.z3expr(), name)

solver.assert_and_track(self._expr.z3expr(), "goal")
facts["goal"] = self._expr
result = solver.check()
assert result == z3.unsat, f"result should be unsat (is {result})"
Expand Down
156 changes: 69 additions & 87 deletions rflx/model/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,35 @@ def _validate_structure(self, structure_fields: set[Field]) -> bool:
],
)

def has_final(field: Field, seen: Optional[set[Field]] = None) -> bool:
"""Return True if the field has a path to the final field or a cycle was found."""

if seen is None:
seen = set()

if field in seen:
return True

seen = {field, *seen}

if field == FINAL:
return True

return any(has_final(o.target, seen) for o in self.outgoing(field))

for f in (INITIAL, *sorted(structure_fields)):
if not has_final(f):
self.error.extend(
[
(
f'no path to FINAL for field "{f.name}" in "{self.identifier}"',
Subsystem.MODEL,
Severity.ERROR,
f.identifier.location,
),
],
)

duplicate_links = defaultdict(list)
for link in self.structure:
duplicate_links[(link.source, link.target, link.condition)].append(link)
Expand Down Expand Up @@ -1230,7 +1259,6 @@ def _verify(self) -> None:
proofs.check(self.error)

self._prove_reachability()
self._prove_contradictions()

self.error.propagate()

Expand Down Expand Up @@ -1705,114 +1733,64 @@ def _prove_conflicting_conditions(self, proofs: expr.ParallelProofs) -> None:
)

def _prove_reachability(self) -> None:
def has_final(field: Field) -> bool:
if field == FINAL:
return True
return any(has_final(o.target) for o in self.outgoing(field))
"""
Find all fields that are unreachable due to contradictions on all paths to the field.
for f in (INITIAL, *self.fields):
if not has_final(f):
self.error.extend(
[
(
f'no path to FINAL for field "{f.name}" in "{self.identifier}"',
Subsystem.MODEL,
Severity.ERROR,
f.identifier.location,
),
],
)
Fields that can only be reached via an unreachable field, and are therefore unreachable due
to the same problem, are not mentioned in the resulting error message.
"""

def is_covered(path: tuple[Link, ...], subpaths: set[tuple[Link, ...]]) -> bool:
"""Check if `path` starts with any subpath contained in `subpaths`."""
return any(
p for p in subpaths if len(p) < len(path) and all(a == b for a, b in zip(p, path))
)

unreachable: set[tuple[Link, ...]] = set()

for f in (*self.fields, FINAL):
paths = []
for path in self.paths(f):
if is_covered(path, unreachable):
continue

facts = [fact for link in path for fact in self._link_expressions(link)]
last_field = path[-1].target
outgoing = self.outgoing(last_field)
if last_field != FINAL and outgoing:
facts.append(
expr.Or(
*[o.condition for o in outgoing],
location=last_field.identifier.location,
),
)
proof = expr.TRUE.check(facts)
outgoing = self.outgoing(path[-1].target)
condition = expr.Or(*[o.condition for o in outgoing]) if outgoing else expr.TRUE
proof = condition.check(
[*facts, *self.message_constraints(), *self.type_constraints(condition)],
)
if proof.result == expr.ProofResult.SAT:
break

paths.append((path, proof.error))
unreachable.add(path)
else:
error = []
error.append(
(
f'unreachable field "{f.name}" in "{self.identifier}"',
Subsystem.MODEL,
Severity.ERROR,
f.identifier.location,
),
)
for index, (path, errors) in enumerate(sorted(paths)):
if paths:
error = []
error.append(
(
f"path {index} (" + " -> ".join([l.target.name for l in path]) + "):",
f'unreachable field "{f.identifier}"',
Subsystem.MODEL,
Severity.INFO,
Severity.ERROR,
f.identifier.location,
),
)
error.extend(
[
(f'unsatisfied "{m}"', Subsystem.MODEL, Severity.INFO, l)
for m, l in errors
],
)
self.error.extend(error)

def _prove_contradictions(self) -> None:
for f in (INITIAL, *self.fields):
contradictions = []
paths = 0
for path in self.paths(f):
facts = [fact for link in path for fact in self._link_expressions(link)]
for c in self.outgoing(f):
paths += 1
contradiction = c.condition
constraints = self.message_constraints() + self.type_constraints(contradiction)
proof = contradiction.check([*constraints, *facts])
if proof.result == expr.ProofResult.SAT:
continue

contradictions.append((path, c.condition, proof.error))

if paths == len(contradictions):
for path, cond, errors in sorted(contradictions):
self.error.extend(
[
(
f'contradicting condition in "{self.identifier}"',
Subsystem.MODEL,
Severity.ERROR,
cond.location,
),
],
)
self.error.extend(
[
for path, errors in sorted(paths):
error.extend(
(
f'on path: "{l.target.identifier}"',
Subsystem.MODEL,
Severity.INFO,
l.target.identifier.location,
)
for l in path
],
)
self.error.extend(
[
)
error.extend(
(f'unsatisfied "{m}"', Subsystem.MODEL, Severity.INFO, l)
for m, l in errors
],
)
)
self.error.extend(error)

def _prove_coverage(self, proofs: expr.ParallelProofs) -> None:
"""
Expand Down Expand Up @@ -3092,13 +3070,17 @@ def _aggregate_constraints(
types: Mapping[Field, mty.Type],
expression: expr.Expr = expr.TRUE,
) -> list[expr.Expr]:
def get_constraints(aggregate: expr.Aggregate, field: expr.Variable) -> Sequence[expr.Expr]:
def get_constraints(
aggregate: expr.Aggregate,
field: expr.Variable,
location: Optional[Location],
) -> Sequence[expr.Expr]:
comp = types[Field(field.name)]
assert isinstance(comp, mty.Composite)
result = expr.Equal(
expr.Mul(aggregate.length, comp.element_size),
expr.Size(field),
location=expression.location,
location=location,
)
if isinstance(comp, mty.Sequence) and isinstance(comp.element_type, mty.Scalar):
return [
Expand All @@ -3111,9 +3093,9 @@ def get_constraints(aggregate: expr.Aggregate, field: expr.Variable) -> Sequence
for r in expression.findall(lambda x: isinstance(x, (expr.Equal, expr.NotEqual))):
assert isinstance(r, (expr.Equal, expr.NotEqual))
if isinstance(r.left, expr.Aggregate) and isinstance(r.right, expr.Variable):
aggregate_constraints.extend(get_constraints(r.left, r.right))
aggregate_constraints.extend(get_constraints(r.left, r.right, r.location))
if isinstance(r.left, expr.Variable) and isinstance(r.right, expr.Aggregate):
aggregate_constraints.extend(get_constraints(r.right, r.left))
aggregate_constraints.extend(get_constraints(r.right, r.left, r.location))

return aggregate_constraints

Expand Down
Loading

0 comments on commit 96f8025

Please sign in to comment.