In [30]:
import logging
from hexbytes import HexBytes
from web3 import Web3
from eth_abi import encode
%run mmr_helpers.ipynb

# Configure logging
logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s')

class InvalidProof(Exception):
    def __init__(self, *args):
        super().__init__(*args)
        logging.error("Invalid proof")

class IndexOutOfBounds(Exception):
    def __init__(self, *args):
        super().__init__(*args)
        logging.error("Index out of bounds")

class InvalidRoot(Exception):
    def __init__(self, *args):
        super().__init__(*args)
        logging.error("Invalid root")

class InvalidPeaksArrayLength(Exception):
    def __init__(self, *args):
        super().__init__(*args)
        logging.error("Invalid peaks array length")

class StatelessMmr:
    @staticmethod
    def append(elem, peaks, last_elements_count, last_root):
        updated_elements_count, new_root, _ = StatelessMmr.do_append(
            elem, peaks, last_elements_count, last_root
        )
        return updated_elements_count, new_root

    @staticmethod
    def append_with_peaks_retrieval(elem, peaks, last_elements_count, last_root):
        updated_elements_count, new_root, updated_peaks = StatelessMmr.do_append(
            elem, peaks, last_elements_count, last_root
        )
        return updated_elements_count, new_root, updated_peaks

    @staticmethod
    def multi_append(elems, peaks, last_elements_count, last_root):
        elements_count = last_elements_count
        root = last_root
        updated_peaks = peaks

        for elem in elems:
            elements_count, root, updated_peaks = StatelessMmr.append_with_peaks_retrieval(
                elem, updated_peaks, elements_count, root
            )
        return elements_count, root

    @staticmethod
    def multi_append_with_peaks_retrieval(elems, peaks, last_elements_count, last_root):
        elements_count = last_elements_count
        root = last_root
        updated_peaks = peaks

        for elem in elems:
            elements_count, root, updated_peaks = StatelessMmr.append_with_peaks_retrieval(
                elem, updated_peaks, elements_count, root
            )
        return elements_count, root, updated_peaks

    @staticmethod
    def multi_append_with_precomputed_peaks(elems, all_peaks, last_elements_count, last_root):
        elements_count = last_elements_count
        root = last_root

        for i in range(len(elems)):
            elements_count, root = StatelessMmr.append(
                elems[i], all_peaks[i], elements_count, root
            )
        return elements_count, root

    @staticmethod
    def verify_proof(index, value, proof, peaks, elements_count, root):
        if index > elements_count:
            raise IndexOutOfBounds()

        computed_root = StatelessMmr.compute_root(peaks, elements_count.to_bytes(32, 'big'))
        if computed_root != root:
            raise InvalidRoot()

        top_peak = StatelessMmr.get_proof_top_peak(0, value, index, proof)

        is_valid = StatelessMmrHelpers.array_contains(top_peak, peaks)
        if not is_valid:
            raise InvalidProof()

    @staticmethod
    def compute_root(peaks, size):
        bagged_peaks = StatelessMmr.bag_peaks(peaks)
        return Web3.keccak(Web3.to_bytes(size) + bagged_peaks)

    @staticmethod
    def bag_peaks(peaks):
        if len(peaks) < 1:
            raise InvalidPeaksArrayLength()
        if len(peaks) == 1:
            return peaks[0]

        len_peaks = len(peaks)
        root0 = Web3.keccak(peaks[len_peaks - 2] + peaks[len_peaks - 1])
        reversed_peaks = [peaks[len_peaks - 3 - i] for i in range(len_peaks - 2)]

        bags = root0
        for peak in reversed_peaks:
            bags = Web3.keccak(peak + bags)
        return bags
    
    @staticmethod
    def do_append(elem, peaks, last_elements_count, last_root):
        elements_count = last_elements_count + 1
        if last_elements_count == 0:
            root0 = elem
            first_root = Web3.keccak(encode(['uint256', 'bytes32'], [1, root0]))
            new_peaks = [root0]
            return elements_count, first_root, new_peaks

        leaf_count = StatelessMmrHelpers.mmr_size_to_leaf_count(elements_count - 1)
        number_of_peaks = StatelessMmrHelpers.count_ones(leaf_count)
        if len(peaks) != number_of_peaks:
            raise InvalidPeaksArrayLength()

        computed_root = StatelessMmr.compute_root(peaks, last_elements_count.to_bytes(32, 'big'))
        if computed_root != last_root:
            raise InvalidRoot()

        append_peaks = StatelessMmrHelpers.new_arr_with_elem(peaks, elem)

        append_no_merges = StatelessMmrHelpers.leaf_count_to_append_no_merges(leaf_count)
        updated_peaks = StatelessMmr.append_perform_merging(append_peaks, append_no_merges)

        updated_elements_count = elements_count + append_no_merges

        new_root = StatelessMmr.compute_root(updated_peaks, updated_elements_count.to_bytes(32, 'big'))
        return updated_elements_count, new_root, updated_peaks

    @staticmethod
    def append_perform_merging(peaks, no_merges):
        peaks_len = len(peaks)
        acc_hash = peaks[peaks_len - 1]
        for i in range(no_merges):
            hash_value = peaks[peaks_len - i - 2]
            acc_hash = Web3.keccak(encode(['bytes32', 'bytes32'], [hash_value, acc_hash]))
        new_peaks = peaks[:peaks_len - no_merges - 1]
        new_peaks.append(acc_hash)
        return new_peaks

    @staticmethod
    def get_proof_top_peak(height, hash_value, elements_count, proof):
        leaf_index = StatelessMmrHelpers.mmr_index_to_leaf_index(elements_count)
        for current_sibling in proof:
            is_right_child = leaf_index % 2 == 1
            if is_right_child:
                hashed = Web3.keccak(current_sibling + hash_value)
                elements_count += 1
                hash_value = hashed
            else:
                hashed = Web3.keccak(hash_value + current_sibling)
                elements_count += 2 << height
                hash_value = hashed
            height += 1
            leaf_index //= 2
        return hash_value
    
    @staticmethod
    def get_node_hash(index, peaks, height):
        if height == 0:
            leaf_index = StatelessMmrHelpers.mmr_index_to_leaf_index(index)
            if leaf_index < len(peaks):
                return peaks[leaf_index]
            else:
                return index.to_bytes(32, 'big')
        else:
            left_index = index * 2
            right_index = left_index + 1
            left_hash = StatelessMmr.get_node_hash(left_index, peaks, height - 1)
            right_hash = StatelessMmr.get_node_hash(right_index, peaks, height - 1)
            return Web3.keccak(left_hash + right_hash)
    
    @staticmethod
    def generate_proof(index, peaks, elements_count):
        print('index', index)
        if index < 1 or index > StatelessMmrHelpers.mmr_size_to_leaf_count(elements_count):
            raise IndexOutOfBounds()

        leaf_index = StatelessMmrHelpers.mmr_index_to_leaf_index(index)
        print('leaf_index', leaf_index)
        proof = []
        current_index = leaf_index
        height = 0
        print('current_index', current_index)

        while current_index > 0:
            sibling_index = current_index ^ 1
            if sibling_index < StatelessMmrHelpers.leaf_count_to_mmr_size(elements_count):
                sibling_hash = StatelessMmr.get_node_hash(sibling_index, peaks, height)
                proof.append(sibling_hash)
            current_index //= 2
            height += 1

        return proof[::-1]  # Reverse the order of sibling hashes