In [56]:
from hexbytes import HexBytes
from web3 import Web3
%run stateless_mmr.ipynb

class MMR:
    def __init__(self):
        self.tree_root = None  # Root hash of the tree
        self.node_index_to_root = {}  # Mapping of node index to relative root hash
        self.node_index_to_peaks = {}  # Mapping of node index to peaks
        self.last_peaks = []  # Peaks of the last tree
        self.last_elements_count = 0  # Latest elements count
        self.last_root = None  # Latest root hash
        self.hashes = {}  # List of hashes of the tree
        self.elements_count = 0  # Number of elements in the tree
        


    def append(self, element):
        # Convert element to bytes if needed
        if not isinstance(element, bytes):
            element = element.to_bytes(32, 'big')

        # Append element to the tree
        next_elements_count, next_root_hash, next_peaks = StatelessMmr.append_with_peaks_retrieval(  # type: ignore
            element, self.last_peaks, self.last_elements_count, self.last_root
        )
        print(f"next_elements_count: {next_elements_count}")
        print(f"next_root_hash: {next_root_hash.hex()}")
        print(f"next_peaks: {[peak.hex() for peak in next_peaks]}")

        # Convert next_peaks to hex format
        next_peaks_hex = [HexBytes(peak) for peak in next_peaks]

        # Update contract state
        self.last_peaks = next_peaks_hex
        self.last_elements_count = next_elements_count
        self.last_root = next_root_hash
        self.hashes[get_element_index(next_elements_count)] = element
        self.elements_count += 1

        # Emit event
        self.emit_appended(element, self.last_root, self.last_elements_count)

    def multi_append(self, elements):
        # Append multiple elements to the tree
        next_elements_count = self.last_elements_count
        next_root = self.last_root
        next_peaks = self.last_peaks

        for element in elements:
            next_elements_count, next_root, next_peaks = StatelessMmr.append_with_peaks_retrieval( # type: ignore
                element, next_peaks, next_elements_count, next_root
            )

            # Emit event for each appended element
            self.emit_appended(element, next_root, next_elements_count)

        # Update contract state
        self.last_peaks = next_peaks
        self.last_elements_count = next_elements_count
        self.last_root = next_root
        self.node_index_to_root[next_elements_count] = self.last_root
        self.node_index_to_peaks[next_elements_count] = self.last_peaks

    def get_root_hash(self):
        # Return the root hash of the tree
        return self.last_root

    def get_elements_count(self):
        # Return the number of nodes in the tree
        return self.last_elements_count

    @staticmethod
    def verify_proof(index, value, proof, peaks, elements_count, root):
        # Verify the proof
        StatelessMmr.verify_proof(index, value, proof, peaks, elements_count, root) # type: ignore

    def get_proof(self, element_index, options=None):
        if element_index == 0:
            raise InvalidElementIndex()

        options = options or ProofOptions()
        tree_size = self.get_elements_count()

        if element_index > tree_size:
            raise InvalidElementIndex()

        peaks = find_peaks(tree_size)

        siblings = find_siblings(element_index, tree_size)
        print(f"Siblings: {siblings}")

        formatting_opts = options.formatting_opts.peaks if options.formatting_opts else None
        peaks_hashes = self.last_peaks
        print(f"Peaks: {peaks_hashes}")

        siblings_hashes_vec = [StatelessMmr.get_node_hash(node, self.last_peaks, ) for node in siblings]
        for idx in siblings:
            print("idx: ", idx)
            sibling_hash = self.hashes[idx]
            if sibling_hash is not None:
                siblings_hashes_vec.append(hex(int(sibling_hash.hex(), 16)))

        element_hash = self.hashes[element_index]
        if element_hash is None:
            raise NoHashFoundForIndex(element_index)

        return Proof(
            element_index=element_index,
            element_hash=hex(int(element_hash.hex(), 16)),
            siblings_hashes=siblings_hashes_vec,
            peaks_hashes=[peak_hash for peak_hash in peaks_hashes],
            elements_count=tree_size,
        )
    
    def retrieve_peaks_hashes(self, peaks):
        peaks_hashes = []
        for peak in peaks:
            peak_hash = Web3.keccak(peak)
            if peak_hash is not None:
                peaks_hashes.append(peak_hash.hex())
        return peaks_hashes

    def emit_appended(self, element, root_hash, elements_count):
        # Emit the Appended event
        event = {
            "element": int.from_bytes(element, byteorder='big'),
            "rootHash": root_hash,
            "elementsCount": elements_count
        }
        print(f"Appended event: {event}")
        
