# The BSTree data structure

## Agenda

- API
- Implementation
    - Addition
    - Search
    - Removal
    - Iteration / Traversal

## API

In [35]:
class BSTree:
    class Node:
        def __init__(self, val, left=None, right=None):
            self.val = val
            self.left = left
            self.right = right
            
    def __init__(self):
        self.size = 0
        self.root = None
    
    def add(self, val):
        """Adds `val` to this tree while maintaining BSTree properties."""
        assert(val not in self)
        pass
    
    def __contains__(self, val):
        """Returns `True` if val is in this tree and `False` otherwise."""
        pass
    
    def __delitem__(self, val):
        """Removes `val` from this tree while maintaining BSTree properties."""
        assert(val in self)
        pass
    
    def __iter__(self):
        """Returns an iterator over all the values in the tree, in ascending order."""
        pass

    def __len__(self):
        return self.size

    def pprint(self, width=64):
        """Attempts to pretty-print this tree's contents."""
        height = self.height()
        nodes  = [(self.root, 0)]
        prev_level = 0
        repr_str = ''
        while nodes:
            n,level = nodes.pop(0)
            if prev_level != level:
                prev_level = level
                repr_str += '\n'
            if not n:
                if level < height-1:
                    nodes.extend([(None, level+1), (None, level+1)])
                repr_str += '{val:^{width}}'.format(val='-', width=width//2**level)
            elif n:
                if n.left or level < height-1:
                    nodes.append((n.left, level+1))
                if n.right or level < height-1:
                    nodes.append((n.right, level+1))
                repr_str += '{val:^{width}}'.format(val=n.val, width=width//2**level)
        print(repr_str)
    
    def height(self):
        """Returns the height of the longest branch of the tree."""
        def height_rec(t):
            if not t: # empty tree returns 0
                return 0
            else:
                return max(1+height_rec(t.left), 1+height_rec(t.right))
        return height_rec(self.root)

In [None]:
t = BSTree()
t.root = BSTree.Node(5,
                    left = BSTree.Node(2),
                    right = BSTree.Node(10,
                                       left = BSTree.Node(7),
                                       right = BSTree.Node(15)))
t.size = 3

In [None]:
t.pprint()

In [None]:
t.height()

## Implementation

### Addition

In [36]:
class BSTree(BSTree):
    def add(self, val):
        def add_rec(t):
            assert(val not in self)
            if t is None:
                t = BSTree.Node(val)
            elif val < t.val:
                t.left = add_rec(t.left)
            else: # val > t.val
                t.right = add_rec(t.right)
            return t
        self.root = add_rec(self.root)
        self.size += 1

In [3]:
t = BSTree()
for x in [10,5,8,12,3,7]:
    t.add(x)

In [4]:
t.pprint()

                               10                               
               5                               12               
       3               8               -               -        
   -       -       7       -       -       -       -       -    


In [9]:
import random
t = BSTree()
vals = list(range(5))
random.shuffle(vals)
for x in vals:
    t.add(x)
t.pprint()

                               2                                
               1                               4                
       0               -               3               -        


In [37]:
import random
t = BSTree()
vals = list(range(1, 10, 2))
random.shuffle(vals)
for x in vals:
    t.add(x)

assert(all(x in t for x in range(1, 10, 2)))
assert(all(x not in t for x in range(0, 12, 2)))

AssertionError: 

### Search

In [38]:
class BSTree(BSTree): 
    def __contains__(self, val):
        def contains_rec(t):
            if t is None:
                return False
            elif t.val == val:
                return True
            elif t.val < val:
                return contains_rec(t.right)
            else: # t.val > x
                return contains_rec(t.left)
        return contains_rec(self.root)
'''
Keep O(log N) vs for loop which is O(N)
for x in self:
    if val == x:
        return True
    else:
        return False
'''

'\nKeep O(log N) vs for loop which is O(N)\nfor x in self:\n    if val == x:\n        return True\n    else:\n        return False\n'

In [None]:
t = BSTree()
t.root = BSTree.Node(5,
                    left = BSTree.Node(2),
                    right = BSTree.Node(10))
t.size = 3

In [None]:
5 in t

In [None]:
2 in t

In [None]:
10 in t

In [None]:
-1 in t

### Removal

In [22]:
class BSTree(BSTree): 
    def __delitem__(self, val):
        assert(val in self)
        # deal with relatively simple cases first!
        def delitem_rec(t):
            if val < t.val:
                t.left = delitem_rec(t.left)
            elif val > t.val:
                t.right = delitem_rec(t.right)
            else: # val == t.val (node containing the value we want to remove)
                if not t.left and not t.right:
                    return None
                elif t.left and not t.right:
                    return t.left
                elif t.right and not t.left:
                    return t.right
                else:
                    pass
            return t
        self.root = delitem_rec(self.root)
        self.size -= 1

In [23]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.add(x)
t.pprint()

                               10                               
               5                               15               
       2               -               -               17       


In [24]:
del t[2]
t.pprint()

                               10                               
               5                               15               
       -               -               -               17       


In [25]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.add(x)
del t[5]
t.pprint()

                               10                               
               2                               15               
       -               -               -               17       


In [26]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.add(x)
del t[15]
t.pprint()

                               10                               
               5                               17               
       2               -               -               -        


In [39]:
class BSTree(BSTree):
    def __delitem__(self, val):
        # fully working delete
        assert(val in self)
        def delitem_rec(t):
            if val < t.val:
                t.left = delitem_rec(t.left)
            elif val > t.val:
                t.right = delitem_rec(t.right)
            else: # val == t.val (node containing the value we want to remove)
                if not t.left and not t.right:
                    return None
                elif t.left and not t.right:
                    return t.left
                elif t.right and not t.left:
                    return t.right
                else:
                    to_del = t.left
                    if not to_del.right:
                        t.left = to_del.left
                    else:
                        p = to_del
                        to_del = to_del.right
                        while to_del.right:
                            p = to_del
                            to_del = to_del.right
                        p.right = to_del.left
                    t.val = to_del.val
            return t
        self.root = delitem_rec(self.root)
        self.size -= 1

In [None]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.add(x)
t.pprint()

In [None]:
del t[15]
t.pprint()

In [None]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.add(x)
del t[5]
t.pprint()

In [None]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.add(x)
del t[10]
t.pprint()

In [41]:
t = BSTree()
for x in [7,4,10,8,2,3,12,9]:
    t.add(x)
del t[7]
t.pprint()

                               4                                
               2                               10               
       -               3               8               12       
   -       -       -       -       -       9       -       -    


### Iteration / Traversal

In [30]:
class BSTree(BSTree):
    def __iter__(self):
        def traverse(t):
            if t is None:
                return
            else:
                yield from traverse(t.left)
                yield t.val
                yield from traverse(t.right)
        yield from traverse(self.root)

In [31]:
import random
t = BSTree()
vals = list(range(20))
random.shuffle(vals)
print(vals)
for x in vals:
    t.add(x)
for x in t:
    print(x)

[3, 2, 1, 14, 16, 6, 8, 17, 11, 13, 19, 10, 12, 5, 0, 4, 15, 9, 7, 18]
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
