From 1eab2144f675f95cd4a16f98259b2acbaa36f823 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Fri, 7 May 2021 16:36:01 -0700 Subject: [PATCH] Refactor for `Dialect`. --- clvm/__init__.py | 3 +- clvm/chainable_multi_op_fn.py | 26 +++++ clvm/chia_dialect.py | 37 +++++++ clvm/dialect.py | 159 ++++++++++++++++++++++++++++++ clvm/handle_unknown_op.py | 124 ++++++++++++++++++++++++ clvm/operators.py | 175 +++------------------------------- clvm/run_program.py | 27 +++++- clvm/serialize.py | 4 +- clvm/types.py | 15 +++ tests/brun/trace-1.txt | 2 +- tests/brun/trace-2.txt | 2 +- tests/operatordict_test.py | 29 ------ tests/operators_test.py | 65 +++++++++---- 13 files changed, 453 insertions(+), 215 deletions(-) create mode 100644 clvm/chainable_multi_op_fn.py create mode 100644 clvm/chia_dialect.py create mode 100644 clvm/dialect.py create mode 100644 clvm/handle_unknown_op.py create mode 100644 clvm/types.py delete mode 100644 tests/operatordict_test.py diff --git a/clvm/__init__.py b/clvm/__init__.py index a7062e1e..b0c4bf4c 100644 --- a/clvm/__init__.py +++ b/clvm/__init__.py @@ -1,6 +1,7 @@ from .SExp import SExp +from .dialect import Dialect # noqa from .operators import ( # noqa - QUOTE_ATOM, + QUOTE_ATOM, # deprecated KEYWORD_TO_ATOM, KEYWORD_FROM_ATOM, ) diff --git a/clvm/chainable_multi_op_fn.py b/clvm/chainable_multi_op_fn.py new file mode 100644 index 00000000..be25c8f4 --- /dev/null +++ b/clvm/chainable_multi_op_fn.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +from .types import CLVMObjectType, MultiOpFn, OperatorDict + + +@dataclass +class ChainableMultiOpFn: + """ + This structure handles clvm operators. Given an atom, it looks it up in a `dict`, then + falls back to calling `unknown_op_handler`. + """ + op_lookup: OperatorDict + unknown_op_handler: MultiOpFn + + def __call__( + self, op: bytes, arguments: CLVMObjectType, max_cost: Optional[int] = None + ) -> Tuple[int, CLVMObjectType]: + f = self.op_lookup.get(op) + if f: + try: + return f(arguments) + except TypeError: + # some operators require `max_cost` + return f(arguments, max_cost) + return self.unknown_op_handler(op, arguments, max_cost) diff --git a/clvm/chia_dialect.py b/clvm/chia_dialect.py new file mode 100644 index 00000000..a35e4b21 --- /dev/null +++ b/clvm/chia_dialect.py @@ -0,0 +1,37 @@ +from .casts import int_to_bytes +from .dialect import ConversionFn, Dialect, new_dialect, opcode_table_for_backend + +KEYWORDS = ( + # core opcodes 0x01-x08 + ". q a i c f r l x " + + # opcodes on atoms as strings 0x09-0x0f + "= >s sha256 substr strlen concat . " + + # opcodes on atoms as ints 0x10-0x17 + "+ - * / divmod > ash lsh " + + # opcodes on atoms as vectors of bools 0x18-0x1c + "logand logior logxor lognot . " + + # opcodes for bls 1381 0x1d-0x1f + "point_add pubkey_for_exp . " + + # bool opcodes 0x20-0x23 + "not any all . " + + # misc 0x24 + "softfork " +).split() + +KEYWORD_FROM_ATOM = {int_to_bytes(k): v for k, v in enumerate(KEYWORDS)} +KEYWORD_TO_ATOM = {v: k for k, v in KEYWORD_FROM_ATOM.items()} + + +def chia_dialect(strict: bool, to_python: ConversionFn, backend=None) -> Dialect: + quote_kw = KEYWORD_TO_ATOM["q"] + apply_kw = KEYWORD_TO_ATOM["a"] + dialect = new_dialect(quote_kw, apply_kw, strict, to_python, backend=backend) + table = opcode_table_for_backend(KEYWORD_TO_ATOM, backend=backend) + dialect.update(table) + return dialect diff --git a/clvm/dialect.py b/clvm/dialect.py new file mode 100644 index 00000000..5873a736 --- /dev/null +++ b/clvm/dialect.py @@ -0,0 +1,159 @@ +from typing import Callable, Optional, Tuple + +try: + import clvm_rs +except ImportError: + clvm_rs = None + +from . import core_ops, more_ops +from .chainable_multi_op_fn import ChainableMultiOpFn +from .handle_unknown_op import ( + handle_unknown_op_softfork_ready, + handle_unknown_op_strict, +) +from .run_program import _run_program +from .types import CLVMObjectType, ConversionFn, MultiOpFn, OperatorDict + + +OP_REWRITE = { + "+": "add", + "-": "subtract", + "*": "multiply", + "/": "div", + "i": "if", + "c": "cons", + "f": "first", + "r": "rest", + "l": "listp", + "x": "raise", + "=": "eq", + ">": "gr", + ">s": "gr_bytes", +} + + +def op_table_for_module(mod): + + # python-implemented operators don't take `max_cost` and rust-implemented operators do + # So we make the `max_cost` operator optional with this trick + # TODO: have python-implemented ops also take `max_cost` and unify the API. + + def elide_max_cost(f): + def inner_op(sexp, max_cost=None): + try: + return f(sexp, max_cost) + except TypeError: + return f(sexp) + return inner_op + + return {k: elide_max_cost(v) for k, v in mod.__dict__.items() if k.startswith("op_")} + + +def op_imp_table_for_backend(backend): + if backend is None and clvm_rs: + backend = "native" + + if backend == "native": + if clvm_rs is None: + raise RuntimeError("native backend not installed") + return clvm_rs.native_opcodes_dict() + + table = {} + table.update(op_table_for_module(core_ops)) + table.update(op_table_for_module(more_ops)) + return table + + +def op_atom_to_imp_table(op_imp_table, keyword_to_atom, op_rewrite=OP_REWRITE): + op_atom_to_imp_table = {} + for op, bytecode in keyword_to_atom.items(): + op_name = "op_%s" % op_rewrite.get(op, op) + op_f = op_imp_table.get(op_name) + if op_f: + op_atom_to_imp_table[bytecode] = op_f + return op_atom_to_imp_table + + +def opcode_table_for_backend(keyword_to_atom, backend): + op_imp_table = op_imp_table_for_backend(backend) + return op_atom_to_imp_table(op_imp_table, keyword_to_atom) + + +class Dialect: + def __init__( + self, + quote_kw: bytes, + apply_kw: bytes, + multi_op_fn: MultiOpFn, + to_python: ConversionFn, + ): + self.quote_kw = quote_kw + self.apply_kw = apply_kw + self.opcode_lookup = dict() + self.multi_op_fn = ChainableMultiOpFn(self.opcode_lookup, multi_op_fn) + self.to_python = to_python + + def update(self, d: OperatorDict) -> None: + self.opcode_lookup.update(d) + + def clear(self) -> None: + self.opcode_lookup.clear() + + def run_program( + self, + program: CLVMObjectType, + env: CLVMObjectType, + max_cost: int, + pre_eval_f: Optional[ + Callable[[CLVMObjectType, CLVMObjectType], Tuple[int, CLVMObjectType]] + ] = None, + ) -> Tuple[int, CLVMObjectType]: + cost, r = _run_program( + program, + env, + self.multi_op_fn, + self.quote_kw, + self.apply_kw, + max_cost, + pre_eval_f, + ) + return cost, self.to_python(r) + + +def native_new_dialect( + quote_kw: bytes, apply_kw: bytes, strict: bool, to_python: ConversionFn +) -> Dialect: + unknown_op_callback = ( + clvm_rs.NATIVE_OP_UNKNOWN_STRICT + if strict + else clvm_rs.NATIVE_OP_UNKNOWN_NON_STRICT + ) + dialect = clvm_rs.Dialect( + quote_kw, + apply_kw, + unknown_op_callback, + to_python=to_python, + ) + return dialect + + +def python_new_dialect( + quote_kw: bytes, apply_kw: bytes, strict: bool, to_python: ConversionFn +) -> Dialect: + unknown_op_callback = ( + handle_unknown_op_strict if strict else handle_unknown_op_softfork_ready + ) + dialect = Dialect( + quote_kw, + apply_kw, + unknown_op_callback, + to_python=to_python, + ) + return dialect + + +def new_dialect(quote_kw: bytes, apply_kw: bytes, strict: bool, to_python: ConversionFn, backend=None): + if backend is None: + backend = "python" if clvm_rs is None else "native" + backend_f = native_new_dialect if backend == "native" else python_new_dialect + return backend_f(quote_kw, apply_kw, strict, to_python) diff --git a/clvm/handle_unknown_op.py b/clvm/handle_unknown_op.py new file mode 100644 index 00000000..d30211b1 --- /dev/null +++ b/clvm/handle_unknown_op.py @@ -0,0 +1,124 @@ +from typing import Tuple + +from .CLVMObject import CLVMObject +from .EvalError import EvalError + +from .costs import ( + ARITH_BASE_COST, + ARITH_COST_PER_BYTE, + ARITH_COST_PER_ARG, + MUL_BASE_COST, + MUL_COST_PER_OP, + MUL_LINEAR_COST_PER_BYTE, + MUL_SQUARE_COST_PER_BYTE_DIVIDER, + CONCAT_BASE_COST, + CONCAT_COST_PER_ARG, + CONCAT_COST_PER_BYTE, +) + + +def handle_unknown_op_strict(op, arguments, _max_cost=None): + raise EvalError("unimplemented operator", arguments.to(op)) + + +def args_len(op_name, args): + for arg in args.as_iter(): + if arg.pair: + raise EvalError("%s requires int args" % op_name, arg) + yield len(arg.as_atom()) + + +# unknown ops are reserved if they start with 0xffff +# otherwise, unknown ops are no-ops, but they have costs. The cost is computed +# like this: + +# byte index (reverse): +# | 4 | 3 | 2 | 1 | 0 | +# +---+---+---+---+------------+ +# | multiplier |XX | XXXXXX | +# +---+---+---+---+---+--------+ +# ^ ^ ^ +# | | + 6 bits ignored when computing cost +# cost_multiplier | +# + 2 bits +# cost_function + +# 1 is always added to the multiplier before using it to multiply the cost, this +# is since cost may not be 0. + +# cost_function is 2 bits and defines how cost is computed based on arguments: +# 0: constant, cost is 1 * (multiplier + 1) +# 1: computed like operator add, multiplied by (multiplier + 1) +# 2: computed like operator mul, multiplied by (multiplier + 1) +# 3: computed like operator concat, multiplied by (multiplier + 1) + +# this means that unknown ops where cost_function is 1, 2, or 3, may still be +# fatal errors if the arguments passed are not atoms. + + +def handle_unknown_op_softfork_ready( + op: bytes, args: CLVMObject, max_cost: int +) -> Tuple[int, CLVMObject]: + # any opcode starting with ffff is reserved (i.e. fatal error) + # opcodes are not allowed to be empty + if len(op) == 0 or op[:2] == b"\xff\xff": + raise EvalError("reserved operator", args.to(op)) + + # all other unknown opcodes are no-ops + # the cost of the no-ops is determined by the opcode number, except the + # 6 least significant bits. + + cost_function = (op[-1] & 0b11000000) >> 6 + # the multiplier cannot be 0. it starts at 1 + + if len(op) > 5: + raise EvalError("invalid operator", args.to(op)) + + cost_multiplier = int.from_bytes(op[:-1], "big", signed=False) + 1 + + # 0 = constant + # 1 = like op_add/op_sub + # 2 = like op_multiply + # 3 = like op_concat + if cost_function == 0: + cost = 1 + elif cost_function == 1: + # like op_add + cost = ARITH_BASE_COST + arg_size = 0 + for length in args_len("unknown op", args): + arg_size += length + cost += ARITH_COST_PER_ARG + cost += arg_size * ARITH_COST_PER_BYTE + elif cost_function == 2: + # like op_multiply + cost = MUL_BASE_COST + operands = args_len("unknown op", args) + try: + vs = next(operands) + for rs in operands: + cost += MUL_COST_PER_OP + cost += (rs + vs) * MUL_LINEAR_COST_PER_BYTE + cost += (rs * vs) // MUL_SQUARE_COST_PER_BYTE_DIVIDER + # this is an estimate, since we don't want to actually multiply the + # values + vs += rs + except StopIteration: + pass + + elif cost_function == 3: + # like concat + cost = CONCAT_BASE_COST + length = 0 + for arg in args.as_iter(): + if arg.pair: + raise EvalError("unknown op on list", arg) + cost += CONCAT_COST_PER_ARG + length += len(arg.atom) + cost += length * CONCAT_COST_PER_BYTE + + cost *= cost_multiplier + if cost >= 2**32: + raise EvalError("invalid operator", args.to(op)) + + return (cost, args.to(b"")) diff --git a/clvm/operators.py b/clvm/operators.py index a63e6a88..07d164e7 100644 --- a/clvm/operators.py +++ b/clvm/operators.py @@ -1,168 +1,14 @@ +# this API is deprecated in favor of dialects. See `dialect.py` and `chia_dialect.py` + from typing import Dict, Tuple from . import core_ops, more_ops from .CLVMObject import CLVMObject -from .SExp import SExp -from .EvalError import EvalError - -from .casts import int_to_bytes from .op_utils import operators_for_module - -from .costs import ( - ARITH_BASE_COST, - ARITH_COST_PER_BYTE, - ARITH_COST_PER_ARG, - MUL_BASE_COST, - MUL_COST_PER_OP, - MUL_LINEAR_COST_PER_BYTE, - MUL_SQUARE_COST_PER_BYTE_DIVIDER, - CONCAT_BASE_COST, - CONCAT_COST_PER_ARG, - CONCAT_COST_PER_BYTE, -) - -KEYWORDS = ( - # core opcodes 0x01-x08 - ". q a i c f r l x " - - # opcodes on atoms as strings 0x09-0x0f - "= >s sha256 substr strlen concat . " - - # opcodes on atoms as ints 0x10-0x17 - "+ - * / divmod > ash lsh " - - # opcodes on atoms as vectors of bools 0x18-0x1c - "logand logior logxor lognot . " - - # opcodes for bls 1381 0x1d-0x1f - "point_add pubkey_for_exp . " - - # bool opcodes 0x20-0x23 - "not any all . " - - # misc 0x24 - "softfork " -).split() - -KEYWORD_FROM_ATOM = {int_to_bytes(k): v for k, v in enumerate(KEYWORDS)} -KEYWORD_TO_ATOM = {v: k for k, v in KEYWORD_FROM_ATOM.items()} - -OP_REWRITE = { - "+": "add", - "-": "subtract", - "*": "multiply", - "/": "div", - "i": "if", - "c": "cons", - "f": "first", - "r": "rest", - "l": "listp", - "x": "raise", - "=": "eq", - ">": "gr", - ">s": "gr_bytes", -} - - -def args_len(op_name, args): - for arg in args.as_iter(): - if arg.pair: - raise EvalError("%s requires int args" % op_name, arg) - yield len(arg.as_atom()) - - -# unknown ops are reserved if they start with 0xffff -# otherwise, unknown ops are no-ops, but they have costs. The cost is computed -# like this: - -# byte index (reverse): -# | 4 | 3 | 2 | 1 | 0 | -# +---+---+---+---+------------+ -# | multiplier |XX | XXXXXX | -# +---+---+---+---+---+--------+ -# ^ ^ ^ -# | | + 6 bits ignored when computing cost -# cost_multiplier | -# + 2 bits -# cost_function - -# 1 is always added to the multiplier before using it to multiply the cost, this -# is since cost may not be 0. - -# cost_function is 2 bits and defines how cost is computed based on arguments: -# 0: constant, cost is 1 * (multiplier + 1) -# 1: computed like operator add, multiplied by (multiplier + 1) -# 2: computed like operator mul, multiplied by (multiplier + 1) -# 3: computed like operator concat, multiplied by (multiplier + 1) - -# this means that unknown ops where cost_function is 1, 2, or 3, may still be -# fatal errors if the arguments passed are not atoms. - -def default_unknown_op(op: bytes, args: CLVMObject) -> Tuple[int, CLVMObject]: - # any opcode starting with ffff is reserved (i.e. fatal error) - # opcodes are not allowed to be empty - if len(op) == 0 or op[:2] == b"\xff\xff": - raise EvalError("reserved operator", args.to(op)) - - # all other unknown opcodes are no-ops - # the cost of the no-ops is determined by the opcode number, except the - # 6 least significant bits. - - cost_function = (op[-1] & 0b11000000) >> 6 - # the multiplier cannot be 0. it starts at 1 - - if len(op) > 5: - raise EvalError("invalid operator", args.to(op)) - - cost_multiplier = int.from_bytes(op[:-1], "big", signed=False) + 1 - - # 0 = constant - # 1 = like op_add/op_sub - # 2 = like op_multiply - # 3 = like op_concat - if cost_function == 0: - cost = 1 - elif cost_function == 1: - # like op_add - cost = ARITH_BASE_COST - arg_size = 0 - for length in args_len("unknown op", args): - arg_size += length - cost += ARITH_COST_PER_ARG - cost += arg_size * ARITH_COST_PER_BYTE - elif cost_function == 2: - # like op_multiply - cost = MUL_BASE_COST - operands = args_len("unknown op", args) - try: - vs = next(operands) - for rs in operands: - cost += MUL_COST_PER_OP - cost += (rs + vs) * MUL_LINEAR_COST_PER_BYTE - cost += (rs * vs) // MUL_SQUARE_COST_PER_BYTE_DIVIDER - # this is an estimate, since we don't want to actually multiply the - # values - vs += rs - except StopIteration: - pass - - elif cost_function == 3: - # like concat - cost = CONCAT_BASE_COST - length = 0 - for arg in args.as_iter(): - if arg.pair: - raise EvalError("unknown op on list", arg) - cost += CONCAT_COST_PER_ARG - length += len(arg.atom) - cost += length * CONCAT_COST_PER_BYTE - - cost *= cost_multiplier - if cost >= 2**32: - raise EvalError("invalid operator", args.to(op)) - - return (cost, SExp.null()) +from .handle_unknown_op import handle_unknown_op_softfork_ready +from .chia_dialect import KEYWORDS, KEYWORD_FROM_ATOM, KEYWORD_TO_ATOM # noqa +from .dialect import OP_REWRITE class OperatorDict(dict): @@ -184,13 +30,16 @@ def __new__(class_, d: Dict, *args, **kwargs): if "unknown_op_handler" in kwargs: self.unknown_op_handler = kwargs["unknown_op_handler"] else: - self.unknown_op_handler = default_unknown_op + self.unknown_op_handler = handle_unknown_op_softfork_ready return self def __call__(self, op: bytes, arguments: CLVMObject) -> Tuple[int, CLVMObject]: f = self.get(op) if f is None: - return self.unknown_op_handler(op, arguments) + try: + return self.unknown_op_handler(op, arguments, max_cost=None) + except TypeError: + return self.unknown_op_handler(op, arguments) else: return f(arguments) @@ -199,6 +48,8 @@ def __call__(self, op: bytes, arguments: CLVMObject) -> Tuple[int, CLVMObject]: APPLY_ATOM = KEYWORD_TO_ATOM["a"] OPERATOR_LOOKUP = OperatorDict( - operators_for_module(KEYWORD_TO_ATOM, core_ops, OP_REWRITE), quote=QUOTE_ATOM, apply=APPLY_ATOM + operators_for_module(KEYWORD_TO_ATOM, core_ops, OP_REWRITE), + quote=QUOTE_ATOM, + apply=APPLY_ATOM, ) OPERATOR_LOOKUP.update(operators_for_module(KEYWORD_TO_ATOM, more_ops, OP_REWRITE)) diff --git a/clvm/run_program.py b/clvm/run_program.py index 20f4b75c..7b7fcb17 100644 --- a/clvm/run_program.py +++ b/clvm/run_program.py @@ -15,6 +15,8 @@ # the "Any" below should really be "OpStackType" but # recursive types aren't supported by mypy +MultiOpFn = Callable[[bytes, SExp, int], Tuple[int, SExp]] + OpCallable = Callable[[Any, "ValStackType"], int] ValStackType = List[SExp] @@ -53,6 +55,27 @@ def run_program( pre_eval_f=None, ) -> Tuple[int, CLVMObject]: + return _run_program( + program, + args, + operator_lookup, + operator_lookup.quote_atom, + operator_lookup.apply_atom, + max_cost, + pre_eval_f, + ) + + +def _run_program( + program: CLVMObject, + args: CLVMObject, + operator_lookup: MultiOpFn, + quote_atom: bytes, + apply_atom: bytes, + max_cost=None, + pre_eval_f=None, +) -> Tuple[int, CLVMObject]: + program = SExp.to(program) if pre_eval_f: pre_eval_op = to_pre_eval_op(pre_eval_f, program.to) @@ -137,7 +160,7 @@ def eval_op(op_stack: OpStackType, value_stack: ValStackType) -> int: op = operator.as_atom() operand_list = sexp.rest() - if op == operator_lookup.quote_atom: + if op == quote_atom: value_stack.append(operand_list) return QUOTE_COST @@ -160,7 +183,7 @@ def apply_op(op_stack: OpStackType, value_stack: ValStackType) -> int: raise EvalError("internal error", operator) op = operator.as_atom() - if op == operator_lookup.apply_atom: + if op == apply_atom: if operand_list.list_len() != 2: raise EvalError("apply requires exactly 2 parameters", operand_list) new_program = operand_list.first() diff --git a/clvm/serialize.py b/clvm/serialize.py index d45ed75a..accda2f1 100644 --- a/clvm/serialize.py +++ b/clvm/serialize.py @@ -22,13 +22,13 @@ def sexp_to_byte_iterator(sexp): todo_stack = [sexp] while todo_stack: sexp = todo_stack.pop() - pair = sexp.as_pair() + pair = sexp.pair if pair: yield bytes([CONS_BOX_MARKER]) todo_stack.append(pair[1]) todo_stack.append(pair[0]) else: - yield from atom_to_byte_iterator(sexp.as_atom()) + yield from atom_to_byte_iterator(sexp.atom) def atom_to_byte_iterator(as_atom): diff --git a/clvm/types.py b/clvm/types.py new file mode 100644 index 00000000..c9f794ba --- /dev/null +++ b/clvm/types.py @@ -0,0 +1,15 @@ +from typing import Any, Callable, Dict, Tuple, Union + + +CLVMAtom = Any +CLVMPair = Any + +CLVMObjectType = Union["CLVMAtom", "CLVMPair"] + +MultiOpFn = Callable[[bytes, CLVMObjectType, int], Tuple[int, CLVMObjectType]] + +ConversionFn = Callable[[CLVMObjectType], CLVMObjectType] + +OpFn = Callable[[CLVMObjectType, int], Tuple[int, CLVMObjectType]] + +OperatorDict = Dict[bytes, Callable[[CLVMObjectType, int], Tuple[int, CLVMObjectType]]] diff --git a/tests/brun/trace-1.txt b/tests/brun/trace-1.txt index 72fb303f..a1d5c92d 100644 --- a/tests/brun/trace-1.txt +++ b/tests/brun/trace-1.txt @@ -1,4 +1,4 @@ -brun --backend=python -c -v '(+ (q . 10) (f 1))' '(51)' +brun -c -v '(+ (q . 10) (f 1))' '(51)' cost = 860 61 diff --git a/tests/brun/trace-2.txt b/tests/brun/trace-2.txt index fb475ec5..08ba2687 100644 --- a/tests/brun/trace-2.txt +++ b/tests/brun/trace-2.txt @@ -1,4 +1,4 @@ -brun --backend=python -c -v '(x)' +brun -c -v '(x)' FAIL: clvm raise () (a 2 3) [((x))] => (didn't finish) diff --git a/tests/operatordict_test.py b/tests/operatordict_test.py deleted file mode 100644 index 897f6ffe..00000000 --- a/tests/operatordict_test.py +++ /dev/null @@ -1,29 +0,0 @@ -import unittest - -from clvm.operators import OperatorDict - - -class OperatorDictTest(unittest.TestCase): - def test_operatordict_constructor(self): - """Constructing should fail if quote or apply are not specified, - either by object property or by keyword argument. - Note that they cannot be specified in the operator dictionary itself. - """ - d = {1: "hello", 2: "goodbye"} - with self.assertRaises(AttributeError): - o = OperatorDict(d) - with self.assertRaises(AttributeError): - o = OperatorDict(d, apply=1) - with self.assertRaises(AttributeError): - o = OperatorDict(d, quote=1) - o = OperatorDict(d, apply=1, quote=2) - print(o) - # Why does the constructed Operator dict contain entries for "apply":1 and "quote":2 ? - # assert d == o - self.assertEqual(o.apply_atom, 1) - self.assertEqual(o.quote_atom, 2) - - # Test construction from an already existing OperatorDict - o2 = OperatorDict(o) - self.assertEqual(o2.apply_atom, 1) - self.assertEqual(o2.quote_atom, 2) diff --git a/tests/operators_test.py b/tests/operators_test.py index 9c84d719..c0282412 100644 --- a/tests/operators_test.py +++ b/tests/operators_test.py @@ -1,59 +1,90 @@ import unittest -from clvm.operators import (OPERATOR_LOOKUP, KEYWORD_TO_ATOM, default_unknown_op, OperatorDict) +from clvm.chainable_multi_op_fn import ChainableMultiOpFn +from clvm.costs import CONCAT_BASE_COST +from clvm.dialect import opcode_table_for_backend +from clvm.handle_unknown_op import handle_unknown_op_softfork_ready +from clvm.operators import KEYWORD_TO_ATOM from clvm.EvalError import EvalError from clvm import SExp -from clvm.costs import CONCAT_BASE_COST +OPERATOR_LOOKUP = opcode_table_for_backend(KEYWORD_TO_ATOM, backend=None) +MAX_COST = int(1e18) -class OperatorsTest(unittest.TestCase): +class OperatorsTest(unittest.TestCase): def setUp(self): self.handler_called = False - def unknown_handler(self, name, args): + def unknown_handler(self, name, args, _max_cost): self.handler_called = True - self.assertEqual(name, b'\xff\xff1337') + self.assertEqual(name, b"\xff\xff1337") self.assertEqual(args, SExp.to(1337)) - return 42, SExp.to(b'foobar') + return 42, SExp.to(b"foobar") def test_unknown_op(self): - self.assertRaises(EvalError, lambda: OPERATOR_LOOKUP(b'\xff\xff1337', SExp.to(1337))) - od = OperatorDict(OPERATOR_LOOKUP, unknown_op_handler=lambda name, args: self.unknown_handler(name, args)) - cost, ret = od(b'\xff\xff1337', SExp.to(1337)) + self.assertRaises( + KeyError, lambda: OPERATOR_LOOKUP[b"\xff\xff1337"](SExp.to(1337), None) + ) + od = ChainableMultiOpFn( + opcode_table_for_backend(KEYWORD_TO_ATOM, backend=None), + self.unknown_handler, + ) + cost, ret = od(b"\xff\xff1337", SExp.to(1337), None) self.assertTrue(self.handler_called) self.assertEqual(cost, 42) - self.assertEqual(ret, SExp.to(b'foobar')) + self.assertEqual(ret, SExp.to(b"foobar")) def test_plus(self): print(OPERATOR_LOOKUP) - self.assertEqual(OPERATOR_LOOKUP(KEYWORD_TO_ATOM['+'], SExp.to([3, 4, 5]))[1], SExp.to(12)) + self.assertEqual( + OPERATOR_LOOKUP[KEYWORD_TO_ATOM["+"]](SExp.to([3, 4, 5]), MAX_COST)[1], + SExp.to(12), + ) def test_unknown_op_reserved(self): # any op that starts with ffff is reserved, and results in a hard # failure with self.assertRaises(EvalError): - default_unknown_op(b"\xff\xff", SExp.null()) + handle_unknown_op_softfork_ready(b"\xff\xff", SExp.null(), max_cost=None) for suffix in [b"\xff", b"0", b"\x00", b"\xcc\xcc\xfe\xed\xfa\xce"]: with self.assertRaises(EvalError): - default_unknown_op(b"\xff\xff" + suffix, SExp.null()) + handle_unknown_op_softfork_ready( + b"\xff\xff" + suffix, SExp.null(), max_cost=None + ) with self.assertRaises(EvalError): # an empty atom is not a valid opcode - self.assertEqual(default_unknown_op(b"", SExp.null()), (1, SExp.null())) + self.assertEqual( + handle_unknown_op_softfork_ready(b"", SExp.null(), max_cost=None), + (1, SExp.null()), + ) # a single ff is not sufficient to be treated as a reserved opcode - self.assertEqual(default_unknown_op(b"\xff", SExp.null()), (CONCAT_BASE_COST, SExp.null())) + self.assertEqual( + handle_unknown_op_softfork_ready(b"\xff", SExp.null(), max_cost=None), + (CONCAT_BASE_COST, SExp.null()), + ) # leading zeroes count, and this does not count as a ffff-prefix # the cost is 0xffff00 = 16776960 - self.assertEqual(default_unknown_op(b"\x00\xff\xff\x00\x00", SExp.null()), (16776961, SExp.null())) + self.assertEqual( + handle_unknown_op_softfork_ready( + b"\x00\xff\xff\x00\x00", SExp.null(), max_cost=None + ), + (16776961, SExp.null()), + ) def test_unknown_ops_last_bits(self): # The last byte is ignored for no-op unknown ops for suffix in [b"\x3f", b"\x0f", b"\x00", b"\x2c"]: # the cost is unchanged by the last byte - self.assertEqual(default_unknown_op(b"\x3c" + suffix, SExp.null()), (61, SExp.null())) + self.assertEqual( + handle_unknown_op_softfork_ready( + b"\x3c" + suffix, SExp.null(), max_cost=None + ), + (61, SExp.null()), + )