<a href="https://colab.research.google.com/github/Anirudh-Raghav/Braille-I/blob/master/data-structures/trees.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import random

# Binary search tree along with traversal methods

In [2]:
class Node:
    def __init__(self, value):
        self.left = None
        self.value = value
        self.right = None
        self.parent = None

In [3]:
class Tree:
    def create_node(self, value):
        return Node(value)

    def insert(self, node, value):
        if (node is None):
            return self.create_node(value)

        if (value < node.value):
            node.left = self.insert(node.left, value)
            node.left.parent = node
        
        elif (value > node.value):
            node.right = self.insert(node.right, value)
            node.right.parent = node
        
        else:
            raise ValueError('Tree only accepts unique keys')

        return node


    def insert_arr(self, root, arr):
        if root is None:
            root = create_node(arr[0])
            arr.pop(0)

        for i in range(len(arr)):
            insert(root, arr[i])
        return root


    def search(self, node, value):
        # Starts off with root and checks all values recursively
        if (node is None or node.value == value):
            return node
        
        if (node.value > value):
            return self.search(node.left, value)
        else:
            return self.search(node.right, value)


    def delete(self, node, value):
        # Starts off with root and checks all values recursively
        if (node is None):
            return node

        if (value < node.value):
            node.left = self.delete_node(node.left, value)

        elif (value > node.value):
            node.right = self.delete_node(node.right, value)

        else:
            if node.left is None and node.right is None and node.parent is not None:
                if node == node.parent.left:
                    node.parent.left = None
                else:
                    node.parent.right = None

            elif node.left is None and node.parent:
                if node == node.parent.left:
                    node.parent.left = node.right
                else:
                    node.parent.right = node.right

            elif node.right is None and node.parent:
                if node == node.parent.left:
                    node.parent.left = node.left
                else:
                    node.parent.right = node.left

            elif node.parent:
                temp = node.left
                while temp.right is not None:
                    temp = temp.right
                node.value = temp.value
                delete(temp, temp.value)

        return node


    def invert(self, node):
        if node is None:
            return None

        node.left = self.invert(node.right)
        node.right = self.invert(node.left)
        return node


    def inorder_traverse(self, node):
        # Starts off with root and checks all values recursively
        if node is None:
            return []

        return self.inorder_traverse(node.left) + [node.value] + self.inorder_traverse(node.right)

    def preorder_traverse(self, node):
        # Starts off with root and checks all values recursively
        if node is None:
            return []

        return [node.value] + self.preorder_traverse(node.left) + self.preorder_traverse(node.right)

    def postorder_traverse(self, node):
        # Starts off with root and checks all values recursively
        if node is None:
            return []

        return self.postorder_traverse(node.left) + self.postorder_traverse(node.right) + [node.value]

    def levelorder_traverse(self, root):
        result = []

        if root is None:
            return result

        queue = []
        queue.append(root)

        while (len(queue) > 0):
            result.append(queue[0].value)
            node = queue.pop(0)

            if node.left is not None:
                queue.append(node.left)

            if node.right is not None:
                queue.append(node.right)
    
    def get_node_height(self, node):
        # Starts off with root and checks all values recursively
        if node is None:
            return 0

        if not node.left and not node.right:
            return 1

        return 1 + max(self.get_node_height(node.left), self.get_node_height(node.right))

# Balanced BSTS

## AVL Trees

In [4]:
class avl_trees(Tree):
    def rotate_left(self, node):
        temp = node.right
        temp.left, node.right = node, temp.left
        temp.parent = node.parent
        node.parent = temp
        if node.right:
            node.right.parent = node
        return temp


    def rotate_right(self, node):
        temp = node.left
        temp.right, node.left = node, temp.right
        temp.parent = node.parent
        node.parent = temp
        if node.left:
            node.left.parent = node
        return temp


    def avl_test(self, node):
        return 1 >= abs(self.get_node_height(node.left) - self.get_node_height(node.right))


    def avl_balance(self, node):
        if not node:
            return None

        self.avl_balance(node.left)
        self.avl_balance(node.right)

        if avl_test(node):
            return

        if self.get_node_height(node.left) > self.get_node_height(node.right):
            self.rotate_right(node)
        self.rotate_left(node)


    def avl_insert(self, root, value):
        self.insert(root, value)
        self.avl_balance(root)

        return root


    def avl_insert_arr(self, root, arr):
        if root is None:
            root = create_node(arr[0])
            arr.pop(0)

        for i in range(len(arr)):
            avl_insert(root, arr[i])

        return root


    def avl_delete(self, root, value):
        self.delete(root, value)
        self.avl_balance(root)

        return root

## Red-black Trees

