In [19]:
from collections import deque
from web3 import Web3
from eth_abi import encode

In [288]:
GENESIS_STRING = "brave new world"

class KeccakHasher:
    def __init__(self, options=None):
        if options is None:
            options = {
                "blockSizeBits": 256
            }
        self.options = options

    def array_to_uint8_array(self, data):
        result = []
        for e in data:
            if e is None:
                continue  # Skip None values
            if isinstance(e, str) and e.startswith('0x'):
                n = int(e, 16)  # Convert hexadecimal string to integer
            else:
                n = int(e)  # Convert decimal string or integer
            hex_str = format(n, '064x')
            bytes_list = [int(hex_str[i:i+2], 16) for i in range(0, len(hex_str), 2)]
            result.extend(bytes_list)
        return bytearray(result)

    def hash(self, data):
        # If the data is empty, return the hash of an empty array
        if len(data) == 0:
            return Web3.keccak(b'').hex()
        # If the data has a single element, return the hash of that element using `web3` keccak
        if len(data) == 1:
            return Web3.keccak(hexstr=data[0]).hex()

        bytes_data = self.array_to_uint8_array(data)
        return Web3.keccak(bytes(bytes_data)).hex()

    
    def is_element_size_valid(self, element):
        return KeccakHasher.byte_size(element) <= self.options["blockSizeBits"]

    @staticmethod
    def byte_size(value):
        if isinstance(value, str):
            hex_str = value[2:] if value.startswith("0x") else value
            return len(str.encode(hex_str))
        elif isinstance(value, int):
            return len(value.to_bytes((value.bit_length() + 7) // 8, 'big'))
        else:
            raise ValueError("Invalid input type. Expected string or integer.")

    def hash_single(self, data):
        return self.hash([data])

    def get_genesis(self):
        s = GENESIS_STRING
        hex_str = "0x"
        for char in s:
            hex_str += hex(ord(char))[2:].zfill(2)
        return self.hash_single(hex_str)

In [289]:
class Store:
    def __init__(self):
        self.store = {}

    def get(self, key):
        return self.store.get(key)

    def get_many(self, keys):
        result = {}
        for key in keys:
            value = self.get(key)
            if value:
                result[key] = value
        return result

    def set(self, key, value):
        self.store[key] = value

    def set_many(self, entries):
        for key, value in entries.items():
            self.set(key, value)

    def delete(self, key):
        if key in self.store:
            del self.store[key]

    def delete_many(self, keys):
        for key in keys:
            self.delete(key)

In [290]:
def bit_length(num):
    return len(bin(num)[2:])

def all_ones(num):
    return (1 << bit_length(num)) - 1 == num

def get_height(element_index):
    h = element_index
    while not all_ones(h):
        h = h - ((1 << (bit_length(h) - 1)) - 1)
    return bit_length(h) - 1

def sibling_offset(height):
    return parent_offset(height) - 1

def parent_offset(height):
    return 2 << height

def bintree_jump_right_sibling(element_index):
    height = get_height(element_index)
    return element_index + (1 << (height + 1)) - 1

def bintree_move_down_left(element_index):
    height = get_height(element_index)
    if height == 0:
        return 0
    return element_index - (1 << height)

def find_peaks(elements_count):
    if elements_count == 0:
        return []

    if get_height(elements_count + 1) > get_height(elements_count):
        return []

    top = 1
    while top - 1 <= elements_count:
        top <<= 1
    top = (top >> 1) - 1
    if top == 0:
        return [1]

    peaks = [top]
    peak = top
    outer = True
    while outer:
        peak = bintree_jump_right_sibling(peak)
        while peak > elements_count:
            peak = bintree_move_down_left(peak)
            if peak == 0:
                outer = False
                break
        if outer:
            peaks.append(peak)
    return peaks

In [291]:
from functools import reduce


class MMR:
    def __init__(self, store, hasher, mmr_id=None):
        self.store = store
        self.hasher = hasher
        self.mmr_id = mmr_id
        self.elements_count = 0
        self.root_hash = None
        self.leaves_count = 0
        self.hashes = {}

    @staticmethod
    def create_with_genesis(store, hasher, mmr_id=None):
        mmr = MMR(store, hasher, mmr_id)
        if mmr.get_elements_count() != 0:
            raise Exception("Cannot call createWithGenesis on a non-empty MMR. Please provide an empty store or change the MMR id.")
        mmr.append(hasher.get_genesis())
        return mmr

    def append(self, value):
        if not self.hasher.is_element_size_valid(value):
            raise Exception("Element size is too big to hash with this hasher")

        elements_count = self.get_elements_count(self)
        peaks = self.retrieve_peaks_hashes(find_peaks(elements_count))

        self.elements_count += 1
        last_element_idx = self.elements_count
        leaf_element_index = last_element_idx

        self.hashes[last_element_idx] = value

        peaks.append(value)

        height = 0
        while get_height(last_element_idx + 1) > height:
            print("height", height)
            last_element_idx += 1

            right_hash = peaks.pop()
            left_hash = peaks.pop()
            print("left_hash", left_hash)
            print("right_hash", right_hash)

            parent_hash = self.hasher.hash([left_hash, right_hash])
            self.hashes[last_element_idx] = parent_hash 
            peaks.append(parent_hash)

            height += 1

        self.set_elements_count = last_element_idx

        bag = self.bag_the_peaks()

        root_hash = self.calculate_root_hash(bag, last_element_idx)
        self.root_hash = root_hash

        self.leaves_count += 1
        leaves = self.leaves_count

        return {
            "leavesCount": leaves,
            "elementsCount": last_element_idx,
            "elementIndex": leaf_element_index,
            "rootHash": root_hash,
        }
        
    def get_proof(self, element_index, options=None):
        if element_index < 1:
            raise Exception("Index must be greater than 1")

        if options is None:
            options = {}

        elements_count = options.get("elementsCount", self.get_elements_count())
        if element_index > elements_count:
            raise Exception("Index must be less than the tree tree size")

        peaks = find_peaks(elements_count)
        siblings = []

        index = element_index
        while index not in peaks:
            is_right = get_height(index + 1) == get_height(index) + 1
            sib = index - sibling_offset(get_height(index)) if is_right else index + sibling_offset(get_height(index))
            siblings.append(sib)

            index = index + 1 if is_right else index + parent_offset(get_height(index))

        peaks_hashes = self.retrieve_peaks_hashes(peaks, options.get("formattingOpts", {}).get("peaks"))
        siblings_hashes = list(self.get_hashes(siblings))

        return {
            "elementIndex": element_index,
            "elementHash": self.get_hash(element_index),
            "siblingsHashes": siblings_hashes,
            "peaksHashes": peaks_hashes,
            "elementsCount": elements_count,
        }
        
    def verify_proof(self, proof, element_value, options=None):
        if options is None:
            options = {}

        elements_count = options.get("elementsCount", self.get_elements_count())

        if "formattingOpts" in options:
            proof_format = options["formattingOpts"].get("proof")
            peaks_format = options["formattingOpts"].get("peaks")

            proof_null_values_count = proof["siblingsHashes"].count(proof_format["nullValue"])
            proof["siblingsHashes"] = proof["siblingsHashes"][:len(proof["siblingsHashes"]) - proof_null_values_count]

            peaks_null_values_count = proof["peaksHashes"].count(peaks_format["nullValue"])
            proof["peaksHashes"] = proof["peaksHashes"][:len(proof["peaksHashes"]) - peaks_null_values_count]

        element_index = proof["elementIndex"]
        siblings_hashes = proof["siblingsHashes"]

        if element_index < 1:
            raise Exception("Index must be greater than 1")
        if element_index > elements_count:
            raise Exception("Index must be in the tree")

        hash_value = element_value

        for proof_hash in siblings_hashes:
            is_right = get_height(element_index + 1) == get_height(element_index) + 1
            element_index = element_index + 1 if is_right else element_index + parent_offset(get_height(element_index))
            hash_value = self.hasher.hash([proof_hash, hash_value]) if is_right else self.hasher.hash([hash_value, proof_hash])

        return hash_value in self.retrieve_peaks_hashes(find_peaks(elements_count))
    
    def bag_the_peaks(self, elements_count=None):
        tree_size = elements_count if elements_count is not None else self.elements_count
        peaks_idxs = find_peaks(tree_size)
        peaks_hashes = self.retrieve_peaks_hashes(peaks_idxs)

        if len(peaks_idxs) == 0:
            return "0x0"
        elif len(peaks_idxs) == 1:
            return peaks_hashes[0]

        root0 = self.hasher.hash([peaks_hashes[-2], peaks_hashes[-1]])
        root = reduce(lambda prev, cur: self.hasher.hash([cur, prev]), reversed(peaks_hashes[:-2]), root0)

        return 
    
    def get_peaks(self):
        if options is None:
            options = {}

        elements_count = options.get("elementsCount", self.get_elements_count())
        formatting_opts = options.get("formattingOpts")

        peaks_idxs = find_peaks(elements_count)
        peaks = self.retrieve_peaks_hashes(peaks_idxs)

        return peaks
    
    def calculate_root_hash(self, bag, leaf_count):
        print("bag", bag)
        print("leaf_count", leaf_count)
        return self.hasher.hash([leaf_count, bag])
    
    @staticmethod
    def count_ones(value):
        n = value
        ones_count = 0
        while n > 0:
            n = n & (n - 1)
            ones_count += 1
        return ones_count
    
    @staticmethod
    def map_leaf_index_to_element_index(self, leaf_index):
        return 2 * leaf_index + 1 - self.count_ones(leaf_index)

    def retrieve_peaks_hashes(self, peaks_idxs):
        hashes = self.get_hashes(peaks_idxs)
        
        return hashes
    
    def get_hashes(self, indexes):
        return [self.hashes[idx] for idx in indexes]
    
    @staticmethod
    def get_elements_count(self):
        return self.elements_count

In [293]:
mmr = MMR(Store(), KeccakHasher())
mmr.append(0x0)
mmr.append(0x1)
mmr.append(0x2)
mmr.append(0x3)

bag 0
leaf_count 1
height 0
left_hash 0
right_hash 1
bag 0x0
leaf_count 3
bag 2
leaf_count 3
bag None
leaf_count 4


{'leavesCount': 4,
 'elementsCount': 4,
 'elementIndex': 4,
 'rootHash': '0x8a35acfbc15ff81a39ae7d344fd709f28e8600b4aa8c65c6b64bfe7fe36bd19b'}