diff --git a/chia/_tests/blockchain/test_blockchain.py b/chia/_tests/blockchain/test_blockchain.py index fc000b951385..4aee79ede938 100644 --- a/chia/_tests/blockchain/test_blockchain.py +++ b/chia/_tests/blockchain/test_blockchain.py @@ -9,7 +9,7 @@ from typing import List, Optional import pytest -from chia_rs import AugSchemeMPL, G2Element +from chia_rs import AugSchemeMPL, G2Element, MerkleSet from clvm.casts import int_to_bytes from chia._tests.blockchain.blockchain_test_utils import ( @@ -54,7 +54,6 @@ from chia.util.generator_tools import get_block_header from chia.util.hash import std_hash from chia.util.ints import uint8, uint32, uint64 -from chia.util.merkle_set import MerkleSet from chia.util.misc import available_logical_cores from chia.util.recursive_replace import recursive_replace from chia.util.vdf_prover import get_vdf_info_and_proof diff --git a/chia/_tests/core/test_merkle_set.py b/chia/_tests/core/test_merkle_set.py index 80d953a71152..239e889a3370 100644 --- a/chia/_tests/core/test_merkle_set.py +++ b/chia/_tests/core/test_merkle_set.py @@ -8,13 +8,12 @@ from typing import List, Optional, Tuple import pytest -from chia_rs import Coin, compute_merkle_set_root +from chia_rs import Coin, MerkleSet, compute_merkle_set_root, confirm_included_already_hashed from chia.simulator.block_tools import BlockTools from chia.types.blockchain_format.sized_bytes import bytes32 from chia.util.hash import std_hash from chia.util.ints import uint64 -from chia.util.merkle_set import MerkleSet, confirm_included_already_hashed from chia.util.misc import to_batches from chia.wallet.util.wallet_sync_utils import validate_additions, validate_removals @@ -53,21 +52,21 @@ def hashdown(buf: bytes) -> bytes32: @pytest.mark.anyio async def test_merkle_set_invalid_hash_size() -> None: # this is too large - with pytest.raises(AssertionError): + with pytest.raises(ValueError): MerkleSet([bytes([0x80] + [0] * 32)]) # type: ignore[list-item] with pytest.raises(ValueError, match="could not convert slice to array"): compute_merkle_set_root([bytes([0x80] + [0] * 32)]) # this is too small - with pytest.raises(AssertionError): + with pytest.raises(ValueError): MerkleSet([bytes([0x80] + [0] * 30)]) # type: ignore[list-item] with pytest.raises(ValueError, match="could not convert slice to array"): compute_merkle_set_root([bytes([0x80] + [0] * 30)]) # empty - with pytest.raises(AssertionError): + with pytest.raises(ValueError): MerkleSet([b""]) # type: ignore[list-item] with pytest.raises(ValueError, match="could not convert slice to array"): diff --git a/chia/_tests/wallet/sync/test_wallet_sync.py b/chia/_tests/wallet/sync/test_wallet_sync.py index 728929164188..a6dad788e223 100644 --- a/chia/_tests/wallet/sync/test_wallet_sync.py +++ b/chia/_tests/wallet/sync/test_wallet_sync.py @@ -9,6 +9,7 @@ import pytest from aiosqlite import Error as AIOSqliteError +from chia_rs import confirm_not_included_already_hashed from colorlog import getLogger from chia._tests.connection_utils import disconnect_all, disconnect_all_and_reconnect @@ -526,8 +527,10 @@ async def test_request_additions_errors(simulator_and_wallet: OldSimulatorsAndWa await full_node_api.request_additions(RequestAdditions(last_block.height, std_hash(b""), [ph])) # No results + fake_coin = std_hash(b"") + assert ph != fake_coin res1 = await full_node_api.request_additions( - RequestAdditions(last_block.height, last_block.header_hash, [std_hash(b"")]) + RequestAdditions(last_block.height, last_block.header_hash, [fake_coin]) ) assert res1 is not None response = RespondAdditions.from_bytes(res1.data) @@ -536,6 +539,16 @@ async def test_request_additions_errors(simulator_and_wallet: OldSimulatorsAndWa assert response.proofs is not None assert len(response.proofs) == 1 assert len(response.coins) == 1 + full_block = await full_node_api.full_node.block_store.get_full_block(last_block.header_hash) + assert full_block is not None + assert full_block.foliage_transaction_block is not None + root = full_block.foliage_transaction_block.additions_root + assert confirm_not_included_already_hashed(root, response.proofs[0][0], response.proofs[0][1]) + # proofs is a tuple of (puzzlehash, proof, proof_2) + # proof is a proof of inclusion (or exclusion) of that puzzlehash + # proof_2 is a proof of all the coins with that puzzlehash + # all coin names are concatenated and hashed into one entry in the merkle set for proof_2 + # the response contains the list of coins so you can check the proof_2 assert response.proofs[0][0] == std_hash(b"") assert response.proofs[0][1] is not None diff --git a/chia/full_node/full_node_api.py b/chia/full_node/full_node_api.py index 97bd31ce5467..20533b4964ea 100644 --- a/chia/full_node/full_node_api.py +++ b/chia/full_node/full_node_api.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast import anyio -from chia_rs import AugSchemeMPL, G1Element, G2Element +from chia_rs import AugSchemeMPL, G1Element, G2Element, MerkleSet from chiabip158 import PyBIP158 from chia.consensus.block_creation import create_unfinished_block @@ -66,7 +66,6 @@ from chia.util.hash import std_hash from chia.util.ints import uint8, uint32, uint64, uint128 from chia.util.limited_semaphore import LimitedSemaphoreFullError -from chia.util.merkle_set import MerkleSet if TYPE_CHECKING: from chia.full_node.full_node import FullNode diff --git a/chia/util/merkle_set.py b/chia/util/merkle_set.py deleted file mode 100644 index 77a47b82cd43..000000000000 --- a/chia/util/merkle_set.py +++ /dev/null @@ -1,371 +0,0 @@ -from __future__ import annotations - -from abc import ABCMeta, abstractmethod -from hashlib import sha256 -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple - -from chia.types.blockchain_format.sized_bytes import bytes32 - -if TYPE_CHECKING: - from hashlib import _Hash - -""" -A simple, confidence-inspiring Merkle Set standard - -Advantages of this standard: -Low CPU requirements -Small proofs of inclusion/exclusion -Reasonably simple implementation - -The main tricks in this standard are: - -Skips repeated hashing of exactly two things even when they share prefix bits - - -Proofs support proving including/exclusion for a large number of values in -a single string. They're a serialization of a subset of the tree. - -Proof format: - -multiproof: subtree -subtree: middle or terminal or truncated or empty -middle: MIDDLE 1 subtree subtree -terminal: TERMINAL 1 hash 32 -# If the sibling is empty truncated implies more than two children. -truncated: TRUNCATED 1 hash 32 -empty: EMPTY 1 -EMPTY: \x00 -TERMINAL: \x01 -MIDDLE: \x02 -TRUNCATED: \x03 -""" - -EMPTY = bytes([0]) -TERMINAL = bytes([1]) -MIDDLE = bytes([2]) -TRUNCATED = bytes([3]) - -BLANK = bytes32([0] * 32) - -prehashed: Dict[bytes, _Hash] = {} - - -def init_prehashed() -> None: - for x in [EMPTY, TERMINAL, MIDDLE]: - for y in [EMPTY, TERMINAL, MIDDLE]: - prehashed[x + y] = sha256(bytes([0] * 30) + x + y) - - -init_prehashed() - - -def hashdown(mystr: bytes) -> bytes: - assert len(mystr) == 66 - h = prehashed[bytes(mystr[0:1] + mystr[33:34])].copy() - h.update(mystr[1:33] + mystr[34:]) - return h.digest()[:32] - - -def compress_root(mystr: bytes) -> bytes32: - assert len(mystr) == 33 - if mystr[0:1] == MIDDLE: - return bytes32(mystr[1:]) - if mystr[0:1] == EMPTY: - assert mystr[1:] == BLANK - return BLANK - return bytes32(sha256(mystr).digest()[:32]) - - -def get_bit(mybytes: bytes, pos: int) -> int: - assert len(mybytes) == 32 - return (mybytes[pos // 8] >> (7 - (pos % 8))) & 1 - - -class Node(metaclass=ABCMeta): - hash: bytes - - @abstractmethod - def get_hash(self) -> bytes: - pass - - @abstractmethod - def is_empty(self) -> bool: - pass - - @abstractmethod - def is_terminal(self) -> bool: - pass - - @abstractmethod - def is_double(self) -> bool: - pass - - @abstractmethod - def add(self, toadd: bytes, depth: int) -> Node: - pass - - @abstractmethod - def is_included(self, tocheck: bytes, depth: int, p: List[bytes]) -> bool: - pass - - @abstractmethod - def other_included(self, tocheck: bytes, depth: int, p: List[bytes], collapse: bool) -> None: - pass - - @abstractmethod - def _audit(self, hashes: List[bytes], bits: List[int]) -> None: - pass - - -class MerkleSet: - root: Node - - def __init__(self, leafs: Iterable[bytes32]): - self.root = _empty - for leaf in leafs: - self.root = self.root.add(leaf, 0) - - def get_root(self) -> bytes32: - return compress_root(self.root.get_hash()) - - def add_already_hashed(self, toadd: bytes) -> None: - self.root = self.root.add(toadd, 0) - - def is_included_already_hashed(self, tocheck: bytes) -> Tuple[bool, bytes]: - proof: List[bytes] = [] - r = self.root.is_included(tocheck, 0, proof) - return r, b"".join(proof) - - def _audit(self, hashes: List[bytes]) -> None: - newhashes: List[bytes] = [] - self.root._audit(newhashes, []) - assert newhashes == sorted(newhashes) - - @staticmethod - def _with_root(root: Node) -> MerkleSet: - s = MerkleSet([]) - s.root = root - return s - - -class EmptyNode(Node): - def __init__(self) -> None: - self.hash = BLANK - - def get_hash(self) -> bytes: - return EMPTY + BLANK - - def is_empty(self) -> bool: - return True - - def is_terminal(self) -> bool: - return False - - def is_double(self) -> bool: - raise SetError() - - def add(self, toadd: bytes, depth: int) -> Node: - return TerminalNode(toadd) - - def is_included(self, tocheck: bytes, depth: int, p: List[bytes]) -> bool: - p.append(EMPTY) - return False - - def other_included(self, tocheck: bytes, depth: int, p: List[bytes], collapse: bool) -> None: - p.append(EMPTY) - - def _audit(self, hashes: List[bytes], bits: List[int]) -> None: - pass - - -_empty = EmptyNode() - - -def _make_middle(children: Any, depth: int) -> Node: - cbits = [get_bit(child.hash, depth) for child in children] - if cbits[0] != cbits[1]: - return MiddleNode(children) - nextvals: List[Node] = [_empty, _empty] - nextvals[cbits[0] ^ 1] = _empty - nextvals[cbits[0]] = _make_middle(children, depth + 1) - return MiddleNode(nextvals) - - -class TerminalNode(Node): - def __init__(self, hash: bytes, bits: Optional[List[int]] = None) -> None: - assert len(hash) == 32 - self.hash = hash - if bits is not None: - self._audit([], bits) - - def get_hash(self) -> bytes: - return TERMINAL + self.hash - - def is_empty(self) -> bool: - return False - - def is_terminal(self) -> bool: - return True - - def is_double(self) -> bool: - raise SetError() - - def add(self, toadd: bytes, depth: int) -> Node: - if toadd == self.hash: - return self - if toadd > self.hash: - return _make_middle([self, TerminalNode(toadd)], depth) - else: - return _make_middle([TerminalNode(toadd), self], depth) - - def is_included(self, tocheck: bytes, depth: int, p: List[bytes]) -> bool: - p.append(TERMINAL + self.hash) - return tocheck == self.hash - - def other_included(self, tocheck: bytes, depth: int, p: List[bytes], collapse: bool) -> None: - p.append(TERMINAL + self.hash) - - def _audit(self, hashes: List[bytes], bits: List[int]) -> None: - hashes.append(self.hash) - for pos, v in enumerate(bits): - assert get_bit(self.hash, pos) == v - - -class MiddleNode(Node): - def __init__(self, children: List[Node]): - self.children = children - if children[0].is_empty() and children[1].is_double(): - self.hash = children[1].hash - elif children[1].is_empty() and children[0].is_double(): - self.hash = children[0].hash - else: - if children[0].is_empty() and (children[1].is_empty() or children[1].is_terminal()): - raise SetError() - if children[1].is_empty() and children[0].is_terminal(): - raise SetError - if children[0].is_terminal() and children[1].is_terminal() and children[0].hash >= children[1].hash: - raise SetError - self.hash = hashdown(children[0].get_hash() + children[1].get_hash()) - - def get_hash(self) -> bytes: - return MIDDLE + self.hash - - def is_empty(self) -> bool: - return False - - def is_terminal(self) -> bool: - return False - - def is_double(self) -> bool: - if self.children[0].is_empty(): - return self.children[1].is_double() - if self.children[1].is_empty(): - return self.children[0].is_double() - return self.children[0].is_terminal() and self.children[1].is_terminal() - - def add(self, toadd: bytes, depth: int) -> Node: - bit = get_bit(toadd, depth) - child = self.children[bit] - newchild = child.add(toadd, depth + 1) - if newchild is child: - return self - newvals = [x for x in self.children] - newvals[bit] = newchild - return MiddleNode(newvals) - - def is_included(self, tocheck: bytes, depth: int, p: List[bytes]) -> bool: - p.append(MIDDLE) - if get_bit(tocheck, depth) == 0: - r = self.children[0].is_included(tocheck, depth + 1, p) - self.children[1].other_included(tocheck, depth + 1, p, not self.children[0].is_empty()) - return r - else: - self.children[0].other_included(tocheck, depth + 1, p, not self.children[1].is_empty()) - return self.children[1].is_included(tocheck, depth + 1, p) - - def other_included(self, tocheck: bytes, depth: int, p: List[bytes], collapse: bool) -> None: - if collapse or not self.is_double(): - p.append(TRUNCATED + self.hash) - else: - self.is_included(tocheck, depth, p) - - def _audit(self, hashes: List[bytes], bits: List[int]) -> None: - self.children[0]._audit(hashes, bits + [0]) - self.children[1]._audit(hashes, bits + [1]) - - -class TruncatedNode(Node): - def __init__(self, hash: bytes): - self.hash = hash - - def get_hash(self) -> bytes: - return MIDDLE + self.hash - - def is_empty(self) -> bool: - return False - - def is_terminal(self) -> bool: - return False - - def is_double(self) -> bool: - return False - - def add(self, toadd: bytes, depth: int) -> Node: - return self - - def is_included(self, tocheck: bytes, depth: int, p: List[bytes]) -> bool: - raise SetError() - - def other_included(self, tocheck: bytes, depth: int, p: List[bytes], collapse: bool) -> None: - p.append(TRUNCATED + self.hash) - - def _audit(self, hashes: List[bytes], bits: List[int]) -> None: - pass - - -class SetError(Exception): - pass - - -def confirm_included_already_hashed(root: bytes32, val: bytes, proof: bytes) -> bool: - return _confirm(root, val, proof, True) - - -def confirm_not_included_already_hashed(root: bytes32, val: bytes, proof: bytes) -> bool: - return _confirm(root, val, proof, False) - - -def _confirm(root: bytes32, val: bytes, proof: bytes, expected: bool) -> bool: - try: - p = deserialize_proof(proof) - if p.get_root() != root: - return False - r, junk = p.is_included_already_hashed(val) - return r == expected - except SetError: - return False - - -def deserialize_proof(proof: bytes) -> MerkleSet: - try: - r, pos = _deserialize(proof, 0, []) - if pos != len(proof): - raise SetError() - return MerkleSet._with_root(r) - except IndexError: - raise SetError() - - -def _deserialize(proof: bytes, pos: int, bits: List[int]) -> Tuple[Node, int]: - t = proof[pos : pos + 1] # flake8: noqa - if t == EMPTY: - return _empty, pos + 1 - if t == TERMINAL: - return TerminalNode(proof[pos + 1 : pos + 33], bits), pos + 33 # flake8: noqa - if t == TRUNCATED: - return TruncatedNode(proof[pos + 1 : pos + 33]), pos + 33 # flake8: noqa - if t != MIDDLE: - raise SetError() - v0, pos = _deserialize(proof, pos + 1, bits + [0]) - v1, pos = _deserialize(proof, pos, bits + [1]) - return MiddleNode([v0, v1]), pos diff --git a/chia/wallet/util/wallet_sync_utils.py b/chia/wallet/util/wallet_sync_utils.py index b320322df2e2..05f401f1e3bd 100644 --- a/chia/wallet/util/wallet_sync_utils.py +++ b/chia/wallet/util/wallet_sync_utils.py @@ -5,7 +5,7 @@ import random from typing import Any, List, Optional, Set, Tuple, Union -from chia_rs import compute_merkle_set_root +from chia_rs import compute_merkle_set_root, confirm_included_already_hashed, confirm_not_included_already_hashed from chia.full_node.full_node_api import FullNodeAPI from chia.protocols.shared_protocol import Capability @@ -36,7 +36,6 @@ from chia.types.coin_spend import CoinSpend, make_spend from chia.types.header_block import HeaderBlock from chia.util.ints import uint32 -from chia.util.merkle_set import confirm_included_already_hashed, confirm_not_included_already_hashed from chia.wallet.util.peer_request_cache import PeerRequestCache log = logging.getLogger(__name__)