From e37a8057d3a5bb7d7bcee5b83b425e0f1d904a57 Mon Sep 17 00:00:00 2001 From: Tobias Reiher Date: Fri, 21 Oct 2022 20:10:31 +0200 Subject: [PATCH] Add initial draft of intermediate representation Ref. #1204, #861 --- rflx/tac.py | 424 +++++++++++++++++++++++++++++++++++++++++ stubs/z3.pyi | 6 + tests/unit/tac_test.py | 207 ++++++++++++++++++++ 3 files changed, 637 insertions(+) create mode 100644 rflx/tac.py create mode 100644 tests/unit/tac_test.py diff --git a/rflx/tac.py b/rflx/tac.py new file mode 100644 index 000000000..f8a3caedc --- /dev/null +++ b/rflx/tac.py @@ -0,0 +1,424 @@ +# pylint: disable = fixme + +""" +Intermediate representation in three-address code (TAC) format. + +This module is still under development (cf. https://github.com/Componolit/RecordFlux/issues/1204). +""" + +from __future__ import annotations + +from abc import abstractmethod + +import z3 + +from rflx.common import Base +from rflx.identifier import ID, StrID + + +class Stmt(Base): + @abstractmethod + def z3expr(self) -> z3.BoolRef: + raise NotImplementedError + + +class Assign(Stmt): + def __init__(self, target: StrID, expression: Expr) -> None: + self._target = ID(target) + self._expression = expression + + @property + def target(self) -> ID: + return self._target + + @property + def expression(self) -> Expr: + return self._expression + + def z3expr(self) -> z3.BoolRef: + target = ( + IntVar(self._target) if isinstance(self._expression, IntExpr) else BoolVar(self._target) + ) + return target.z3expr() == self._expression.z3expr() + + +class FieldAssign(Stmt): + pass # TODO + + +class Append(Stmt): + pass # TODO + + +class Extend(Stmt): + pass # TODO + + +class Reset(Stmt): + pass # TODO + + +class Read(Stmt): + pass # TODO + + +class Write(Stmt): + pass # TODO + + +class Assert(Stmt): + def __init__(self, expression: BoolExpr) -> None: + self._expression = expression + + def z3expr(self) -> z3.BoolRef: + return self._expression.z3expr() + + +class Expr(Base): + @abstractmethod + def z3expr(self) -> z3.ExprRef: + raise NotImplementedError + + @property + def preconditions(self) -> list[Stmt]: + return [] + + +class BasicExpr(Expr): + pass + + +class IntExpr(Expr): + @abstractmethod + def z3expr(self) -> z3.ArithRef: + raise NotImplementedError + + +class BoolExpr(Expr): + @abstractmethod + def z3expr(self) -> z3.BoolRef: + raise NotImplementedError + + +class BasicIntExpr(BasicExpr, IntExpr): + pass + + +class BasicBoolExpr(BasicExpr, BoolExpr): + pass + + +class Var(BasicExpr): + def __init__(self, identifier: StrID) -> None: + self._identifier = ID(identifier) + + @property + def identifier(self) -> ID: + return self._identifier + + +class IntVar(Var, BasicIntExpr): + def __init__(self, identifier: StrID, negative: bool = False) -> None: + super().__init__(identifier) + self._negative = negative + + def z3expr(self) -> z3.ArithRef: + expr = z3.Int(str(self._identifier)) + return -expr if self._negative else expr + + +class BoolVar(Var, BasicBoolExpr): + def z3expr(self) -> z3.BoolRef: + return z3.Bool(str(self._identifier)) + + +class MsgVar(Var): + def z3expr(self) -> z3.ExprRef: + return z3.Const(str(self._identifier), z3.DeclareSort("Msg")) + + +class SeqVar(Var): + def z3expr(self) -> z3.ExprRef: + return z3.Const(str(self._identifier), z3.DeclareSort("Seq")) + + +class EnumLit(BasicIntExpr): + def __init__(self, identifier: StrID) -> None: + assert str(identifier) not in ("True", "False") + self._identifier = ID(identifier) + + def z3expr(self) -> z3.ArithRef: + return z3.Int(str(self._identifier)) + + # TODO: return value of literal + # def z3facts(self) -> list[z3.ExprRef]: + # pass + + +class IntVal(BasicIntExpr): + def __init__(self, value: int) -> None: + self._value = value + + def z3expr(self) -> z3.ArithRef: + return z3.IntVal(self._value) + + +class BoolVal(BasicBoolExpr): + def __init__(self, value: bool) -> None: + self._value = value + + def z3expr(self) -> z3.BoolRef: + return z3.BoolVal(self._value) + + +class UnaryExpr(Expr): + def __init__(self, expression: BasicExpr) -> None: + self._expression = expression + + +class UnaryBoolExpr(UnaryExpr, BoolExpr): + def __init__(self, expression: BasicBoolExpr) -> None: + super().__init__(expression) + self._expression: BasicBoolExpr + + +class BinaryExpr(Expr): + def __init__(self, left: BasicExpr, right: BasicExpr) -> None: + self._left = left + self._right = right + + +class BinaryIntExpr(BinaryExpr, IntExpr): + def __init__(self, left: BasicIntExpr, right: BasicIntExpr) -> None: + super().__init__(left, right) + self._left: BasicIntExpr + self._right: BasicIntExpr + + +class BinaryBoolExpr(BinaryExpr, BoolExpr): + def __init__(self, left: BasicBoolExpr, right: BasicBoolExpr) -> None: + super().__init__(left, right) + self._left: BasicBoolExpr + self._right: BasicBoolExpr + + +class Add(BinaryIntExpr): + def z3expr(self) -> z3.ArithRef: + return self._left.z3expr() + self._right.z3expr() + + @property + def preconditions(self) -> list[Stmt]: + type_last = IntVar("Type'Last") # TODO + return [ + # Left + Right <= Type'Last + Assign("D", Sub(type_last, self._right)), + Assert(LessEqual(self._left, IntVar("D"))), + ] + + +class Sub(BinaryIntExpr): + def z3expr(self) -> z3.ArithRef: + return self._left.z3expr() - self._right.z3expr() + + @property + def preconditions(self) -> list[Stmt]: + type_first = IntVar("Type'First") # TODO + return [ + # Left - Right >= Type'First + Assign("S", Add(type_first, self._right)), + Assert(GreaterEqual(self._left, IntVar("S"))), + ] + + +class Mul(BinaryIntExpr): + def z3expr(self) -> z3.ArithRef: + return self._left.z3expr() * self._right.z3expr() + + @property + def preconditions(self) -> list[Stmt]: + type_last = IntVar("Type'Last") # TODO + return [ + # Left * Right <= Type'Last + Assign("D", Div(type_last, self._right)), + Assert(LessEqual(self._left, IntVar("D"))), + ] + + +class Div(BinaryIntExpr): + def z3expr(self) -> z3.ArithRef: + return self._left.z3expr() / self._right.z3expr() + + @property + def preconditions(self) -> list[Stmt]: + return [ + # Right /= 0 + Assert(NotEqual(self._right, IntVal(0))), + ] + + +class Pow(BinaryIntExpr): + def z3expr(self) -> z3.ArithRef: + return self._left.z3expr() ** self._right.z3expr() + + @property + def preconditions(self) -> list[Stmt]: + type_last = IntVar("Type'Last") # TODO + return [ + # Left ** Right <= Type'Last + Assign("P", Pow(self._left, self._right)), + Assert(LessEqual(IntVar("P"), type_last)), + ] + + +class Mod(BinaryIntExpr): + def z3expr(self) -> z3.ArithRef: + return self._left.z3expr() % self._right.z3expr() + + @property + def preconditions(self) -> list[Stmt]: + return [ + # Right /= 0 + Assert(NotEqual(self._right, IntVal(0))), + ] + + +class Not(UnaryBoolExpr): + def z3expr(self) -> z3.BoolRef: + return z3.Not(self._expression.z3expr()) + + +class And(BinaryBoolExpr): + def z3expr(self) -> z3.BoolRef: + return z3.And(self._left.z3expr(), self._right.z3expr()) + + +class Or(BinaryBoolExpr): + def z3expr(self) -> z3.BoolRef: + return z3.Or(self._left.z3expr(), self._right.z3expr()) + + +class Relation(BoolExpr): + def __init__(self, left: BasicIntExpr, right: BasicIntExpr) -> None: + self._left = left + self._right = right + + +class Less(Relation): + def z3expr(self) -> z3.BoolRef: + return self._left.z3expr() < self._right.z3expr() + + +class LessEqual(Relation): + def z3expr(self) -> z3.BoolRef: + return self._left.z3expr() <= self._right.z3expr() + + +class Equal(Relation): + def z3expr(self) -> z3.BoolRef: + return self._left.z3expr() == self._right.z3expr() + + +class GreaterEqual(Relation): + def z3expr(self) -> z3.BoolRef: + return self._left.z3expr() >= self._right.z3expr() + + +class Greater(Relation): + def z3expr(self) -> z3.BoolRef: + return self._left.z3expr() > self._right.z3expr() + + +class NotEqual(Relation): + def z3expr(self) -> z3.BoolRef: + return self._left.z3expr() != self._right.z3expr() + + +# class QuantifiedExpr(Expr): +# def __init__(self, parameter: StrID, iterable: BasicExpr, predicate: Expr) -> None: +# self._parameter = ID(parameter) +# self._iterable = iterable +# self._predicate = predicate + + +# class ForAll(QuantifiedExpr): +# pass # TODO + + +# class ForSome(QuantifiedExpr): +# pass # TODO + + +class Call(Expr): + def __init__(self, identifier: StrID, *arguments: BasicExpr) -> None: + self._identifier = ID(identifier) + self._arguments = list(arguments) + self._preconditions: list[Stmt] = [] + + @property + def preconditions(self) -> list[Stmt]: + return self._preconditions + + @preconditions.setter + def preconditions(self, preconditions: list[Stmt]) -> None: + self._preconditions = preconditions + + +class IntCall(Call, IntExpr): + def __init__(self, identifier: StrID, *arguments: BasicExpr, negative: bool = False) -> None: + super().__init__(identifier, *arguments) + self._negative = negative + + def z3expr(self) -> z3.ArithRef: + # TODO: consider non-idempotent calls + expr = z3.Int(str(self._identifier)) + return -expr if self._negative else expr + + +class BoolCall(Call, BoolExpr): + def z3expr(self) -> z3.BoolRef: + # TODO: consider non-idempotent calls + return z3.Bool(str(self._identifier)) + + +class IntFieldAccess(IntExpr): + def __init__(self, message: StrID, field: StrID, negative: bool = False) -> None: + self._message = ID(message) + self._field = ID(field) + self._negative = negative + + def z3expr(self) -> z3.ArithRef: + expr = z3.Int(f"{self._message}.{self._field}") + return -expr if self._negative else expr + + +class BoolFieldAccess(BoolExpr): + def __init__(self, message: StrID, field: StrID) -> None: + self._message = ID(message) + self._field = ID(field) + + def z3expr(self) -> z3.BoolRef: + return z3.Bool(f"{self._message}.{self._field}") + + +class IntIfExpr(Expr): + def __init__( + self, condition: BasicBoolExpr, then_expr: BasicIntExpr, else_expr: BasicIntExpr + ) -> None: + self._condition = condition + self._then_expr = then_expr + self._else_expr = else_expr + + def z3expr(self) -> z3.ExprRef: + return z3.If(self._condition.z3expr(), self._then_expr.z3expr(), self._else_expr.z3expr()) + + +class BoolIfExpr(Expr): + def __init__( + self, condition: BasicBoolExpr, then_expr: BasicBoolExpr, else_expr: BasicBoolExpr + ) -> None: + self._condition = condition + self._then_expr = then_expr + self._else_expr = else_expr + + def z3expr(self) -> z3.ExprRef: + return z3.If(self._condition.z3expr(), self._then_expr.z3expr(), self._else_expr.z3expr()) diff --git a/stubs/z3.pyi b/stubs/z3.pyi index 388b91c66..3dc76b7ad 100644 --- a/stubs/z3.pyi +++ b/stubs/z3.pyi @@ -32,6 +32,12 @@ class IntNumRef(ArithRef): def as_string(self) -> str: ... def as_binary_string(self) -> bytes: ... +class Z3PPObject: ... +class AstRef(Z3PPObject): ... +class SortRef(AstRef): ... + +def DeclareSort(name: str) -> SortRef: ... +def Const(name: str, sort: SortRef) -> ExprRef: ... def Int(name: str, ctx: Optional[Context] = None) -> ArithRef: ... def IntVal(val: int, ctx: Optional[Context] = None) -> ArithRef: ... def Sum(*args: ArithRef) -> ArithRef: ... diff --git a/tests/unit/tac_test.py b/tests/unit/tac_test.py new file mode 100644 index 000000000..f8de72648 --- /dev/null +++ b/tests/unit/tac_test.py @@ -0,0 +1,207 @@ +import z3 + +from rflx import tac +from rflx.identifier import ID + + +def test_assign() -> None: + assign = tac.Assign("X", tac.IntVar("Y")) + assert assign.target == ID("X") + assert assign.expression == tac.IntVar("Y") + + +def test_assign_z3expr() -> None: + assert tac.Assign("X", tac.IntVar("Y")).z3expr() == (z3.Int("X") == z3.Int("Y")) + assert tac.Assign("X", tac.BoolVar("Y")).z3expr() == (z3.Bool("X") == z3.Bool("Y")) + + +def test_assert_z3expr() -> None: + assert tac.Assert(tac.BoolVar("X")).z3expr() == z3.Bool("X") + + +def test_int_var() -> None: + var = tac.IntVar("X") + assert var.identifier == ID("X") + + +def test_int_var_z3expr() -> None: + assert tac.IntVar("X").z3expr() == z3.Int("X") + + +def test_int_var_preconditions() -> None: + assert not tac.IntVar("X").preconditions + + +def test_bool_var() -> None: + var = tac.BoolVar("X") + assert var.identifier == ID("X") + + +def test_bool_var_z3expr() -> None: + assert tac.BoolVar("X").z3expr() == z3.Bool("X") + + +def test_bool_var_preconditions() -> None: + assert not tac.BoolVar("X").preconditions + + +def test_msg_var_z3expr() -> None: + assert tac.MsgVar("X").z3expr() == z3.Const("X", z3.DeclareSort("Msg")) + + +def test_seq_var_z3expr() -> None: + assert tac.SeqVar("X").z3expr() == z3.Const("X", z3.DeclareSort("Seq")) + + +def test_enum_lit_z3expr() -> None: + assert tac.EnumLit("Lit").z3expr() == z3.Int("Lit") + + +def test_int_val_z3expr() -> None: + assert tac.IntVal(1).z3expr() == z3.IntVal(1) + + +def test_bool_val_z3expr() -> None: + assert tac.BoolVal(True).z3expr() == z3.BoolVal(True) + assert tac.BoolVal(False).z3expr() == z3.BoolVal(False) + + +def test_add_z3expr() -> None: + assert tac.Add(tac.IntVar("X"), tac.IntVal(1)).z3expr() == (z3.Int("X") + z3.IntVal(1)) + + +def test_add_preconditions() -> None: + assert tac.Add(tac.IntVar("X"), tac.IntVal(1)).preconditions == [ + tac.Assign("D", tac.Sub(tac.IntVar("Type'Last"), tac.IntVal(1))), + tac.Assert(tac.LessEqual(tac.IntVar("X"), tac.IntVar("D"))), + ] + + +def test_sub_z3expr() -> None: + assert tac.Sub(tac.IntVar("X"), tac.IntVal(1)).z3expr() == (z3.Int("X") - z3.IntVal(1)) + + +def test_sub_preconditions() -> None: + assert tac.Sub(tac.IntVar("X"), tac.IntVal(1)).preconditions == [ + tac.Assign("S", tac.Add(tac.IntVar("Type'First"), tac.IntVal(1))), + tac.Assert(tac.GreaterEqual(tac.IntVar("X"), tac.IntVar("S"))), + ] + + +def test_mul_z3expr() -> None: + assert tac.Mul(tac.IntVar("X"), tac.IntVal(1)).z3expr() == (z3.Int("X") * z3.IntVal(1)) + + +def test_mul_preconditions() -> None: + assert tac.Mul(tac.IntVar("X"), tac.IntVal(1)).preconditions == [ + tac.Assign("D", tac.Div(tac.IntVar("Type'Last"), tac.IntVal(1))), + tac.Assert(tac.LessEqual(tac.IntVar("X"), tac.IntVar("D"))), + ] + + +def test_div_z3expr() -> None: + assert tac.Div(tac.IntVar("X"), tac.IntVal(1)).z3expr() == (z3.Int("X") / z3.IntVal(1)) + + +def test_div_preconditions() -> None: + assert tac.Div(tac.IntVar("X"), tac.IntVal(1)).preconditions == [ + tac.Assert(tac.NotEqual(tac.IntVal(1), tac.IntVal(0))), + ] + + +def test_pow_z3expr() -> None: + assert tac.Pow(tac.IntVar("X"), tac.IntVal(1)).z3expr() == (z3.Int("X") ** z3.IntVal(1)) + + +def test_pow_preconditions() -> None: + assert tac.Pow(tac.IntVar("X"), tac.IntVal(1)).preconditions == [ + tac.Assign("P", tac.Pow(tac.IntVar("X"), tac.IntVal(1))), + tac.Assert(tac.LessEqual(tac.IntVar("P"), tac.IntVar("Type'Last"))), + ] + + +def test_mod_z3expr() -> None: + assert tac.Mod(tac.IntVar("X"), tac.IntVal(1)).z3expr() == (z3.Int("X") % z3.IntVal(1)) + + +def test_mod_preconditions() -> None: + assert tac.Mod(tac.IntVar("X"), tac.IntVal(1)).preconditions == [ + tac.Assert(tac.NotEqual(tac.IntVal(1), tac.IntVal(0))), + ] + + +def test_not_z3expr() -> None: + assert tac.Not(tac.BoolVar("X")).z3expr() == z3.Not(z3.Bool("X")) + + +def test_and_z3expr() -> None: + assert tac.And(tac.BoolVar("X"), tac.BoolVal(True)).z3expr() == z3.And( + z3.Bool("X"), z3.BoolVal(True) + ) + + +def test_or_z3expr() -> None: + assert tac.Or(tac.BoolVar("X"), tac.BoolVal(True)).z3expr() == z3.Or( + z3.Bool("X"), z3.BoolVal(True) + ) + + +def test_less_z3expr() -> None: + assert tac.Less(tac.IntVar("X"), tac.IntVal(1)).z3expr() == (z3.Int("X") < z3.IntVal(1)) + + +def test_less_equal_z3expr() -> None: + assert tac.LessEqual(tac.IntVar("X"), tac.IntVal(1)).z3expr() == (z3.Int("X") <= z3.IntVal(1)) + + +def test_equal_z3expr() -> None: + assert tac.Equal(tac.IntVar("X"), tac.IntVal(1)).z3expr() == (z3.Int("X") == z3.IntVal(1)) + + +def test_greater_equal_z3expr() -> None: + assert tac.GreaterEqual(tac.IntVar("X"), tac.IntVal(1)).z3expr() == ( + z3.Int("X") >= z3.IntVal(1) + ) + + +def test_greater_z3expr() -> None: + assert tac.Greater(tac.IntVar("X"), tac.IntVal(1)).z3expr() == (z3.Int("X") > z3.IntVal(1)) + + +def test_not_equal_z3expr() -> None: + assert tac.NotEqual(tac.IntVar("X"), tac.IntVal(1)).z3expr() == (z3.Int("X") != z3.IntVal(1)) + + +def test_int_call_z3expr() -> None: + assert tac.IntCall("X", tac.IntVar("Y"), tac.BoolVal(True)).z3expr() == z3.Int("X") + + +def test_bool_call_z3expr() -> None: + assert tac.BoolCall("X", tac.BoolVar("Y"), tac.BoolVal(True)).z3expr() == z3.Bool("X") + + +def test_call_preconditions() -> None: + call = tac.IntCall("X", tac.IntVar("Y"), tac.BoolVal(True)) + assert not call.preconditions + call.preconditions = [tac.Assert(tac.Greater(tac.IntVar("Y"), tac.IntVal(0)))] + assert call.preconditions == [tac.Assert(tac.Greater(tac.IntVar("Y"), tac.IntVal(0)))] + + +def test_int_field_access_z3expr() -> None: + assert tac.IntFieldAccess("M", "F").z3expr() == z3.Int("M.F") + + +def test_bool_field_access_z3expr() -> None: + assert tac.BoolFieldAccess("M", "F").z3expr() == z3.Bool("M.F") + + +def test_int_if_expr_z3expr() -> None: + assert tac.IntIfExpr(tac.BoolVar("X"), tac.IntVar("Y"), tac.IntVal(1)).z3expr() == z3.If( + z3.Bool("X"), z3.Int("Y"), z3.IntVal(1) + ) + + +def test_bool_if_expr_z3expr() -> None: + assert tac.BoolIfExpr(tac.BoolVar("X"), tac.BoolVar("Y"), tac.BoolVal(False)).z3expr() == z3.If( + z3.Bool("X"), z3.Bool("Y"), z3.BoolVal(False) + )