In [1]:
from collections.abc import MutableMapping

# Vanilla Binary Search Tree

In [2]:
class BinarySearchTreeMap(MutableMapping):
    class _Item:
        __slots__ = "_key", "_value"

        def __init__(self, k, v):
            self._key = k
            self._value = v

        def __eq__(self, other):
            return self._key < other._key

        def __ne__(self, other):
            return not (self == other)

    class _Node:
        __slots__ = "_parent", "_left", "_right", "_item"

        def __init__(self, parent, left, right, item):
            self._parent = parent
            self._left = left
            self._right = right
            self._item = item

        def key(self):
            return self._item._key

        def value(self):
            return self._item._value

    def __init__(self):
        self._root = None
        self._size = 0

    def __len__(self):
        return self._size

    def is_empty(self):
        return len(self) == 0

    def is_root(self, node):
        return self.root() is node

    def is_leaf(self, node):
        return self.num_children(node) == 0

    def root(self):
        return self._root

    def parent(self, node):
        return node._parent

    def left(self, node):
        return node._left

    def right(self, node):
        return node._right

    def num_children(self, node):
        count = 0
        if self.left(node):
            count += 1
        if self.right(node):
            count += 1
        return count

    def children(self, node):
        if self.left(node):
            yield self.left(node)
        if self.right(node):
            yield self.right(node)

    def sibling(self, node):
        parent = self.parent(node)
        if parent:
            if node is self.left(parent):
                return self.right(parent)
            return self.left(parent)
        return None

    def _add_root(self, k, v):
        if self.root():
            raise ValueError("Root exists")
        self._root = self._Node(None, None, None, self._Item(k, v))
        self._size += 1
        return self._root

    def _add_left(self, node, k, v):
        if self.left(node):
            raise ValueError("Left child exists")
        node._left = self._Node(node, None, None, self._Item(k, v))
        self._size += 1
        return node._left

    def _add_right(self, node, k, v):
        if self.right(node):
            raise ValueError("Right child exists")
        node._right = self._Node(node, None, None, self._Item(k, v))
        self._size += 1
        return node._right

    def _replace(self, node, k, v):
        old_key = node._item._key
        old_value = node._item._value
        node._item._key = k
        node._item._value = v
        return old_key, old_value

    def _delete(self, node):
        if self.num_children(node) == 2:
            raise ValueError("node has two children")
        child = self.left(node) if self.left(node) else self.right(node)
        if child:
            child._parent = node._parent
        if node is self.root():
            self._root = child
        else:
            parent = node._parent
            if self.left(parent) == node:
                parent._left = child
            else:
                parent._right = child
        self._size -= 1
        k, v = node.key(), node.value()
        node._item = node._parent = node._left = node._right = None
        return k, v

    def _subtree_search(self, node, k):
        if k == node.key():
            return node
        elif k < node.key() and self.left(node):
            return self._subtree_search(self.left(node), k)
        elif k > node.key() and self.right(node):
            return self._subtree_search(self.right(node), k)
        return node

    def _subtree_search_first(self, node):
        current_node = node
        while self.left(current_node):
            current_node = self.left(current_node)
        return current_node

    def _subtree_search_last(self, node):
        current_node = node
        while self.right(current_node):
            current_node = self.right(current_node)
        return current_node

    def first(self):
        if self.is_empty():
            raise ValueError("Tree is empty")
        return self._subtree_search_first(self.root())

    def last(self):
        if self.is_empty():
            raise ValueError("Tree is empty")
        return self._subtree_search_last(self.root())

    def before(self, node):
        if node:
            if self.left(node):
                return self._subtree_search_last(self.left(node))
            else:
                current_node = node
                parent = self.parent(current_node)
                while parent and current_node == self.left(parent):
                    current_node = parent
                    parent = self.parent(current_node)
                return parent

    def after(self, node):
        if node:
            if self.right(node):
                return self._subtree_search_first(self.right(node))
            else:
                current_node = node
                parent = self.parent(current_node)
                while parent and current_node == self.right(parent):
                    current_node = parent
                    parent = self.parent(current_node)
                return parent

    def find_node(self, k):
        if self.is_empty():
            return None
        node = self._subtree_search(self.root(), k)
        self._rebalance_access(node)
        return node

    def find_min(self):
        if self.is_empty():
            return None
        node = self.first()
        return node.key(), node.value()

    def find_max(self):
        if self.is_empty():
            return None
        node = self.last()
        return node.key(), node.value()

    def find_ge(self, k):
        if self.is_empty():
            return None
        node = self._subtree_search(self.root(), k)
        if node.key() == k:
            return node.key(), node.value()
        else:
            node = self.after(node)
            if node:
                return node.key(), node.value()
            return None

    def find_range(self, start=None, stop=None):
        if not self.is_empty():
            if start is None:
                current_node = self.first()
            else:
                current_node = self.find_node(start)
                if current_node.key() < start:
                    current_node = self.after(start)
            while current_node and (not stop or current_node.key() < stop):
                yield current_node.key(), current_node.value()
                current_node = self.after(current_node)

    def __getitem__(self, k):
        if self.is_empty():
            raise KeyError(f"KeyError: '{k}'")
        node = self._subtree_search(self.root(), k)
        self._rebalance_access(node)
        if not node.key() == k:
            raise KeyError(f"KeyError: '{k}'")
        return node.value()

    def __setitem__(self, k, v):
        if self.is_empty():
            leaf = self._add_root(k, v)
        else:
            node = self._subtree_search(self.root(), k)
            if node.key() == k:
                node._item._value = v
                self._rebalance_access(node)
                return
            elif node.key() < k:
                leaf = self._add_right(node, k, v)
            else:
                leaf = self._add_left(node, k, v)
        self._rebalance_insert(leaf)

    def delete(self, node):
        if self.num_children(node) == 2:
            replacement_node = self.before(node)
            self._replace(
                node, replacement_node.key(), replacement_node.value()
            )
            node = replacement_node
        parent = self.parent(node)
        self._delete(node)
        self._rebalance_delete(parent)

    def __delitem__(self, k):
        if not self.is_empty():
            node = self._subtree_search(self.root(), k)
            if node.key() == k:
                self.delete(node)
                return
            self._rebalance_access(node)
        raise KeyError(f"KeyError: '{k}'")

    def _inorder_traversal(self, node):
        if self.left(node):
            for other in self._inorder_traversal(self.left(node)):
                yield other
        yield node
        if self.right(node):
            for other in self._inorder_traversal(self.right(node)):
                yield other

    def __iter__(self):
        if not self.is_empty():
            for node in self._inorder_traversal(self.root()):
                yield node.key()

    def _is_subtree_binary(self, node, min_val, max_val):
        if not node:
            return True
        elif (
            node.key() > min_val
            and node.key() < max_val
            and self._is_subtree_binary(self.left(node), min_val, node.key())
            and self._is_subtree_binary(self.right(node), node.key(), max_val)
        ):
            return True
        return False

    def is_binary(self):
        """Check whether `self` is a binary search tree (BST)."""
        return self._is_subtree_binary(self.root(), float("-inf"), float("inf"))

    def _rebalance_insert(self, node):
        pass

    def _rebalance_delete(self, node):
        pass

    def _rebalance_access(self, node):
        pass

    def _relink(self, parent, child, make_left_child):
        """Relink parent node with child node."""
        if make_left_child:
            parent._left = child
        else:
            parent._right = child
        if child:
            child._parent = parent

    def _rotate(self, node):
        """Rotate node above its parent."""
        x = node
        y = self.parent(x)
        z = self.parent(y)
        if not z:
            self._root = x
            x._parent = None
        else:
            self._relink(z, x, y == self.left(z))
        if x == self.left(y):
            self._relink(y, self.right(x), True)
            self._relink(x, y, False)
        else:
            self._relink(y, self.left(x), False)
            self._relink(x, y, True)

    def _restructure(self, x):
        """Perform trinode restructure of node with its parent/grandparent."""
        y = self.parent(x)
        z = self.parent(y)
        if (x == self.left(y)) == (y == self.left(z)):
            self._rotate(y)
            return y
        else:
            self._rotate(x)
            self._rotate(x)
            return x

