In [1]:
class Node:
    def __init__(self, value, parent_node= None):
        self.value = value
        self.right_node = None
        self.left_node = None
        self.parent_node = parent_node

class BinarySearchTree:
    def __init__(self):
        self.root = None
    
    def insert(self, value):
        if self.root is None:
            self.root = Node(value)
        else: 
            self.insert_node(value, self.root)
    
    def insert_node(self, value, node):
        if value < node.value:
            if node.left_node:
                self.insert_node(value, node.left_node)
            else:
                node.left_node = Node(value, node)
        else:
            if node.right_node:
                self.insert_node(value, node.right_node)
            else:
                node.right_node = Node(value, node)
    
    def get_min(self):
        current_node = self.root
        while current_node.left_node:
            current_node = current_node.left_node
        return current_node.value
    
    def get_max(self):
        current_node = self.root
        while current_node.right_node:
            current_node = current_node.right_node
        return current_node.value
    
    def traverse(self):
        if self.root:
            self.traverse_in_order(self.root)
    
    def traverse_in_order(self, node):
        if node.left_node:
            self.traverse_in_order(node.left_node)
        print(node.value)
        
        if node.right_node:
            self.traverse_in_order(node.right_node)
    
    def remove(self, node):
        if self.root:
            self.remove_node(node, self.root)
    
    def remove_node(self, value, node):
        if node is None:
            return None
        
        if value < node.value:
            self.remove_node(value, node.left_node)
        elif value > node.value:
            self.remove_node(value, node.right_node)
        else: 
            #leaf node
            if node.left_node is None and node.right_node is None:
                print(f"Leaf node - {node.value} - has been removed.")
                parent = node.parent_node
                
                if parent is not None and parent.left_node == node:
                    parent.left_node = None
                if parent is not None and parent.right_node == node:
                    parent.right_node = None
                
                if parent is None:
                    self.root = node.right_node
                
                del node
                
            # single child 
            elif node.left_node is None and node.right_node is not None:
                print("Removing a node with single right child.")
                parent = node.parent_node
                
                if parent is not None and parent.left_node == node:
                    parent.left_node = node.right_node
                if parent is not None and parent.right_node == node:
                    parent.right_node = node.right_node
                
                if parent is None:
                    self.root = None
                node.right_node.parent = parent
                del node
                
            elif node.right_node is None and node.left_node is not None:
                print("Removing a node with single  child.")
                parent = node.parent_node
                
                if parent is not None:
                    if parent.left_node == node:
                        parent.left_node = node.left_node
                    if parent.right_node == node:
                        parent.right_node = node.left_node
                else:
                    self.root = node.left_node
                node.left_node.parent = parent
                del node
           
            else:
                print("Removing node with two children.")
                predecessor = self.get_predecessor(node.left_node)
                
                temp = predecessor.value
                predecessor.value = node.value
                node.value = temp
            
    def get_predecessor(self, node):
        
        if node.right_node:
            return self.get_predecessor(node.right_node)
        
        return node

In [2]:
bst = BinarySearchTree()

In [3]:
bst.insert(20)
bst.insert(45)
bst.insert(19)
bst.insert(18)
bst.insert(17)
bst.insert(14)

In [4]:
bst.get_min()

14

In [5]:
bst.get_max()

45

In [6]:
bst.traverse()

14
17
18
19
20
45


In [7]:
bst.remove(14)

Leaf node - 14 - has been removed.


## Example

Write an efficient algorithm that is able to compare two binary search trees. The method returns *true* if the trees are identical (same topologies with same values), otherwise it returns *false*.

In [8]:
class TreeComparator:
    def compare(self, node1, node2):
        if not node1 or not node2: # check leaf nodes
            return node1 == node2
        if node1.value is not node2.value: # check values in the nodes
            return False
        return self.compare(node1.left_node, node2.left_node) and self.compare(node1.right_node, node2.right_node)
        
class Node:
    def __init__(self, value, parent_node= None):
        self.value = value
        self.right_node = None
        self.left_node = None
        self.parent_node = parent_node

class BinarySearchTree:
    def __init__(self):
        self.root = None
    
    def insert(self, value):
        if self.root is None:
            self.root = Node(value)
        else: 
            self.insert_node(value, self.root)
    
    def insert_node(self, value, node):
        if value < node.value:
            if node.left_node:
                self.insert_node(value, node.left_node)
            else:
                node.left_node = Node(value, node)
        else:
            if node.right_node:
                self.insert_node(value, node.right_node)
            else:
                node.right_node = Node(value, node)
    
    def get_min(self):
        current_node = self.root
        while current_node.left_node:
            current_node = current_node.left_node
        return current_node.value
    
    def get_max(self):
        current_node = self.root
        while current_node.right_node:
            current_node = current_node.right_node
        return current_node.value
    
    def traverse(self):
        if self.root:
            self.traverse_in_order(self.root)
    
    def traverse_in_order(self, node):
        if node.left_node:
            self.traverse_in_order(node.left_node)
        print(node.value)
        
        if node.right_node:
            self.traverse_in_order(node.right_node)
    
    def remove(self, node):
        if self.root:
            self.remove_node(node, self.root)
    
    def remove_node(self, value, node):
        if node is None:
            return None
        
        if value < node.value:
            self.remove_node(value, node.left_node)
        elif value > node.value:
            self.remove_node(value, node.right_node)
        else: 
            #leaf node
            if node.left_node is None and node.right_node is None:
                print(f"Leaf node - {node.value} - has been removed.")
                parent = node.parent_node
                
                if parent is not None and parent.left_node == node:
                    parent.left_node = None
                if parent is not None and parent.right_node == node:
                    parent.right_node = None
                
                if parent is None:
                    self.root = node.right_node
                
                del node
                
            # single child 
            elif node.left_node is None and node.right_node is not None:
                print("Removing a node with single right child.")
                parent = node.parent_node
                
                if parent is not None and parent.left_node == node:
                    parent.left_node = node.right_node
                if parent is not None and parent.right_node == node:
                    parent.right_node = node.right_node
                
                if parent is None:
                    self.root = None
                node.right_node.parent = parent
                del node
                
            elif node.right_node is None and node.left_node is not None:
                print("Removing a node with single  child.")
                parent = node.parent_node
                
                if parent is not None:
                    if parent.left_node == node:
                        parent.left_node = node.left_node
                    if parent.right_node == node:
                        parent.right_node = node.left_node
                else:
                    self.root = node.left_node
                node.left_node.parent = parent
                del node
           
            else:
                print("Removing node with two children.")
                predecessor = self.get_predecessor(node.left_node)
                
                temp = predecessor.value
                predecessor.value = node.value
                node.value = temp
            
    def get_predecessor(self, node):
        
        if node.right_node:
            return self.get_predecessor(node.right_node)
        
        return node

In [9]:
bst1 = BinarySearchTree()

bst1.insert(10)
bst1.insert(5)
bst1.insert(15)

bst2 = BinarySearchTree()

bst2.insert(10)
bst2.insert(5)
bst2.insert(15)

In [10]:
compare = TreeComparator()
compare.compare(bst1.root, bst2.root)

True

In [11]:
bst3 = BinarySearchTree()

bst3.insert(45)
bst.insert(20)
bst.insert(34)

In [12]:
compare.compare(bst2.root, bst3.root)

False