In [1]:
class Node:

    def __init__(self, data, parent):
        self.data = data
        self.parent = parent
        self.right_node = None
        self.left_node = None

In [2]:
class BinarySearchTree:

    def __init__(self):
        self.root = None

    def remove(self, data):
        if self.root:
            self.remove_node(data, self.root)

    def insert(self, data):
        if self.root is None:
            self.root = Node(data, None)
        else:
            self.insert_node(data, self.root)

    def insert_node(self, data, node):
        # we have to go to the left subtree
        if data < node.data:
            if node.left_node:
                self.insert_node(data, node.left_node)
            else:
                node.left_node = Node(data, node)
        # we have to visit the right subtree
        else:
            if node.right_node:
                self.insert_node(data, node.right_node)
            else:
                node.right_node = Node(data, node)


    def remove_node(self, data, node):
        
        if node is None:
            raise ValueError("Node not found in the tree") # it checks if the node is in the tree

        elif data < node.data: # if the data is less than the node data, we go to the left subtree
            node.left_node = self.remove_node(data, node.left_node)
        
        elif data > node.data: # if the data is greater than the node data, we go to the right subtree
            node.right_node = self.remove_node(data, node.right_node)
            
        else:
            # Case 1: No children
            if node.left_node is None and node.right_node is None: # it the node has no children, we just remove it and set the parent to None
                node = None
                return

            # Case 2: One child
            elif node.left_node is None: # if the node has one child, we just remove it and set the parent to the child
                return node.right_node

            else:
                if node.right_node is None: # if the node has one child, we just remove it and set the parent to the child
                    return node.left_node
                

            # Case 3: Two children, here we get the predecessor node, replace the node data with the predecessor data and remove the predecessor node
                predecessor_node = self.get_predecessor(node.right_node) 

                node.data = predecessor_node.data
                node.right_node = self.remove_node(predecessor_node.data, node.right_node)
                



    def get_predecessor(self, node): #  we get the predecessor node, which is the right most node in the left subtree
        if node is None:
            return None

        while node.left_node is not None:
            node = node.left_node

        return node

    def traverse(self, node, traversal_type='in_order'):    # we traverse the tree in order, pre order or post order
        if node is None:
            return
       
        if traversal_type == 'pre_order': # we print the node data, then we go to the left subtree and then to the right subtree
           print(node.data)
           self.traverse(node.left_node, traversal_type)
           self.traverse(node.right_node, traversal_type)

        elif traversal_type == 'in_order': # we go to the left subtree, print the node data and then we go to the right subtree
            self.traverse(node.left_node, traversal_type)
            print(node.data)
            self.traverse(node.right_node, traversal_type)
        
        elif traversal_type == 'post_order': # we go to the left subtree, then to the right subtree and then we print the node data
            self.traverse(node.left_node, traversal_type)
            self.traverse(node.right_node, traversal_type)
            print(node.data)
        
        else:
            print("Traversal type " + str(traversal_type) + " is not supported.") # if the traversal type is not supported, we print this message
            return

        


In [7]:
tree = BinarySearchTree() #we create a tree
# we insert nodes
tree.insert(5) 
tree.insert(3)
tree.insert(6)
tree.insert(1)
tree.insert(2)
tree.insert(4)
tree.insert(7)
tree.insert(10)

# we traverse the tree (in order, pre order and post order) to check all the traversal types work
tree.traverse(tree.root, 'pre_order')
print("----")
tree.traverse(tree.root, 'in_order')
print("----")
tree.traverse(tree.root, 'post_order')
print("----")
# we remove nodes to check if the remove method works
tree.remove_node(5, tree.root)
tree.traverse(tree.root, 'in_order')
print("----")
# we remove nodes to check again if the remove method works (Double checking never hurts :D)
tree.remove_node(7, tree.root)
tree.traverse(tree.root, 'in_order')
print("----")



5
3
1
2
4
6
7
10
----
1
2
3
4
5
6
7
10
----
2
1
4
3
10
7
6
5
----
1
2
3
4
6
7
10
----
1
2
3
4
6
10
----