In [3]:
bst = BinarySearchTreeMap()

In [4]:
bst['imad']

KeyError: "KeyError: 'imad'"

In [5]:
bst['imad'] = 1

In [6]:
len(bst)

1

In [7]:
[k for k in bst]

['imad']

In [8]:
del bst['imad']

In [9]:
len(bst)

0

In [10]:
for i in range(10):
    bst[i] = i

In [11]:
len(bst)

10

In [12]:
[k for k in bst]

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [13]:
bst[9]

9

In [14]:
bst.find_min(), bst.find_max()

((0, 0), (9, 9))

In [15]:
bst.find_ge(9)

(9, 9)

In [16]:
for p in bst.find_range():
    print(p)

(0, 0)
(1, 1)
(2, 2)
(3, 3)
(4, 4)
(5, 5)
(6, 6)
(7, 7)
(8, 8)
(9, 9)


In [17]:
9 in bst

True

In [18]:
list(bst.items())

[(0, 0),
 (1, 1),
 (2, 2),
 (3, 3),
 (4, 4),
 (5, 5),
 (6, 6),
 (7, 7),
 (8, 8),
 (9, 9)]

In [19]:
del bst[5]

In [20]:
len(bst)

9

In [21]:
5 in bst

False

In [22]:
[k for k in bst]