def get_element_index(n):
    if n < 0:
        return None
    
    if n == 0:
        return 1
    elif n == 1:
        return 2
    else:
        # Calculate the largest power of 2 less than or equal to n
        largest_power_of_2 = 1 << (n.bit_length() - 1)
        
        if n < largest_power_of_2 * 2 - 1:
            # If n is within the range of current power of 2 peak
            return 2 * n - largest_power_of_2 + 1
        else:
            # If n is starting a new power of 2 peak
            return 2 * n - largest_power_of_2 + 2
 

def find_peaks(elements_count):
    mountain_elements_count = (1 << elements_count.bit_length()) - 1
    mountain_index_shift = 0
    peaks = []

    while mountain_elements_count > 0:
        if mountain_elements_count <= elements_count:
            mountain_index_shift += mountain_elements_count
            peaks.append(mountain_index_shift)
            elements_count -= mountain_elements_count
        
        mountain_elements_count >>= 1;

    if elements_count > 0:
        return []

    return peaks

def element_index_to_leaf_index(element_index):
    if element_index == 0:
        raise InvalidElementIndex()
    return elements_count_to_leaf_count(element_index - 1)

def elements_count_to_leaf_count(elements_count):
    leaf_count = 0
    mountain_leaf_count = 1 << elements_count.bit_length()
    current_elements_count = elements_count

    while mountain_leaf_count > 0:
        mountain_elements_count = 2 * mountain_leaf_count - 1;
        if mountain_elements_count <= current_elements_count:
            leaf_count += mountain_leaf_count;
            current_elements_count -= mountain_elements_count
        mountain_leaf_count >>= 1

    if current_elements_count > 0:
        raise InvalidElementCount()
    else:
        return leaf_count

def find_siblings(element_index, elements_count):
    leaf_index = element_index_to_leaf_index(element_index)
    height = 0
    siblings = []
    current_element_index = element_index

    while current_element_index < elements_count:
        siblings_offset = (2 << height) - 1
        if leaf_index % 2  == 1: 
            siblings.append(current_element_index - siblings_offset)
            current_element_index += 1
        else:
            siblings.append(current_element_index + siblings_offset)
            current_element_index += siblings_offset + 1
        leaf_index /= 2
        height += 1
    
    # if len(siblings) > 0:
        # siblings.pop()
    return siblings

def leaf_count_to_append_no_merges(leaf_count):
    return count_trailing_ones(leaf_count)

def count_trailing_ones(num):
    count = 0;
    while num != 0 and num & 1 == 1:
        num >>= 1;
        count += 1;
    
    return count


def format_proof(siblings_hashes, formatting_opts):
    # Format the siblings hashes based on the formatting options
    # Return the formatted siblings hashes
    pass

def format_peaks(peaks_hashes, formatting_opts):
    # Format the peaks hashes based on the formatting options
    # Return the formatted peaks hashes
    pass

class ProofOptions:
    def __init__(self, elements_count=None, formatting_opts=None):
        self.elements_count = elements_count
        self.formatting_opts = formatting_opts

class Proof:
    def __init__(self, element_index, element_hash, siblings_hashes, peaks_hashes, elements_count):
        self.element_index = element_index
        self.element_hash = element_hash
        self.siblings_hashes = siblings_hashes
        self.peaks_hashes = peaks_hashes
        self.elements_count = elements_count

    def __str__(self):
        return f"""{{
    "elementIndex": {self.element_index},
    "elementHash": "{self.element_hash}",
    "siblingHashes": {self.siblings_hashes},
    "peaksHashes": {self.peaks_hashes},
    "elementsCount": {self.elements_count}
}}"""

class InvalidElementIndex(Exception):
    pass

class InvalidElementCount(Exception):
    pass

class NoHashFoundForIndex(Exception):
    def __init__(self, index):
        self.index = index

In [57]:
mmr = MMR()
mmr.append(0)
mmr.append(1)
mmr.append(2)
mmr.append(3)
# mmr.append(4)
# mmr.append(5)
# mmr.append(6)

proof2 = mmr.get_proof(1)
# print(f"Proof for element: {proof2}")

