# Binary Search Tree

![image](./images/bst_01.png)

#### Define Node class

In [4]:
# this code makes the tree that we'll traverse

class Node(object):
        
    def __init__(self,value = None):
        self.value = value
        self.left = None
        self.right = None
        
    def set_value(self,value):
        self.value = value
        
    def get_value(self):
        return self.value
        
    def set_left_child(self,left):
        self.left = left
        
    def set_right_child(self, right):
        self.right = right
        
    def get_left_child(self):
        return self.left
    
    def get_right_child(self):
        return self.right

    def has_left_child(self):
        return self.left != None
    
    def has_right_child(self):
        return self.right != None
    
    # define __repr_ to decide what a print statement displays for a Node object
    def __repr__(self):
        return f"Node({self.get_value()})"
    
    def __str__(self):
        return f"Node({self.get_value()})"


In [5]:
from collections import deque
class Queue():
    def __init__(self):
        self.q = deque()
        
    def enq(self,value):
        self.q.appendleft(value)
        
    def deq(self):
        if len(self.q) > 0:
            return self.q.pop()
        else:
            return None
    
    def __len__(self):
        return len(self.q)
    
    def __repr__(self):
        if len(self.q) > 0:
            s = "<enqueue here>\n_________________\n" 
            s += "\n_________________\n".join([str(item) for item in self.q])
            s += "\n_________________\n<dequeue here>"
            return s
        else:
            return "<queue is empty>"

#### Define insert

Let's assume that duplicates are overriden by the new node that is to be inserted.  Other options are to keep a counter of duplicate nodes, or to keep a list of duplicates nodes with the same value.

In [6]:
class Tree():
    def __init__(self):
        self.root = None
        
    def set_root(self,value):
        self.root = Node(value)
        
    def get_root(self):
        return self.root
    
    def compare(self, node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node 
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1
        else:
            return 1
    
    """
    define insert here
    can use a for loop (try one or both ways)
    """
    def insert_with_loop(self, new_value):
        # initialize new node 
        node = Node(new_value)
        
        if not self.root:
            self.root = node
        else:
            # iterate until we reach a leaf node 
            curr = self.root 
            prev = self.root
            
            while curr: 
                if curr.value == new_value:
                    # replace duplicate with newest duplicate value
                    # don't set it to the new node because this will break
                    # the current nodes connection to its children 
                    curr.value = new_value
                    return 
                elif new_value < curr.value: 
                    # traverse left 
                    prev = curr
                    curr = curr.left
                else: 
                    # traverse right 
                    prev = curr
                    curr = curr.right 
                    
            # prev should point to our new nodes parent 
            if new_value < prev.value:
                prev.left = node 
            else:
                prev.right = node 

    """
    define insert here (can use recursion)
    try one or both ways
    """  
    def insert_with_recursion(self, value):
        if not self.root:
            self.root = Node(value)
        else:
            self.recurse(self.root, value)
        
        
    def recurse(self, node, value):
        if not node:
            return 
        
        # if we have a duplicate, replace with latest node 
        if node.value == value:
            node.value = value
        # traverse left 
        elif value < node.value:
            if node.left:
                self.recurse(node.left, value)
            else:
                node.left = Node(value)
        # traverse right 
        else:
            if node.right:
                self.recurse(node.right, value)
            else: 
                node.right = Node(value)
                    
    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node) 
            else:
                s += "\n" + str(node)
                previous_level = level

                
        return s


In [7]:
tree = Tree()
tree.insert_with_loop(5)
tree.insert_with_loop(6)
tree.insert_with_loop(4)
tree.insert_with_loop(2)
tree.insert_with_loop(5) # insert duplicate
print(tree)

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>


In [8]:
tree = Tree()
tree.insert_with_recursion(5)
tree.insert_with_recursion(6)
tree.insert_with_recursion(4)
tree.insert_with_recursion(2)
tree.insert_with_recursion(5) # insert duplicate
print(tree)

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>


## Search

Define a search function that takes a value, and returns true if a node containing that value exists in the tree, otherwise false.