[0, 1, 2, 3, 4, 6, 7, 8, 9]

In [23]:
bst.is_binary()

True

| Operation | Running Time |
| :-: | :-: |
| k in T | O(h) |
| T[k], T[k] = v | O(h) |
| T.delete(p), del T[k] | O(h) |
| T.find_position(k) | O(h) |
| T.first(), T.last(), T.find_min(), T.find_max() | O(h) |
| T.before(p), T.after(p) | O(h) |
| T.find_lt(k), T.find_le(k), T.find_gt(k), T.find_ge(k) | O(h) |
| T.find_range(start, stop) | O(s+h) |
| iter(T), reversed(T) | O(n) |

# AVL Trees

In [24]:
class AVLTreeMap(BinarySearchTreeMap):
    """Sorted map implementation using an AVL tree."""

    class _Node(BinarySearchTreeMap._Node):
        def __init__(self, parent, left, right, item):
            super().__init__(parent, left, right, item)
            self._height = 0

        def left_height(self):
            return self._left._height if self._left else 0

        def right_height(self):
            return self._right._height if self._right else 0

    def _recompute_height(self, node):
        node._height = 1 + max(node.left_height(), node.right_height())

    def _is_balanced(self, node):
        return abs(node.left_height() - node.right_height()) <= 1

    def _tall_child(self, node, favor_left=False):
        if node.left_height() + favor_left * 1 > node.right_height():
            return self.left(node)
        return self.right(node)

    def _tall_grandchild(self, node):
        child = self._tall_child(node)
        alignment = child == self.left(node)
        return self._tall_child(child, alignment)

    def _rebalance(self, node):
        current_node = node
        while current_node:
            old_height = node._height
            if not self._is_balanced(current_node):
                current_node = self._restructure(
                    self._tall_grandchild(current_node)
                )
                self._recompute_height(self.left(current_node))
                self._recompute_height(self.right(current_node))
            self._recompute_height(current_node)
            if old_height == current_node._height:
                current_node = None
            else:
                current_node = self.parent(current_node)

    def _rebalance_insert(self, node):
        return self._rebalance(node)

    def _rebalance_delete(self, node):
        return self._rebalance(node)

In [25]:
bst = AVLTreeMap()

In [26]:
bst['imad']

KeyError: "KeyError: 'imad'"

In [27]:
bst['imad'] = 1

In [28]:
len(bst)

1

In [29]:
[k for k in bst]

['imad']

In [30]:
del bst['imad']

In [31]:
len(bst)

0

In [32]:
for i in range(10):
    bst[i] = i

In [33]:
len(bst)

10

In [34]:
[k for k in bst]

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [35]:
bst[9]

9

In [36]:
bst.find_min(), bst.find_max()

((0, 0), (9, 9))

In [37]:
bst.find_ge(9)

(9, 9)

In [38]:
for p in bst.find_range():
    print(p)

(0, 0)
(1, 1)
(2, 2)
(3, 3)
(4, 4)
(5, 5)
(6, 6)
(7, 7)
(8, 8)
(9, 9)


In [39]:
9 in bst

True

In [40]:
list(bst.items())

[(0, 0),
 (1, 1),
 (2, 2),
 (3, 3),
 (4, 4),
 (5, 5),
 (6, 6),
 (7, 7),
 (8, 8),
 (9, 9)]

In [41]:
[bst[k] for k in bst]

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [42]:
del bst[5]

In [43]:
len(bst)

9

In [44]:
5 in bst

False

In [45]:
[k for k in bst]

[0, 1, 2, 3, 4, 6, 7, 8, 9]

In [46]:
bst.is_binary()

True

