## Binary Search Trees (BSTs)

- Fast search, insertion, deletion
- max, min, successor, predeccessor ops
- Dumey, 1952 NSA first published
- Wheeler 1957

avg search, insert, delete: O(log n)

Worst case search, insert, delete: $\Theta(n)$



In [4]:
import math, random

In [128]:
class GNode:
    def __init__(self, key):
        self.key = key
          
    def __repr__(self):
        return (
            f"GNode( key={self.key} )"
        )  

class Group:
    def __init__(self, keys=[]):
        self._list = []
        for k in keys:
            self.insert(k)
            
    def insert(self, key):
        self._list.append( GNode(key) )
    
    def maximum(self):
        return max(self._list, key=lambda n: n.key)
        
    def minimum(self):
        return min(self._list, key=lambda n: n.key)
    
    def pred(self, node):
        pred = None
        
        for n in self._list:
            if n.key < node.key and (pred is None or n.key > pred.key):
                pred = n
                
        return pred

    def succ(self, node):
        succ = None
        
        for n in self._list:
            if n.key > node.key and (pred is None or n.key < pred.key):
                succ = n
                
        return succ    
    
    def lookup(self, key):        
        return ([n for n in self._list if n.key == key] + [None])[0]
    
    def delete(self, node):
        self._list.remove(node)

In [12]:
class BSTNode:
    def __init__(self, key, parent=None):
        self.key = key
        self.parent = parent
        self.right = self.left = None
    
    def __repr__(self):
        return (
            f"BSTNode( key={self.key}, "
            f"parent={self.parent.key if self.parent else None}, "
            f"right={self.right.key if self.right else None}, "
            f"left={self.left.key if self.left else None} )"
        )

    
class BST:
    
    def __init__(self, keys=[]):
        self.root = None
        for k in keys:
            self.insert(k)
        
    def __repr__(self):
        out = ""
        layer = [self.root]
        while [x for x in layer if x]:
            out += str([node.key if node else None for node in layer]) + "\n"
            layer = self.build_layer( layer )
        return out
    
    def build_layer( self, prev_layer ):
        return [ x for node in prev_layer for x in ([node.left, node.right] if node else [None, None]) ]
    
    def insert(self, key):
        
        if self.root is None:
            self.root = BSTNode(key)
            return self.root
        else:
            node = self.root
            while node:
                if key < node.key:
                    if node.left:
                        node = node.left
                    else:
                        node.left = BSTNode(key, parent=node)
                        return node.left
                if key >= node.key:
                    if node.right:
                        node = node.right
                    else:
                        node.right = BSTNode(key, parent=node)
                        return node.right
        
            
    def maximum(self, node=None):
        if not node:
            node = self.root
    
        return self.maximum(node.right) if node.right else node
        
    def minimum(self, node=None):
        if not node:
            node = self.root
            
        return self.minimum(node.left) if node.left else node
    
    def pred(self, node):
        if node.left:
            return self.maximum(node.left)
        while node.parent:
            if node is node.parent.right:
                break
            node = node.parent
        return node.parent
    
    def succ(self, node):
        if node.right:
            return self.minimum(node.right)
        while node.parent:
            if node is node.parent.left:
                break
            node = node.parent
        return node.parent
    
    def lookup(self, key):
        return self.lookup_inner( key, self.root )
    def lookup_inner(self, key, search_node):
        if search_node is None:
            return None
        if search_node.key == key:
            return search_node
        if key < search_node.key:
            return self.lookup_inner(key, search_node.left)
        if key >= search_node.key:
            return self.lookup_inner(key, search_node.right)
        
    def shift_nodes( self, old_node, new_node ):
        
        if old_node is self.root:
            self.root = new_node
        
        elif old_node is old_node.parent.right:
            old_node.parent.right = new_node
            
        elif old_node is old_node.parent.left:
            old_node.parent.left = new_node
        
        if new_node:
            new_node.parent = old_node.parent
        
        
    def delete(self, node):
        if node.left is None:
            return self.shift_nodes( node, node.right )
        if node.right is None:
            return self.shift_nodes( node, node.left )

        succ = self.succ(node)

        if succ is not node.right:
            self.shift_nodes(succ, succ.right)
            succ.right = node.right
            succ.right.parent = succ

        self.shift_nodes(node, succ)
        succ.left = node.left
        succ.left.parent = succ
            
            
    def inorder(self, node):
        if node is None:
            return []
        
        return [ *self.inorder(node.left), node.key, *self.inorder(node.right) ]
    
    
    def depth(self, node="root"):
        if node == "root":
            node = self.root
        if node is None:
            return 0
        return max( self.depth(node.left), self.depth(node.right) ) + 1
    