In [33]:
class Tree():
    def __init__(self):
        self.root = None
        
    def set_root(self,value):
        self.root = Node(value)
        
    def get_root(self):
        return self.root
    
    def compare(self,node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node 
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1
        else:
            return 1
    
    def insert(self,new_value):
        new_node = Node(new_value)
        node = self.get_root()
        if node == None:
            self.root = new_node
            return
        
        while(True):
            comparison = self.compare(node, new_node)
            if comparison == 0:
                # override with new node
                node = new_node
                break # override node, and stop looping
            elif comparison == -1:
                # go left
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    node.set_left_child(new_node)
                    break #inserted node, so stop looping
            else: #comparison == 1
                # go right
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    node.set_right_child(new_node)
                    break # inserted node, so stop looping
                    
    """
    implement search
    """
    def search(self, value):
        if not self.root:
            return False 
        
        #return self.search_recurse(self.root, value)

        curr = self.root
    
        while curr:
            if curr.value == value:
                return True
            elif curr.value > value:
                if curr.left:
                    curr = curr.left
                else:
                    return False
            else:
                if curr.right:
                    curr = curr.right
                else:
                    return False
                
                
    def search_recurse(self, node, value):
        if not node:
            return False 
        
        if node.value == value:
            return True 
        
        return self.search_recurse(node.left, value) or self.search_recurse(node.right, value)
                    
    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node) 
            else:
                s += "\n" + str(node)
                previous_level = level

                
        return s
    
#     def delete(self, value):
#         if self.root:                
#             parent = None 
#             curr = self.root 
            
#             # traverse the tree until the end OR until we find the value we want to remove
#             while curr and curr.value != value:
#                 parent = curr 
                
#                 # traverse left
#                 if value < curr.value:
#                     curr = curr.left 
#                 else:
#                     # traverse right 
#                     curr = curr.right 
                
#             # exited because curr.value == value 
#             if curr:
#                 # CASE 1: node has no children 
#                 if not curr.left and not curr.right:
#                     # check which parent pointer needs to be updated, and update
#                     if parent.left == curr: 
#                         parent.left = None 
#                     else:
#                         parent.right = None 
#                     curr = None 
#                 # CASE 2: node has only one child 
#                 # has only left child: 
#                 elif curr.left and not curr.right:
#                     # if node to delete is on the left, point parent.left to the child node
#                     if value < parent.value:
#                         parent.left = curr.left
#                     else:
#                         # node to delete is on the right, point parent.right to child node 
#                         parent.right = curr.left                     
#                     curr = None 
#                 # has only right child 
#                 elif not curr.left and curr.right:
#                     # if node to delete is on the left, point parent.left to the child node
#                     if value < parent.value:
#                         parent.left = curr.right
#                     else:
#                         # node to delete is on the right, point parent.right to child node 
#                         parent.right = curr.right                     
#                     curr = None                     
#                 # CASE 3: node has two child nodes 
#                 else:
#                     # find next highest value 
#                     next_heightest_node_parent = curr
#                     next_heightest_node = curr.right 
                    
#                     while next_heightest_node.left: 
#                         next_heightest_node_parent = next_heightest_node
#                         next_heightest_node = next_heightest_node.left
                        
#                     # update curr value to next height value 
#                     curr.value = next_heightest_node.value 
                    
#                     # special case where removing root 
#                     if next_heightest_node_parent == tree.root:
#                         next_heightest_node_parent.right = next_heightest_node.right
#                     else:
#                         # update parent.left point to next heighest nodes right child
#                         # it doesn't have a left child, so that's not a consideration 
#                         next_heightest_node_parent.left = next_heightest_node.right 
#                     next_heightest_node = None 
    
    def delete(self, value):
        return self.d_recurse(self.root, value)
    
    def d_recurse(self, node, value):
        if not node:
            return
        
        if value < node.value:
            node.left = self.d_recurse(node.left, value)
        elif value > node.value:
            node.right = self.d_recurse(node.right, value)
        else:
            if not node.left:
                temp = node.right
                node = None
            elif not node.right:
                temp = node.left
                node = None
            else:
                min_val_node = node.right
                while min_val_node.left:
                    min_val_node = min_val_node.left

                node.value = min_val_node.value
                node.right = self.d_recurse(node.right, min_val_node.value)
    

In [34]:
tree = Tree()
tree.insert(5)
tree.insert(6)
tree.insert(4)
tree.insert(2)

print(f"""
search for 8: {tree.search(8)}
search for 2: {tree.search(2)}
""")
print(tree)

# tree.delete(5)
# tree.delete(6)
tree.delete(2)
# tree.delete(4)
# tree.delete(20)
print(tree)


search for 8: False
search for 2: True

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>
Tree

Node(5)
<empty> | Node(6)
<empty> | <empty>


## Bonus: deletion

Try implementing deletion yourself.  You can also check out this explanation [here](https://www.geeksforgeeks.org/binary-search-tree-set-2-delete/)

## Solution notebook
The solution for insertion and search is [here](04 binary_search_tree_solution.ipynb)