| Operation | Running Time |
| :-: | :-: |
| k in T | O(logn) |
| T[k], T[k] = v | O(logn) |
| T.delete(p), del T[k] | O(logn) |
| T.find_position(k) | O(logn) |
| T.first(), T.last(), T.find_min(), T.find_max() | O(logn) |
| T.before(p), T.after(p) | O(logn) |
| T.find_lt(k), T.find_le(k), T.find_gt(k), T.find_ge(k) | O(logn) |
| T.find_range(start, stop) | O(s+logn) |
| iter(T), reversed(T) | O(n) |

In [47]:
bst._is_balanced(bst.root())

True

# Splay Trees

In [48]:
class SplayTreeMap(BinarySearchTreeMap):
    """Sorted map implementation using splay tree."""
    def _splay(self, node):
        while node is not self.root():
            parent = self.parent(node)
            grand_parent = self.parent(parent)
            
            # zig case
            if not grand_parent:
                self._rotate(node)
        
            # zig-zig case
            elif (parent == self.left(grand_parent)) == (node == self.left(parent)):
                self._rotate(parent)
                self._rotate(node)
    
            # zig-zag case (similar double-rotation)
            else:
                self._rotate(node)
                self._rotate(node)                
        
    def _rebalance_access(self, node):
        self._splay(node)
        
    def _rebalance_insert(self, node):
        self._splay(node)

    def _rebalance_delete(self, node):
        if node:
            self._splay(node)

In [49]:
bst = SplayTreeMap()

In [50]:
bst['imad']

KeyError: "KeyError: 'imad'"

In [51]:
bst['imad'] = 1

In [52]:
len(bst)

1

In [53]:
[k for k in bst]

['imad']

In [54]:
del bst['imad']

In [55]:
len(bst)

0

In [56]:
for i in range(10):
    bst[i] = i

In [57]:
len(bst)

10

In [58]:
[k for k in bst]

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [59]:
del bst[9]

In [60]:
list(bst.items())

[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (8, 8)]

In [61]:
[bst[k] for k in bst]

[0, 1, 2, 3, 4, 5, 6, 7, 8]

In [62]:
bst.find_min(), bst.find_max()

((0, 0), (8, 8))

In [63]:
bst.find_ge(9)

In [64]:
for p in bst.find_range():
    print(p)

(0, 0)
(1, 1)
(2, 2)
(3, 3)
(4, 4)
(5, 5)
(6, 6)
(7, 7)
(8, 8)


In [65]:
9 in bst

False

In [66]:
del bst[5]

In [67]:
len(bst)

8

In [68]:
5 in bst

False

In [69]:
[k for k in bst]

[0, 1, 2, 3, 4, 6, 7, 8]

In [70]:
bst.is_binary()

True

| Operation | Running Time |
| :-: | :-: |
| k in T | O(logn)* |
| T[k], T[k] = v | O(logn)* |
| T.delete(p), del T[k] | O(logn)* |
| T.find_position(k) | O(logn)* |
| T.first(), T.last(), T.find_min(), T.find_max() | O(logn)* |
| T.before(p), T.after(p) | O(logn)* |
| T.find_lt(k), T.find_le(k), T.find_gt(k), T.find_ge(k) | O(logn)* |
| T.find_range(start, stop) | O(s+logn)* |
| iter(T), reversed(T) | O(n) |

Splay tree has an amortized running time of the fundamental operations O(logn). It does not provide upper bound or guarantees because we maybe unlucky and have a worst case O(n). However, using amortized analysis, we can show that it has O(logn) amortized running time.

# Red-Black Trees