next_elements_count: 1
next_root_hash: 0xada5013122d395ba3c54772283fb069b10426056ef8ca54750cb9bb552a59e7d
next_peaks: ['0000000000000000000000000000000000000000000000000000000000000000']
Appended event: {'element': 0, 'rootHash': HexBytes('0xada5013122d395ba3c54772283fb069b10426056ef8ca54750cb9bb552a59e7d'), 'elementsCount': 1}
next_elements_count: 3
next_root_hash: 0xc97a69e6e2de1bb9e27f629ecf2981a64edb688b55347fa4daae7dde857b7d91
next_peaks: ['0xa6eef7e35abe7026729641147f7915573c7e97b47efa546f5f6e3230263bcb49']
Appended event: {'element': 1, 'rootHash': HexBytes('0xc97a69e6e2de1bb9e27f629ecf2981a64edb688b55347fa4daae7dde857b7d91'), 'elementsCount': 3}
next_elements_count: 4
next_root_hash: 0xe9899fa57e4d849893dc39931799693814a469f16500b673d7d918e03e244a25
next_peaks: ['0xa6eef7e35abe7026729641147f7915573c7e97b47efa546f5f6e3230263bcb49', '0000000000000000000000000000000000000000000000000000000000000002']
Appended event: {'element': 2, 'rootHash': HexBytes('0xe9899fa57e4d849893dc399317

NameError: name 'node' is not defined

In [179]:
def test_verify_proof_one_leaf():
    peaks = []
    new_pos, new_root = StatelessMmr.append(
        (1).to_bytes(32, 'big'),
        peaks,
        0,
        Web3.keccak(b'')
    )
    assert new_pos == 1
    proof = StatelessMmr.generate_proof(new_pos, peaks, )

    node1 = (1).to_bytes(32, 'big')
    peaks = StatelessMmrHelpers.new_arr_with_elem(peaks, node1)

    try:
        computed_root = StatelessMmr.compute_root(peaks, new_pos.to_bytes(32, 'big'))

        StatelessMmr.verify_proof(
            1,
            (1).to_bytes(32, 'big'),
            [],
            peaks,
            new_pos,
            new_root
        )
        print("Proof verification passed!")
    except Exception as e:
        print(f"Proof verification failed: {str(e)}")

test_verify_proof_one_leaf()

TypeError: StatelessMmr.generate_proof() missing 1 required positional argument: 'elements_count'

In [39]:
from eth_abi import encode

def test_append_initial():
    peaks = []
    node1 = (1).to_bytes(32, 'big')

    new_pos, new_root = StatelessMmr.append(
        (1).to_bytes(32, 'big'),
        peaks,
        0,
        bytes(32)
    )
    assert new_pos == 1

    expected_root = Web3.keccak(encode(['bytes32', 'bytes32'], [(1).to_bytes(32, 'big'), node1]))
    assert new_root == expected_root

    peaks = [node1]
    expected_root_method2 = StatelessMmr.compute_root(peaks, new_pos.to_bytes(32, 'big'))
    assert new_root == expected_root_method2

    return new_pos, new_root, node1

def test_append_one():
    last_pos, last_root, node1 = test_append_initial()
    peaks = [node1]

    new_pos, new_root = StatelessMmr.append(
        (2).to_bytes(32, 'big'),
        peaks,
        last_pos,
        last_root
    )
    assert new_pos == 3

    node2 = (2).to_bytes(32, 'big')
    node3 = Web3.keccak(encode(['bytes32', 'bytes32'], [node1, node2]))
    expected_root = Web3.keccak(encode(['bytes32', 'bytes32'], [(3).to_bytes(32, 'big'), node3]))
    assert new_root == expected_root

    return new_pos, new_root, node3

def test_append_two():
    last_pos, last_root, node3 = test_append_one()
    peaks = [node3]

    new_pos, new_root = StatelessMmr.append(
        (4).to_bytes(32, 'big'),
        peaks,
        last_pos,
        last_root
    )
    assert new_pos == 4

    node4 = (4).to_bytes(32, 'big')
    peaks = StatelessMmrHelpers.new_arr_with_elem(peaks, node4)
    expected_root = StatelessMmr.compute_root(peaks, new_pos.to_bytes(32, 'big'))
    assert new_root == expected_root

    return new_pos, new_root, peaks

def test_verify_proof_four_leaves():
    last_pos, last_root, last_peaks = test_append_two()
    new_pos, new_root = StatelessMmr.append(
        (5).to_bytes(32, 'big'),
        last_peaks,
        last_pos,
        last_root
    )
    assert new_pos == 7

    proof = StatelessMmrHelpers.new_arr_with_elem([], last_peaks[1])
    proof = StatelessMmrHelpers.new_arr_with_elem(proof, last_peaks[0])

    node5 = (5).to_bytes(32, 'big')
    node6 = Web3.keccak(encode(['bytes32', 'bytes32'], [last_peaks[1], node5]))
    node7 = Web3.keccak(encode(['bytes32', 'bytes32'], [last_peaks[0], node6]))
    peaks = StatelessMmrHelpers.new_arr_with_elem([], node7)

    try:
        StatelessMmr.verify_proof(
            5,
            (5).to_bytes(32, 'big'),
            proof,
            peaks,
            new_pos,
            new_root
        )
        print("Proof verification passed!")
    except Exception as e:
        print(f"Proof verification failed: {str(e)}")

test_verify_proof_four_leaves()

Proof verification passed!


In [249]:
def calculate_element_index(n):
    index = 0
    shift = 1
    while n >= shift:
        index += shift
        n -= shift
        shift <<= 1
    return index + n + 1

# Testing the function with given indices
indices = [calculate_element_index(i) for i in range(10)]
indices


[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]