Skip to content

Commit

Permalink
Improve parallelization of message verification
Browse files Browse the repository at this point in the history
Ref. eng/recordflux/RecordFlux#444
  • Loading branch information
treiher committed Nov 24, 2023
1 parent f143459 commit 071d6e1
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 51 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Support for FSF GNAT 13.2 (eng/recordflux/RecordFlux#1458)

### Changed

- Improve parallelization of message verification (eng/recordflux/RecordFlux#444)

### Fixed

- Proving of validity of message field after update with valid sequence (eng/recordflux/RecordFlux#1444)
Expand Down
34 changes: 12 additions & 22 deletions rflx/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from abc import abstractmethod
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from concurrent.futures import ProcessPoolExecutor
from copy import copy
from dataclasses import dataclass
from enum import Enum
from itertools import groupby
Expand Down Expand Up @@ -100,8 +99,7 @@ class ProofJob:

class ParallelProofs:
def __init__(self, workers: int) -> None:
self._proofs: list[list[ProofJob]] = []
self._current: list[ProofJob] = []
self._proofs: list[ProofJob] = []
self._workers = workers

def add( # noqa: PLR0913
Expand All @@ -123,7 +121,7 @@ def add( # noqa: PLR0913
When add_unsat is set to True, unsatisfied facts will added as extra info messages.
This option should only be set to True if ProofResult.UNSAT is considered an error.
"""
self._current.append(
self._proofs.append(
ProofJob(
goal,
facts,
Expand All @@ -136,32 +134,24 @@ def add( # noqa: PLR0913
),
)

def push(self) -> None:
if self._current:
self._proofs.append(copy(self._current))
self._current.clear()

@staticmethod
def check_proof(jobs: Sequence[ProofJob]) -> RecordFluxError:
def check_proof(job: ProofJob) -> RecordFluxError:
result = RecordFluxError()
for job in jobs:
proof = job.goal.check(job.facts)
result.extend(job.results[proof.result])
if job.add_unsat and proof.result != ProofResult.SAT:
result.extend(
[
(f'unsatisfied "{m}"', Subsystem.MODEL, Severity.INFO, locn)
for m, locn in proof.error
],
)
proof = job.goal.check(job.facts)
result.extend(job.results[proof.result])
if job.add_unsat and proof.result != ProofResult.SAT:
result.extend(
[
(f'unsatisfied "{m}"', Subsystem.MODEL, Severity.INFO, locn)
for m, locn in proof.error
],
)
return result

def check(self, error: RecordFluxError) -> None:
self.push()
with ProcessPoolExecutor(max_workers=self._workers, mp_context=MP_CONTEXT) as executor:
for e in executor.map(ParallelProofs.check_proof, self._proofs):
error.extend(e)
error.propagate()


class Expr(DBC, Base):
Expand Down
44 changes: 17 additions & 27 deletions rflx/model/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,14 +1216,19 @@ def _verify(self) -> None:

self.error.propagate()

self._prove_static_conditions()
self._prove_conflicting_conditions()
proofs = expr.ParallelProofs(self._workers)

self._prove_static_conditions(proofs)
self._prove_conflicting_conditions(proofs)
self._prove_coverage(proofs)
self._prove_overlays(proofs)
self._prove_field_positions(proofs)
self._prove_message_size(proofs)

proofs.check(self.error)

self._prove_reachability()
self._prove_contradictions()
self._prove_coverage()
self._prove_overlays()
self._prove_field_positions()
self._prove_message_size()

self.error.propagate()

Expand Down Expand Up @@ -1615,8 +1620,7 @@ def valid_upper(expression: expr.Expr) -> bool:
],
)

def _prove_static_conditions(self) -> None:
proofs = expr.ParallelProofs(self._workers)
def _prove_static_conditions(self, proofs: expr.ParallelProofs) -> None:
for l in self._structure:
if l.condition == expr.TRUE:
continue
Expand Down Expand Up @@ -1653,10 +1657,8 @@ def _prove_static_conditions(self) -> None:
unsat_error=unsat_error,
unknown_error=unknown_error,
)
proofs.check(self.error)

def _prove_conflicting_conditions(self) -> None:
proofs = expr.ParallelProofs(self._workers)
def _prove_conflicting_conditions(self, proofs: expr.ParallelProofs) -> None:
for f in (INITIAL, *self.fields):
for i1, c1 in enumerate(self.outgoing(f)):
for i2, c2 in enumerate(self.outgoing(f)):
Expand Down Expand Up @@ -1699,8 +1701,6 @@ def _prove_conflicting_conditions(self) -> None:
sat_error=error,
unknown_error=error,
)
proofs.push()
proofs.check(self.error)

def _prove_reachability(self) -> None:
def has_final(field: Field) -> bool:
Expand Down Expand Up @@ -1812,7 +1812,7 @@ def _prove_contradictions(self) -> None:
],
)

def _prove_coverage(self) -> None:
def _prove_coverage(self, proofs: expr.ParallelProofs) -> None:
"""
Prove that the fields of a message cover all message bits.
Expand All @@ -1825,7 +1825,6 @@ def _prove_coverage(self) -> None:
effectively pruning the range that this field covers from the bit range of the message. For
the overall expression, prove that it is false for all f, i.e. no bits are left.
"""
proofs = expr.ParallelProofs(self._workers)
for path in [p[:-1] for p in self.paths(FINAL) if p]:
facts: Sequence[expr.Expr]

Expand Down Expand Up @@ -1877,10 +1876,8 @@ def _prove_coverage(self) -> None:
],
)
proofs.add(expr.TRUE, facts, sat_error=error, unknown_error=error)
proofs.check(self.error)

def _prove_overlays(self) -> None:
proofs = expr.ParallelProofs(self._workers)
def _prove_overlays(self, proofs: expr.ParallelProofs) -> None:
for f in (INITIAL, *self.fields):
for p, l in [(p, p[-1]) for p in self.paths(f) if p]:
if l.first != expr.UNDEFINED and isinstance(l.first, expr.First):
Expand Down Expand Up @@ -1908,11 +1905,8 @@ def _prove_overlays(self) -> None:
unknown_error=error,
add_unsat=True,
)
proofs.push()
proofs.check(self.error)

def _prove_field_positions(self) -> None:
proofs = expr.ParallelProofs(self._workers)
def _prove_field_positions(self, proofs: expr.ParallelProofs) -> None:
for f in (*self.fields, FINAL):
for path in self.paths(f):
last = path[-1]
Expand Down Expand Up @@ -2025,12 +2019,9 @@ def _prove_field_positions(self) -> None:
sat_error=error,
unknown_error=error,
)
proofs.push()
proofs.check(self.error)

def _prove_message_size(self) -> None:
def _prove_message_size(self, proofs: expr.ParallelProofs) -> None:
"""Prove that all paths lead to a message with a size that is a multiple of 8 bit."""
proofs = expr.ParallelProofs(self._workers)
type_constraints = self.type_constraints(expr.TRUE)
field_size_constraints = [
expr.Equal(expr.Mod(expr.Size(f.name), expr.Number(8)), expr.Number(0))
Expand Down Expand Up @@ -2073,7 +2064,6 @@ def _prove_message_size(self) -> None:
sat_error=error,
unknown_error=error,
)
proofs.check(self.error)

def _prove_path_property(self, prop: expr.Expr, path: Sequence[Link]) -> expr.Proof:
conditions = [l.condition for l in path if l.condition != expr.TRUE]
Expand Down
14 changes: 12 additions & 2 deletions tests/unit/model/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,7 @@ def test_sequence_aggregate_invalid_element_type() -> None:


def test_opaque_not_byte_aligned() -> None:
t = Integer("P::T", Number(0), Number(3), Number(4))
o = Field(ID("O", location=Location((44, 3))))
with pytest.raises(
RecordFluxError,
Expand All @@ -1100,8 +1101,17 @@ def test_opaque_not_byte_aligned() -> None:
):
Message(
"P::M",
[Link(INITIAL, Field("P")), Link(Field("P"), o, size=Number(128)), Link(o, FINAL)],
{Field("P"): Integer("P::T", Number(0), Number(3), Number(2)), o: OPAQUE},
[
Link(INITIAL, Field("P")),
Link(Field("P"), o, size=Number(128)),
Link(o, Field("Q")),
Link(Field("Q"), FINAL),
],
{
Field("P"): t,
o: OPAQUE,
Field("Q"): t,
},
)


Expand Down

0 comments on commit 071d6e1

Please sign in to comment.