In [106]:
red   = "red"
black = "black"

class RBTNode:
    def __init__(self, key, parent=None, colour=None):
        self.key = key
        self.parent = parent
        self.colour = colour
        self.right = self.left = None
    
    def __repr__(self):
        return (
            f"RBTNode( key={self.key}, "
            f"colour={self.colour}, "
            f"parent={self.parent.key if self.parent else None}, "
            f"right={self.right.key if self.right else None}, "
            f"left={self.left.key if self.left else None} )"
        )
    
    @property
    def grandparent(self):
        return self.parent.parent
    
    @property
    def uncle(self):
        olders = [ self.grandparent.right, self.grandparent.left ]
        return [ older for older in olders if older is not self.parent ][0]
    

class RBT:
    def __init__(self, keys=[]):
        self.root = None
        for k in keys:
            self.insert(k)
            
    def insert(self, key):
        
        if self.root is None:
            self.root = RBTNode(key, colour=black)
            return self.root
        else:
            node = self.root
            while node:
                if key < node.key:
                    if node.left:
                        node = node.left
                    else:
                        node.left = RBTNode(key, parent=node, colour=red)
                        self.fix(node.left)
                        return node.left
                if key >= node.key:
                    if node.right:
                        node = node.right
                    else:
                        node.right = RBTNode(key, parent=node, colour=red)
                        self.fix(node.right)
                        return node.right
                    
                    
    def fix(self, node):
        """Fix any Red-Black Tree insert violations.

        Args:
            node: the node that was inserted.
        """
        while node != self.root and node.parent.colour == red:
            
            if node.parent == node.parent.parent.left:
                uncle = node.parent.parent.right
                if uncle and uncle.colour == red:
                    node.parent.colour = black
                    uncle.colour = black
                    node.parent.parent.colour = red
                    node = node.parent.parent
                else:
                    if node == node.parent.right:
                        node = node.parent
                        self.rotate_left(node)
                    node.parent.colour = black
                    node.parent.parent.colour = red
                    self.rotate_right(node.parent.parent)
                    
            else:
                uncle = node.parent.parent.left
                if uncle and uncle.colour == red:
                    node.parent.colour = black
                    uncle.colour = black
                    node.parent.parent.colour = red
                    node = node.parent.parent
                else:
                    if node == node.parent.left:
                        node = node.parent
                        self.rotate_right(node)
                    node.parent.colour = black
                    node.parent.parent.colour = red
                    self.rotate_left(node.parent.parent)
                    
        self.root.colour = black
        
        
    def fix_(self, node):
        if node is self.root:
            node.colour = black
            return
        if node.parent.colour is black:
            node.colour = red
            return
        if node.parent.colour is red:
            if node.uncle and node.uncle.colour is red:
                node.colour = red
                node.parent.colour = black
                node.uncle.colour = black
                self.fix(node.grandparent)
            if node.uncle is None or node.uncle.colour is black:
                node_towards_uncle = (
                    (node.key >= node.parent.key) ^
                    (node.parent.key >= node.grandparent.key)
                )
                parent_gt_grandpa = node.parent.key >= node.grandparent.key
                if node_towards_uncle:
                    node.colour = black
                    node.grandparent.colour = red
                    grandparent = node.grandparent
                    (self.rotate_right if parent_gt_grandpa else self.rotate_left)(node.parent)
                    (self.rotate_left if parent_gt_grandpa else self.rotate_right)(grandparent)
                if not node_towards_uncle:
                    node.colour = red
                    node.parent.colour = black
                    node.grandparent.colour = red
                    (self.rotate_left if parent_gt_grandpa else self.rotate_right)(node.grandparent)
        
            
    def maximum(self, node=None):
        if not node:
            node = self.root
    
        return self.maximum(node.right) if node.right else node
        
    def minimum(self, node=None):
        if not node:
            node = self.root
            
        return self.minimum(node.left) if node.left else node
    
    def pred(self, node):
        if node.left:
            return self.maximum(node.left)
        while node.parent:
            if node is node.parent.right:
                break
            node = node.parent
        return node.parent
    
    def succ(self, node):
        if node.right:
            return self.minimum(node.right)
        while node.parent:
            if node is node.parent.left:
                break
            node = node.parent
        return node.parent
    
    def lookup(self, key):
        return self.lookup_inner( key, self.root )
    def lookup_inner(self, key, search_node):
        if search_node is None:
            return None
        if search_node.key == key:
            return search_node
        if key < search_node.key:
            return self.lookup_inner(key, search_node.left)
        if key >= search_node.key:
            return self.lookup_inner(key, search_node.right)
        
    def shift_nodes( self, old_node, new_node ):
        
        if old_node is self.root:
            self.root = new_node
        
        elif old_node is old_node.parent.right:
            old_node.parent.right = new_node
            
        elif old_node is old_node.parent.left:
            old_node.parent.left = new_node
        
        if new_node:
            new_node.parent = old_node.parent
            
    
    def rotate_left( self, node ):
        
        assert node.right is not None
        
        moved_subtree = node.right.left
        if moved_subtree is not None:
            moved_subtree.parent = node
        
        if node.parent is not None:
            if node.parent.right is node:
                node.parent.right = node.right
            if node.parent.left is node:
                node.parent.left = node.right
        else:
            self.root = node.right
            
        node.right.parent, node.right.left, node.parent, node.right =\
        node.parent, node, node.right, moved_subtree
        
    
    def rotate_right( self, node ):
        
        assert node.left is not None
        
        moved_subtree = node.left.right
        if moved_subtree is not None:
            moved_subtree.parent = node
        
        if node.parent is not None:
            if node.parent.right is node:
                node.parent.right = node.left
            if node.parent.left is node:
                node.parent.left = node.left
        else:
            self.root = node.left
            
        node.left.parent, node.left.right, node.parent, node.left =\
        node.parent, node, node.left, moved_subtree
    
        
    def delete(self, node):
        if node.left is None:
            return self.shift_nodes( node, node.right )
        if node.right is None:
            return self.shift_nodes( node, node.left )

        succ = self.succ(node)

        if succ is not node.right:
            self.shift_nodes(succ, succ.right)
            succ.right = node.right
            succ.right.parent = succ

        self.shift_nodes(node, succ)
        succ.left = node.left
        succ.left.parent = succ
            
            
    def inorder(self, node):
        if node is None:
            return []
        
        return [ *self.inorder(node.left), node.key, *self.inorder(node.right) ]
    
    
    def depth(self, node="root"):
        if node == "root":
            node = self.root
        if node is None:
            return 0
        return max( self.depth(node.left), self.depth(node.right) ) + 1
    

