# The BSTree data structure

## Agenda

- API
- Implementation
    - Search
    - Insertion
    - Deletion
    - Iteration / Traversal
- Runtime complexity

## API

In [1]:
class BSTree:
    class Node:
        def __init__(self, val, left=None, right=None):
            self.val = val
            self.left = left
            self.right = right
            
    def __init__(self):
        self.size = 0
        self.root = None
    
    def __contains__(self, val):
        """Returns `True` if val is in this tree and `False` otherwise."""
        pass
    
    def insert(self, val):
        """Inserts `val` into this tree while maintaining BSTree properties."""
        assert val not in self
        pass    

    def __delitem__(self, val):
        """Deletes `val` from this tree while maintaining BSTree properties."""
        assert val in self
        pass
    
    def __iter__(self):
        """Returns an iterator over all the values in the tree, in ascending order."""
        pass

    def __len__(self):
        return self.size
    
    def height(self):
        """Returns the height of the root of the tree."""
        def height_rec(t):
            if not t:
                return 0
            else:
                return 1 + max(height_rec(t.left), height_rec(t.right))
        return height_rec(self.root)

    def pprint(self, width=64):
        """Attempts to pretty-print this tree's contents."""
        height = self.height()
        nodes  = [(self.root, 0)]
        prev_level = 0
        repr_str = ''
        while nodes:
            n,level = nodes.pop(0)
            if prev_level != level:
                prev_level = level
                repr_str += '\n'
            if not n:
                if level < height-1:
                    nodes.extend([(None, level+1), (None, level+1)])
                repr_str += '{val:^{width}}'.format(val='-', width=width//2**level)
            elif n:
                if n.left or level < height-1:
                    nodes.append((n.left, level+1))
                if n.right or level < height-1:
                    nodes.append((n.right, level+1))
                repr_str += '{val:^{width}}'.format(val=n.val, width=width//2**level)
        print(repr_str)

In [2]:
t = BSTree()
t.root = BSTree.Node(5,
                    left=BSTree.Node(2),
                    right=BSTree.Node(10))
t.size = 3

In [3]:
t.height()

2

In [4]:
t.pprint()

                               5                                
               2                               10               


## Implementation

### Search

In [17]:
class BSTree(BSTree):
    def __contains__(self, val):
        def contains_rec(node):
            if node is None:
                return False
            elif node.val == val:
                return True
            elif node.val < val:
                return contains_rec(node.right)
            else:
                return contains_rec(node.left)
            
        return contains_rec(self.root)
    # We can't implement __contains__ recursively since it takes a node, but
    # the starting point is independent of such node

In [18]:
t = BSTree()
t.root = BSTree.Node(5,
                    left=BSTree.Node(2),
                    right=BSTree.Node(10))
t.size = 3

In [19]:
100 in t

False

### Insertion

In [24]:
class BSTree(BSTree):
    def insert(self, val):
        def insert_rec(node): # Returns a whole new tree for use of insert
            if not node:
                return BSTree.Node(val)
            elif node.val > val:
                node.left = insert_rec(node.left)
                return node
            else: # node.val < val
                node.right = insert_rec(node.right)
                return node

        assert val not in self
        self.root = insert_rec(self.root)
        self.size += 1

In [27]:
t = BSTree()
t.insert(10)
t.insert(5)
t.insert(7)
t.insert(100)
t.pprint()

                               10                               
               5                              100               
       -               7               -               -        


In [29]:
import random
t = BSTree()
vals = list(range(5))
random.shuffle(vals)
for x in vals:
    t.insert(x)
t.pprint()

                               4                                
               2                               -                
       0               3               -               -        
   -       1       -       -       -       -       -       -    


### Removal

In [32]:
class BSTree(BSTree):
    def __delitem__(self, val):
        def delitem_rec(node):
            if node.val > val:
                node.left = delitem_rec(node.left)
                return node
            elif node.val < val:
                node.right = delitem_rec(node.right)
                return node
            else: # node.val == val, delete node
                # case 1: node is a leaf
                if not node.left and not node.right:
                    return None
                # case 2: node contains only right subtree
                elif not node.left:
                    return node.right
                # case 3: only left subtree
                elif not node.right:
                    return node.left
                # case 4: both left and right subtrees
                else:
                    # Replacement will be the rightmost node from left subtree, or
                    # leftmost node from right subtree
                    
                    pass

        assert val in self
        self.root = delitem_rec(self.root)
        self.size -= 1

In [30]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.insert(x)
t.pprint()

                               10                               
               5                               15               
       2               -               -               17       


In [31]:
del t[2]
t.pprint()

                               10                               
               5                               15               
       2               -               -               17       


In [33]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.insert(x)
t.pprint()

                               10                               
               5                               15               
       2               -               -               17       


In [34]:
del t[5]
t.pprint()

                               10                               
               2                               15               
       -               -               -               17       


In [35]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.insert(x)
t.pprint()

                               10                               
               5                               15               
       2               -               -               17       


In [36]:
del t[15]
t.pprint()

                               10                               
               5                               17               
       2               -               -               -        


In [37]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.insert(x)
t.pprint()

                               10                               
               5                               15               
       2               -               -               17       


In [38]:
del t[10]
t.pprint()

                               -                                


In [57]:
class BSTree(BSTree):
    def __delitem__(self, val):
        assert val in self
        def delitem_rec(node):
            if val < node.val:
                node.left = delitem_rec(node.left)
                return node
            elif val > node.val:
                node.right = delitem_rec(node.right)
                return node
            else:
                if not node.left and not node.right:
                    return None
                elif node.left and not node.right:
                    return node.left
                elif node.right and not node.left:
                    return node.right
                else:
                    # Replacement will be the rightmost node from left subtree
                    n = node.left
                    if not n.right:
                        node.val = n.val
                        node.left = n.left  # Promote subtree
                    else:
                        pn = n
                        n = n.right
                        while n.right: # Locate node to delete and its parent
                            pn = n
                            n = n.right
                        node.val = n.val
                        pn.right = n.left # Promote left subtree from child to parent
                        # Works since all of these are references dependent on node
                    return node
                        
        self.root = delitem_rec(self.root)
        self.size -= 1

In [40]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.insert(x)
t.pprint()

                               10                               
               5                               15               
       2               7               12              18       
   1       -       -       9       -       -       -       -    
 -   -   -   -   -   -   8   -   -   -   -   -   -   -   -   -  


In [41]:
del t[15]
t.pprint()

                               10                               
               5                               12               
       2               7               -               18       
   1       -       -       9       -       -       -       -    
 -   -   -   -   -   -   8   -   -   -   -   -   -   -   -   -  


In [42]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.insert(x)
t.pprint()

                               10                               
               5                               15               
       2               7               12              18       
   1       -       -       9       -       -       -       -    
 -   -   -   -   -   -   8   -   -   -   -   -   -   -   -   -  


In [43]:
del t[5]
t.pprint()

                               10                               
               2                               15               
       1               7               12              18       
   -       -       -       9       -       -       -       -    
 -   -   -   -   -   -   8   -   -   -   -   -   -   -   -   -  


In [44]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.insert(x)
t.pprint()

                               10                               
               5                               15               
       2               7               12              18       
   1       -       -       9       -       -       -       -    
 -   -   -   -   -   -   8   -   -   -   -   -   -   -   -   -  


In [45]:
del t[10]
t.pprint()

                               9                                
               5                               15               
       2               7               12              18       
   1       -       -       8       -       -       -       -    


### Iteration / Traversal

In [46]:
class BSTree(BSTree):
    def __iter__(self):
        def iter_rec(node):
            if node:
                yield from iter_rec(node.left)
                yield node.val
                yield from iter_rec(node.right)
        return iter_rec(self.root)

In [53]:
import random
t = BSTree()
vals = list(range(20))
random.shuffle(vals)
for x in vals:
    t.insert(x)
t.pprint()
for x in t:
    print(x)

                               10                               
               5                               15               
       3               6               11              17       
   0       4       -       8       -       12      16      19   
 -   1   -   -   -   -   7   9   -   -   -   14  -   -   18  -  
- - - 2 - - - - - - - - - - - - - - - - - - 13- - - - - - - - - 
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19


### Iteration by Level

In [54]:
class BSTree(BSTree):
    def __iter__(self):
        queue = [self.root]
        while queue:
            node = queue.pop(0)
            yield node.val
            if node.left:
                queue.append(node.left)
            if node.right:
                queue.append(node.right)

In [55]:
import random
t = BSTree()
vals = list(range(10))
random.shuffle(vals)
for x in vals:
    t.insert(x)
t.pprint()

                               4                                
               2                               9                
       1               3               8               -        
   0       -       -       -       7       -       -       -    
 -   -   -   -   -   -   -   -   5   -   -   -   -   -   -   -  
- - - - - - - - - - - - - - - - - 6 - - - - - - - - - - - - - - 


In [56]:
for x in t:
    print(x)

4
2
9
1
3
8
0
7
5
6


## Runtime Complexity

The runtime complexity of the search, insert, and delete methods of the binary search tree are dependent, ultimately, on the depth of their recursive implementation. The depth of recursion is in turn dependent on the height of the binary search tree.

Given $N$ nodes, the height of a binary search tree is, in the worst case = $N$

This gives us the following worst-case runtime complexities:

- Search = $O(N)$
- Insert = $O(N)$
- Delete = $O(N)$

How can we improve this runtime complexity? What should be our target runtime complexity? $O(\log{N})$.