Taken from [red black trees implementation on qvault](https://qvault.io/python/red-black-tree-python/)
(NOT YET UNDERSTOOD)

In [5]:
class RBNode:
    def __init__(self, val):
        self.red = False
        self.parent = None
        self.val = val
        self.left = None
        self.right = None

In [6]:
class RBTree:
    def __init__(self):
        self.nil = RBNode(0)
        self.nil.red = False
        self.nil.left = None
        self.nil.right = None
        self.root = self.nil

    def insert(self, val):
        # Ordinary Binary Search Insertion
        new_node = RBNode(val)
        new_node.parent = None
        new_node.left = self.nil
        new_node.right = self.nil
        new_node.red = True  # new node must be red

        parent = None
        current = self.root
        while current != self.nil:
            parent = current
            if new_node.val < current.val:
                current = current.left
            elif new_node.val > current.val:
                current = current.right
            else:
                return

        # Set the parent and insert the new node
        new_node.parent = parent
        if parent == None:
            self.root = new_node
        elif new_node.val < parent.val:
            parent.left = new_node
        else:
            parent.right = new_node

        # Fix the tree
        self.fix_insert(new_node)

    def fix_insert(self, new_node):
        while new_node != self.root and new_node.parent.red:
            if new_node.parent == new_node.parent.parent.right:
                u = new_node.parent.parent.left  # uncle
                if u.red:
                    u.red = False
                    new_node.parent.red = False
                    new_node.parent.parent.red = True
                    new_node = new_node.parent.parent
                else:
                    if new_node == new_node.parent.left:
                        new_node = new_node.parent
                        self.rotate_right(new_node)
                    new_node.parent.red = False
                    new_node.parent.parent.red = True
                    self.rotate_left(new_node.parent.parent)
            else:
                u = new_node.parent.parent.right  # uncle

                if u.red:
                    u.red = False
                    new_node.parent.red = False
                    new_node.parent.parent.red = True
                    new_node = new_node.parent.parent
                else:
                    if new_node == new_node.parent.right:
                        new_node = new_node.parent
                        self.rotate_left(new_node)
                    new_node.parent.red = False
                    new_node.parent.parent.red = True
                    self.rotate_right(new_node.parent.parent)
        self.root.red = False

    def exists(self, val):
        curr = self.root
        while curr != self.nil and val != curr.val:
            if val < curr.val:
                curr = curr.left
            else:
                curr = curr.right
        return curr

    # rotate left at node x
    def rotate_left(self, x):
        y = x.right
        x.right = y.left
        if y.left != self.nil:
            y.left.parent = x

        y.parent = x.parent
        if x.parent == None:
            self.root = y
        elif x == x.parent.left:
            x.parent.left = y
        else:
            x.parent.right = y
        y.left = x
        x.parent = y

    # rotate right at node x
    def rotate_right(self, x):
        y = x.left
        x.left = y.right
        if y.right != self.nil:
            y.right.parent = x

        y.parent = x.parent
        if x.parent == None:
            self.root = y
        elif x == x.parent.right:
            x.parent.right = y
        else:
            x.parent.left = y
        y.right = x
        x.parent = y

    def __repr__(self):
        lines = []
        print_tree(self.root, lines)
        return '\n'.join(lines)


def print_tree(node, lines, level=0):
    if node.val != 0:
        print_tree(node.left, lines, level + 1)
        lines.append('-' * 4 * level + '> ' +
                     str(node.val) + ' ' + ('r' if node.red else 'b'))
        print_tree(node.right, lines, level + 1)


def get_nums(num):
    random.seed(1)
    nums = []
    for _ in range(num):
        nums.append(random.randint(1, num-1))
    return nums


def main():
    tree = RBTree()
    for x in range(1, 51):
        tree.insert(x)
    print(tree)


main()

----------------> 1 b
------------> 2 b
----------------> 3 b
--------> 4 b
----------------> 5 b
------------> 6 b
----------------> 7 b
----> 8 b
----------------> 9 b
------------> 10 b
----------------> 11 b
--------> 12 b
----------------> 13 b
------------> 14 b
----------------> 15 b
> 16 b
----------------> 17 b
------------> 18 b
----------------> 19 b
--------> 20 b
----------------> 21 b
------------> 22 b
----------------> 23 b
----> 24 b
--------------------> 25 b
----------------> 26 b
--------------------> 27 b
------------> 28 b
--------------------> 29 b
----------------> 30 b
--------------------> 31 b
--------> 32 r
------------------------> 33 b
--------------------> 34 b
------------------------> 35 b
----------------> 36 r
------------------------> 37 b
--------------------> 38 b
------------------------> 39 b
------------> 40 b
------------------------> 41 b
--------------------> 42 b
------------------------> 43 b
----------------> 44 r
------------------------>