In [7]:
from utils.btree import Btree, Node


btree = Btree(5, root=Node())


random.seed(100)
nums = [int(random.uniform(0, 100)) for _ in range(500)]


for n in nums:
    btree.insert(n)
    

def search( btree, key, node=None ):
    if node is None:
        node = btree.root
        
    for i, nodekey in enumerate(node.keys):
        print(i, nodekey)
        if nodekey == key:
            return (node, i)
        if nodekey > key:
            if node.is_leaf:
                return None
            return search( btree, key, node.children[i] )
    return search( btree, key, node.children[-1] )
    
btree.search(btree.root, 7), search(btree, 7)

0 53
0 8
0 0
1 1
2 3
3 4
4 5
0 6
1 7


((<utils.btree.Node at 0x710185aad0>, 1),
 (<utils.btree.Node at 0x710185aad0>, 1))

In [169]:
class HTNode:
    def __init__(self, key, value=None):
        self.key = key
        self.value = key if value is None else key

        
array_sizes = [
    23,
    47,
    97,
    197,
    397,
    797,
    1597,
    3203,
    6421,
    12853,
    25717,
    51437,
    102877,
    205759,
    411527,
    882377,
    1441049,
]

class HashTableSlow:
    def __init__(self):
        self.upsizes = 0
        self._array_size = array_sizes[self.upsizes]
        self._array = [ [] for _ in range(self._array_size) ]
        self.load = 0
        
    def _hash(self, key):
        
        a = 89891
        b = 51383
        p = 116981
        phi = 1.618033988749895
        
        return int( self._array_size * ((phi * key) % 1) )
        #return (key * p) % self._array_size
        #return int(( a*key + b ) % p) % self._array_size
    
    def insert( self, key, value=None ):
        index = self._hash(key)
        for sub_index, node in enumerate(self._array[index]):
            if node.key == key:
                node.value = value
                return
        else:
            self.load += 1
            self._array[index].append( HTNode(key, value) )
            if self.load / self._array_size > 2/3:
                self.upsize()
    __setitem__ = insert
    
    def upsize(self):
        self.upsizes += 1
        entries = [ node for slot in self._array for node in slot ]
        
        self._array_size = array_sizes[self.upsizes]
        self._array = [ [] for _ in range(self._array_size) ]
        self.load = 0
        
        for node in entries:
            self[node.key] = node.value
    
    def delete( self, key ):
        index = self._hash(key)
        for sub_index, node in enumerate(self._array[index]):
            if node.key == key:
                self._array[index].pop(sub_index)
                self.load -= 1
                return
    __delitem__ = delete
    
    def lookup( self, key ):
        index = self._hash(key)
        for node in self._array[index]:
            if node.key == key:
                return node
    __getitem__ = lookup

