Skip to content

Commit

Permalink
Add calculation of message size based on field values
Browse files Browse the repository at this point in the history
Ref. #292
  • Loading branch information
treiher committed Mar 29, 2021
1 parent 0afd0f6 commit ef67206
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 9 deletions.
70 changes: 61 additions & 9 deletions rflx/model/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from dataclasses import dataclass, field as dataclass_field
from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple, Union

import z3

import rflx.typing_ as rty
from rflx import expression as expr
from rflx.common import Base, flat_name, indent, indent_next, verbose_repr
Expand Down Expand Up @@ -385,11 +387,6 @@ def get_constraints(aggregate: expr.Aggregate, field: expr.Variable) -> Sequence
if isinstance(r.left, expr.Variable) and isinstance(r.right, expr.Aggregate):
aggregate_constraints.extend(get_constraints(r.right, r.left))

message_constraints: List[expr.Expr] = [
expr.Equal(expr.Mod(expr.First("Message"), expr.Number(8)), expr.Number(1)),
expr.Equal(expr.Mod(expr.Size("Message"), expr.Number(8)), expr.Number(0)),
]

scalar_constraints = [
c
for n, t in scalar_types
Expand All @@ -403,12 +400,18 @@ def get_constraints(aggregate: expr.Aggregate, field: expr.Variable) -> Sequence
]

return [
*message_constraints,
*aggregate_constraints,
*scalar_constraints,
*type_size_constraints,
]

@classmethod
def message_constraints(cls) -> List[expr.Expr]:
return [
expr.Equal(expr.Mod(expr.First("Message"), expr.Number(8)), expr.Number(1)),
expr.Equal(expr.Mod(expr.Size("Message"), expr.Number(8)), expr.Number(0)),
]

def __validate(self) -> None:
# pylint: disable=too-many-branches, too-many-locals

Expand Down Expand Up @@ -652,6 +655,44 @@ def is_possibly_empty(self, field: Field) -> bool:

return False

def size(self, field_values: Mapping[Field, expr.Expr]) -> expr.Expr:
for path in self.paths(FINAL):
opt = z3.Optimize()

opt.add(
expr.Equal(
expr.Size("Message"),
expr.Add(
*[
expr.Size(l.target.name)
for l in path
if l.target != FINAL and l.first == expr.UNDEFINED
]
).simplified(),
).z3expr(),
*[
expr.Equal(expr.Variable(f.name), v).z3expr()
for f, v in field_values.items()
if isinstance(v, (expr.Number, expr.Variable))
],
*[
expr.Equal(expr.Size(f.name), expr.Number(len(v.elements) * 8)).z3expr()
for f, v in field_values.items()
if isinstance(v, expr.Aggregate)
],
*[fact.z3expr() for link in path for fact in self.__link_expression(link)],
*[e.z3expr() for e in self.type_constraints(expr.TRUE)],
)

size = opt.maximize(expr.Size("Message").z3expr())

if opt.check() == z3.sat:
value = size.value()
if isinstance(value, z3.IntNumRef):
return expr.Number(value.as_long())

return expr.UNDEFINED

def __verify_expression_types(self) -> None:
types: Dict[ID, mty.Type] = {}

Expand Down Expand Up @@ -959,7 +1000,7 @@ def __prove_contradictions(self) -> None:
for c in self.outgoing(f):
paths += 1
contradiction = c.condition
constraints = self.type_constraints(contradiction)
constraints = self.message_constraints() + self.type_constraints(contradiction)
proof = contradiction.check([*constraints, *facts])
if proof.result == expr.ProofResult.SAT:
continue
Expand Down Expand Up @@ -1148,7 +1189,13 @@ def __prove_field_positions(self) -> None:
last.location,
)
)
proof = start_aligned.check([*facts, *self.type_constraints(start_aligned)])
proof = start_aligned.check(
[
*facts,
*self.message_constraints(),
*self.type_constraints(start_aligned),
]
)
if proof.result != expr.ProofResult.UNSAT:
path_message = " -> ".join([p.target.name for p in path])
self.error.append(
Expand All @@ -1168,7 +1215,11 @@ def __prove_field_positions(self) -> None:
)
)
proof = is_multiple_of_element_size.check(
[*facts, *self.type_constraints(is_multiple_of_element_size)]
[
*facts,
*self.message_constraints(),
*self.type_constraints(is_multiple_of_element_size),
]
)
if proof.result != expr.ProofResult.UNSAT:
path_message = " -> ".join([p.target.name for p in path])
Expand Down Expand Up @@ -1435,6 +1486,7 @@ def prune_dangling_fields(
merged_condition = expr.And(link.condition, final_link.condition)
proof = merged_condition.check(
[
*inner_message.message_constraints(),
*inner_message.type_constraints(merged_condition),
inner_message.field_condition(final_link.source),
]
Expand Down
16 changes: 16 additions & 0 deletions stubs/z3.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ class ArithRef(ExprRef):
def __le__(self, other: "ArithRef") -> BoolRef: ...
def __neg__(self) -> "ArithRef": ...

class IntNumRef(ArithRef):
def as_long(self) -> int: ...
def as_string(self) -> str: ...
def as_binary_string(self) -> bytes: ...

def Int(name: str, ctx: Optional[Context] = None) -> ArithRef: ...
def IntVal(val: int, ctx: Optional[Context] = None) -> ArithRef: ...
def Sum(*args: ArithRef) -> ArithRef: ...
Expand All @@ -48,3 +53,14 @@ class Solver:
def assert_and_track(self, expr: ExprRef, name: str) -> None: ...
def unsat_core(self) -> Iterable[ExprRef]: ...
def set(self, unsat_core: bool) -> None: ...

class OptimizeObjective:
def lower(self) -> ExprRef: ...
def upper(self) -> ExprRef: ...
def value(self) -> ExprRef: ...

class Optimize:
def add(self, *args: ExprRef) -> None: ...
def maximize(self, arg: ExprRef) -> OptimizeObjective: ...
def minimize(self, arg: ExprRef) -> OptimizeObjective: ...
def check(self, *assumptions: ExprRef) -> CheckSatResult: ...
39 changes: 39 additions & 0 deletions tests/unit/model/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2482,6 +2482,45 @@ def test_is_possibly_empty() -> None:
assert message.is_possibly_empty(c)


def test_size() -> None:
assert TLV_MESSAGE.size({Field("Tag"): Variable("Msg_Error")}) == Number(2)
assert TLV_MESSAGE.size(
{Field("Tag"): Variable("Msg_Data"), Field("Length"): Number(4)}
) == Number(48)

assert (
ETHERNET_FRAME.size(
{
Field("Type_Length_TPID"): Number(46),
Field("Type_Length"): Number(46),
Field("Payload"): Aggregate(*[Number(0)] * 46),
}
)
== Number(480)
)
assert (
ETHERNET_FRAME.size(
{
Field("Type_Length_TPID"): Number(0x8100),
Field("TPID"): Number(0x8100),
Field("Type_Length"): Number(46),
Field("Payload"): Aggregate(*[Number(0)] * 46),
}
)
== Number(512)
)
assert (
ETHERNET_FRAME.size(
{
Field("Type_Length_TPID"): Number(1536),
Field("Type_Length"): Number(1536),
Field("Payload"): Aggregate(*[Number(0)] * 46),
}
)
== Number(480)
)


def test_derived_message_incorrect_base_name() -> None:
with pytest.raises(
RecordFluxError,
Expand Down

0 comments on commit ef67206

Please sign in to comment.