In [71]:
class RedBlackTreeMap(BinarySearchTreeMap):
    """Sorted map using implementation of red-black tree."""

    class _Node(BinarySearchTreeMap._Node):
        def __init__(self, parent, left, right, item):
            super().__init__(parent, left, right, item)
            self._red = True # all new nodes will be red nodes before rebalancing
        
    def _set_red(self, node):
        node._red = True
    
    def _set_black(self, node):
        node._red = False
    
    def _set_color(self, node, make_red):
        node._red = make_red
    
    def _is_red(self, node):
        return node is not None and node._red
    
    def _is_red_leaf(self, node):
        return self._is_red(node) and self.is_leaf(node)
    
    def _get_red_child(self, node):
        for child in self.children(node):
            if self._is_red(child):
                return child
        return None
    
    def _rebalance_insert(self, node):
        if node == self.root():  # root should be black
            self._set_black(node)
        else:
            parent = self.parent(node)
            # check if there is double-red violation
            if self._is_red(parent):
                uncle = self.sibling(parent)
                if self._is_red(uncle):
                    grand_parent = self.parent(parent)
                    self._set_black(parent)
                    self._set_black(uncle)
                    self._set_red(grand_parent)
                    # keep going up until no double-red violation or reach the root (O(logn))
                    self._rebalance_insert(grand_parent)
                else:
                    # It only happens once; if needed, in any insertion (O(1))
                    middle = self._restructure(node)
                    self._set_red(self.left(middle))
                    self._set_red(self.right(middle))
                    self._set_black(middle)
    
    def _rebalance_delete(self, node):
        if len(self):
            self._set_black(self.root())
        elif node:
            n = self.num_children(node)
            if n == 1:
                child = next(self.children(node))
                
                # deleted node was black with no children
                if not self._is_red_leaf(child):
                    self._fix_black_deficit(node, child)
            
            # deleted node was black and has red child (promoted)
            elif n == 2:
                if self._is_red(self.left(node)):
                    self._set_black(self.left(node))
                else:
                    self._set_black(self.right(node))
    
    def _fix_black_deficit(self, z, y):
        if not self._is_red(y):
            # case 1: y node is black and has one red child
            x = self._get_red_child(y)
            if x:
                old_color = self._is_red(z)
                middle = self._restructure(x)
                self._set_color(middle, old_color)
                self._set_black(self.left(middle))
                self._set_black(self.right(middle))
            
            # case 2: y node is black and has 2 black children (or None)
            else:
                self._set_red(y)
                if self._is_red(z):
                    self._set_black(z)
                elif not self.is_root(z):
                    self._fix_black_deficit(self.parent(z), self.sibling(z))
        # case 3: y node is red
        else:
            # rotate
            self._rotate(y)
            # recolor
            self._set_black(y)
            self._set_red(z)
            # re-apply algorithm on z's child
            if z == self.left(y):
                self._fix_black_deficit(z, self.right(z))
            else:
                self._fix_black_deficit(z, self.left(z))

In [72]:
bst = RedBlackTreeMap()

In [73]:
bst['imad']

KeyError: "KeyError: 'imad'"

In [74]:
bst['imad'] = 1

In [75]:
len(bst)

1

In [76]:
[k for k in bst]

['imad']

In [77]:
del bst['imad']

In [78]:
len(bst)

0

In [79]:
for i in range(10):
    bst[i] = i

In [80]:
len(bst)

10

In [81]:
[k for k in bst]

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [82]:
bst[9]

9

In [83]:
bst.find_min(), bst.find_max()

((0, 0), (9, 9))

In [84]:
bst.find_ge(9)

(9, 9)

In [85]:
for p in bst.find_range():
    print(p)

(0, 0)
(1, 1)
(2, 2)
(3, 3)
(4, 4)
(5, 5)
(6, 6)
(7, 7)
(8, 8)
(9, 9)


In [86]:
9 in bst

True

In [87]:
list(bst.items())

[(0, 0),
 (1, 1),
 (2, 2),
 (3, 3),
 (4, 4),
 (5, 5),
 (6, 6),
 (7, 7),
 (8, 8),
 (9, 9)]

In [88]:
[bst[k] for k in bst]

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [89]:
del bst[5]

In [90]:
len(bst)

9

In [91]:
5 in bst

False

In [92]:
[k for k in bst]

[0, 1, 2, 3, 4, 6, 7, 8, 9]

In [93]:
bst.is_binary()

True

- Seach takes O(logn)
- Insertion takes O(logn) for the search, O(logn) for recolorings and at most one restructuring and it will terminate after that.
- Deletion takes O(logn) for the search, O(logn) for recolorings and at most two restructuring (case 3 when y is red) and it will terminate after that.

| Operation | Running Time |
| :-: | :-: |
| k in T | O(logn) |
| T[k], T[k] = v | O(logn) |
| T.delete(p), del T[k] | O(logn) |
| T.find_position(k) | O(logn) |
| T.first(), T.last(), T.find_min(), T.find_max() | O(logn) |
| T.before(p), T.after(p) | O(logn) |
| T.find_lt(k), T.find_le(k), T.find_gt(k), T.find_ge(k) | O(logn) |
| T.find_range(start, stop) | O(s+logn) |
| iter(T), reversed(T) | O(n) |