In [252]:
empty = "empty"
deleted = "deleted"

class HTNode:
    def __init__(self, key, value=None):
        self.key = key
        self.value = key if value is None else value

        
array_sizes = [
    23,
    47,
    97,
    197,
    397,
    797,
    1597,
    3203,
    6421,
    12853,
    25717,
    51437,
    102877,
    205759,
    411527,
    882377,
    1441049,
    2454587,
]

class HashTable:
    def __init__(self):
        self.upsizes = 0
        self._array_size = array_sizes[self.upsizes]
        self._array = [ empty for _ in range(self._array_size) ]
        self.load = 0
        
    def _hash(self, key):
        
        a = 89891
        b = 51383
        p = 116981
        phi = 1.618033988749895
        
        return int( self._array_size * ((phi * key) % 1) )
        #return (key * p) % self._array_size
        #return int(( a*key + b ) % p) % self._array_size
    
    def _hash2(self, key):
        
        phi = 1.618033988749895
        
        return int( self._array_size/2 * ((phi * (key+.1)) % 1) ) + 1
        
    
    def insert( self, key, value=None ):
        index = self._hash(key)
        hashdex = self._hash2(key)
        offsets = 0
        while True:
            slot = self._array[index]
            if slot is empty:
                self.load += 1
                break
            elif slot is deleted or slot.key != key:
                offsets += 1
                index = (index + offsets * hashdex) % self._array_size
            else:
                break
                
        self._array[index] = HTNode(key, value)
        if self.load / self._array_size > 2/3:
            self.upsize()
    __setitem__ = insert
    
    def upsize(self):
        self.upsizes += 1
        entries = [ node for node in self._array 
                   if node is not empty and node is not deleted ]
        
        self._array_size = array_sizes[self.upsizes]
        self._array = [ empty for _ in range(self._array_size) ]
        self.load = 0
        
        for node in entries:
            self[node.key] = node.value
    
    def delete( self, key ):
        index = self._hash(key)
        hashdex = self._hash2(key)
        offsets = 0
        while True:
            slot = self._array[index]
            if slot is empty:
                return None
            elif slot is deleted or slot.key != key:
                offsets += 1
                index = (index + offsets * hashdex) % self._array_size
            else:
                self._array[index] = deleted
                return
    __delitem__ = delete
    
    def lookup( self, key ):
        index = self._hash(key)
        hashdex = self._hash2(key)
        offsets = 0
        while True:
            slot = self._array[index]
            if slot is empty:
                return None
            elif slot is deleted or slot.key != key:
                offsets += 1
                index = (index + offsets * hashdex) % self._array_size
            else:
                return slot
    __getitem__ = lookup

