In [2]:
from __future__ import annotations
from typing import Callable, List, Optional, Tuple

In [3]:
class MerkleTree:
    def __init__(self, strings: List[str], hash_func: Callable[[str], str]):
        """
        Arguments :
        strings : The set of strings S
        """
        self.root: Optional[Node] = None
        self.strings = strings
        self.hash_func = hash_func
        self.leaf_nodes: List[Node] = []
        self.create_tree()

    def create_tree(self) -> None:
        """Creates the Merkle Tree"""
        # Create the root node
        self.root = Node(None, self.hash_func)
        # Initialize level
        level = 1
        # Add child nodes to the level till the tree is complete
        self.add_child_nodes(self.root, level, list(self.strings))

    def add_child_nodes(self, parent: Node, level: int, remaining_strings: List[str]) -> Optional[str]:
        """
        Add child to nodes recursively. After all child nodes are added the node value is set as
        the hash of the concatenation of the left and right child values.
        Arguments :
            parent : Node to add childs
            level  : Level of the node
        Returns: The node value
        """

        if len(self.strings) > 2 ** (level - 1):
            # Intern nodes, must have a left and right child always
            parent.left_child = Node(parent, self.hash_func)
            parent.right_child = Node(parent, self.hash_func)
            self.add_child_nodes(parent.left_child, level + 1, remaining_strings)
            self.add_child_nodes(parent.right_child, level + 1, remaining_strings)
            # If not leaf node, set the node value as the hash of the concatenation of the left 
            # and right child values
            if not parent.right_child.value:
                parent.right_child.value = parent.left_child.value
                if not parent.left_child.value:
                    parent.left_child = None
                    parent.right_child = None
                    return ""
            parent.value = self.hash_func(parent.left_child.value + parent.right_child.value)
        elif len(remaining_strings) > 0:
            # Leaf nodes. Pop an the remaining strings and set the node value as the hash of the
            # popped string.
            string = remaining_strings.pop(0)
            parent.left_child = None
            parent.right_child = None
            parent.value = self.hash_func(string)
            self.leaf_nodes.append(parent)
        return parent.value

    def dfs(self, current_node=None):
        """Execute dfs and print nodes"""
        if not current_node:
            current_node = self.root

        print(current_node.value)

        if current_node.left_child:
            self.dfs(current_node.left_child)
        if current_node.right_child:
            self.dfs(current_node.right_child)

    def get_root(self) -> str:
        """
        Returns :
        root : Root of the Merkle Tree
        """
        if self.root:
            return self.root.value
        else:
            return ""

    def get_proof_for(self, item: str) -> Optional[List[Tuple[str, str]]]:
        """
        Returns:
            result: None if the item is not part of the leafs of the tree
                    A list with the necessary info to prove that the
                    item is part of the leafs of the tree
        """
        proof_items = []
        current_node = None

        # Search for the item in the leaf list and return them
        for leaf in self.leaf_nodes:
            if leaf.value == self.hash_func(item):
                current_node = leaf
                break
        # If leaf item not exists return none
        if not current_node:
            return None
        # Find the proof
        while current_node.parent:
            parent = current_node.parent
            if not parent.left_child == current_node and parent.left_child:
                proof_items.append((parent.left_child.value, "i"))
            if not parent.right_child == current_node and parent.right_child:
                proof_items.append((parent.right_child.value, "d"))

            current_node = current_node.parent
        return proof_items


class Node:
    def __init__(self, parent: Optional[Node], hash_func: Callable[[str], str], value: str = ""):
        """
        Arguments :
        parent : Parent of the node
        """
        self.hash_func = hash_func
        self.value = value
        self.parent: Optional[Node] = parent
        self.left_child: Optional[Node] = None
        self.right_child: Optional[Node] = None



In [4]:
def verify(
    root: str,
    item: str,
    proof: List[Tuple[str, str]],
    hash_func: Callable[[str], str],
) -> bool:
    """
    Arguments :
        root:  The root of a merkle tree
        item:  An abritrary string
        proof: An alleged proof that item is part of a Merkle
               tree with root root
        hash_func: An arbitrary hash function
    Returns :
        correct: whether the proof is correct or not
    """
    current_string = hash_func(item)
    for string, pos in proof:
        if pos == "i":
            current_string = hash_func(string + current_string)
        if pos == "d":
            current_string = hash_func(current_string + string)

    if current_string == root:
        return True
    return False

In [5]:
tree = MerkleTree(["s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "s9"], (lambda x: x + "h"))
item = "s9"
proof = tree.get_proof_for(item)
if proof and verify(tree.get_root(), item, proof, tree.hash_func):
    print("The proof is correct")

The proof is correct
