The BSTree data structure
=========================



## Agenda



-   API
-   Implementation

-   Search
-   Addition
-   Removal
-   Iteration / Traversal



## API



In [1]:
class BSTree:
    class Node:
        def __init__(self, val, left=None, right=None):
            self.val = val
            self.left = left
            self.right = right

    def __init__(self):
        self.size = 0
        self.root = None

    def add(self, val):
        """Adds `val` to this tree while maintaining BSTree properties."""
        assert(val not in self)
        def rec_add(r,val):
            if(r.val > val):
                if(r.left):
                    rec_add(r.left,val)
                else:
                    r.left = BSTree.Node(val)
            elif(r.val < val):
                if(r.right):
                    rec_add(r.right,val)
                else:
                    r.right = BSTree.Node(val)
            
        if self.root:
            rec_add(self.root, val)
        else:               
            self.root = Node(val)
        self.size += 1

    def __contains__(self, val):
        """Returns `True` if val is in this tree and `False` otherwise."""
        def rec_contains(r,val):
            if r:
                if(r.val == val):
                    return True
                elif(r.val > val):
                    return rec_contains(r.left,val)
                else:
                    return rec_contains(r.right,val)
        return rec_contains(self.root,val)

    def __delitem__(self, val):
        """Removes `val` from this tree while maintaining BSTree properties."""
        assert(val in self)
        def rec_del(parent,isleft,t,val):
            if t.val > val:
                rec_del(t, True, t.left, val)
            elif t.val < val:
                rec_del(t, False, t.right,val)
            else:
                if t.left and t.right: # node has two children, replace with largest value in the left subtree and then delete that one
                    replaceval = t.left.tmax()
                    t.val = replaceval
                    rec_del(t, True, t.left, replaceval)
                elif t.left: # replace node with its only child (the left one)
                    t.val = t.left.val
                    t.right = t.left.right
                    t.left = t.left.left
                elif t.right: # replace node with its only child (the right one)
                    t.val = t.right.val
                    t.left = t.right.left
                    t.right = t.right.right
                else:
                    if parent:
                        if isleft:
                            parent.left = None
                        else:
                            parent.right = None

        rec_del(None,None,self.root,val)

        self.size += -1


    def __iter__(self): ### in order
        """Returns an iterator over all the values in the tree, in ascending order."""
        def rec_iter(r):
            if r:
                yield from rec_iter(r.left)
                yield r.val
                yield from rec_iter(r.right)

        yield from rec_iter(self.root)

    def __len__(self):
        return self.size

    def pprint(self, width=64):
        """Attempts to pretty-print this tree's contents."""
        height = self.height()
        nodes  = [(self.root, 0)]
        prev_level = 0
        repr_str = ''
        while nodes:
            n,level = nodes.pop(0)
            if prev_level != level:
                prev_level = level
                repr_str += '\n'
            if not n:
                if level < height-1:
                    nodes.extend([(None, level+1), (None, level+1)])
                repr_str += '{val:^{width}}'.format(val='-', width=width//2**level)
            elif n:
                if n.left or level < height-1:
                    nodes.append((n.left, level+1))
                if n.right or level < height-1:
                    nodes.append((n.right, level+1))
                repr_str += '{val:^{width}}'.format(val=n.val, width=width//2**level)
        print(repr_str)

    def height(self):
        """Returns the height of the longest branch of the tree."""
        def height_rec(t):
            if not t:
                return 0
            else:
                return 1 + max(height_rec(t.left), height_rec(t.right))
        return height_rec(self.root)

In [2]:
t = BSTree()
t.root = BSTree.Node(5,
                    left=BSTree.Node(2),
                    right=BSTree.Node(10))
t.size = 3

In [3]:
t.pprint()

                               5                                
               2                               10               


In [1]:
t.height()

## Implementation



### Search



In [1]:
class BSTree(BSTree):
    def __contains__(self, val):
        pass

In [1]:
t = BSTree()
t.root = BSTree.Node(5,
                    left=BSTree.Node(2),
                    right=BSTree.Node(10))
t.size = 3

In [4]:
5 in t

True

### Addition



In [1]:
class BSTree(BSTree):
    def add(self, val):
        assert(val not in self)
        pass

In [1]:
import random
t = BSTree()
vals = list(range(5))
random.shuffle(vals)
for x in vals:
    t.add(x)
t.pprint()

### Removal



In [1]:
class BSTree(BSTree):
    def __delitem__(self, val):
        assert(val in self)
        # deal with relatively simple cases first!
        pass

In [1]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.add(x)
t.pprint()

In [1]:
del t[2]
t.pprint()

In [1]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.add(x)
t.pprint()

In [1]:
del t[5]
t.pprint()

In [1]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.add(x)
t.pprint()

In [1]:
del t[15]
t.pprint()

In [1]:
t = BSTree()
for x in [10, 5, 15, 2, 17]:
    t.add(x)
t.pprint()

In [1]:
del t[10]
t.pprint()

In [1]:
class BSTree(BSTree):
    def __delitem__(self, val):
        # fully working delete
        assert(val in self)
        pass

In [1]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.add(x)
t.pprint()

In [1]:
del t[15]
t.pprint()

In [1]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.add(x)
t.pprint()

In [1]:
del t[5]
t.pprint()

In [1]:
t = BSTree()
for x in [10, 5, 2, 7, 9, 8, 1, 15, 12, 18]:
    t.add(x)
t.pprint()

In [1]:
del t[10]
t.pprint()

### Iteration / Traversal



In [1]:
class BSTree(BSTree):
    def __iter__(self):
        pass

In [1]:
import random
t = BSTree()
vals = list(range(20))
random.shuffle(vals)
for x in vals:
    t.add(x)
for x in t:
    print(x)

In [None]:
class AVLTree(BSTree):
    class Node:
        def __init__(self, val, left=None, right=None, bf=0):
            self.val = val
            self.left = left
            self.right = right
            self.bf = bf
    
        def rotate_right(self):
            y = self.left
            self.left = y.right
            y.right = self
            return y



In [None]:
t = AVLTree(BS)