In [4]:
h = HashTable()
h[0] = 1
h[0].value

1

In [132]:
class PHTNode:
    def __init__(self, key, value=None):
        self.key = key
        self.value = key if value is None else value


class PythonHashTable:
    def __init__(self):
        self.d = {}
    
    def insert(self, key):
        self.d[key] = HTNode(key)
    
    def lookup(self, key):
        return self.d[key]

In [14]:
def primes():
    ps = []
    i = 2
    printed = 0
    
    while i < 2e3:
        for p in (p for p in ps if p <= i**.5):
            if not i % p:
                break
        else:
            ps.append(i)
            if i > printed * 2:
                print(i)
                printed = i
        i += 1

primes()

2
5
11
23
47
97
197
397
797
1597


In [239]:
import random

random.seed(100)
nums = [random.uniform(0, 100) for _ in range(500)]


def benchmark( DataStructure ):

    random.seed(100)

    data_structure = DataStructure( nums )
    total = 0
    
    for _ in range(100):
        num = random.uniform(0, 100)
        
        data_structure.insert(num)
        node = data_structure.lookup(num)
        pred = data_structure.pred(node)
        if pred and False:
            data_structure.delete(pred)
        total += data_structure.maximum().key
    
    return total


def simple_benchmark( DataStructure ):
    
    random.seed(100)

    data_structure = DataStructure()
    total = 0
    
    for _ in range(10000):
        num = random.uniform(0, 100)
        
        data_structure.insert(num)
        total += data_structure.lookup(num).key
    
    return total


keys = [random.uniform(0, 100) for _ in range(50000)]

def insert_benchmark( DataStructure ):
    
    data_structure = DataStructure()
    
    for k in keys:
        data_structure.insert(k)
        
    return data_structure
    

def lookup_benchmark( data_structure ):
    
    total = 0
    
    for key in keys:
        total += data_structure.lookup(key).key
        
    return total

In [120]:
%%time
#benchmark( RBT )
simpleBenchmark( HashTable )

CPU times: user 107 ms, sys: 3.96 ms, total: 111 ms
Wall time: 111 ms


501321.48750235647

In [279]:
%%time

#g = insert_benchmark( Group )
#bst = insert_benchmark( BST )
#rbt = insert_benchmark( RBT )
ht = insert_benchmark( HashTable )
#pht = insert_benchmark( PythonHashTable )

CPU times: user 539 ms, sys: 11.9 ms, total: 550 ms
Wall time: 557 ms


In [278]:
%%time
lookup_benchmark( ht )

CPU times: user 106 ms, sys: 43 µs, total: 106 ms
Wall time: 103 ms


2499363.539716179