In [1]:
#AVL tree: self-balancing binary search tree
#The rotation operations do not change the in-order traversal of AVL tree
#Insertion takes O(log(n))
#Height of AVL tree is log(n)


class Node:
    def __init__(self, value):
        self.left = None
        self.right = None
        self.value = value
        self.parent = None
        self.height = 0
        
#Duplicates are excluded        
class AVLTree:
    def __init__(self):
        self.root = None
    
    def insert(self, value):
        new_node = Node(value)
        if self.root is None:
            self.root = new_node
        else:
            node = self.root
            while node:
                if value < node.value:
                    if node.left is not None:
                        node = node.left
                    else:
                        node.left = new_node
                        new_node.parent = node
                        self._inspect_insertion(new_node)
                        break
                else:
                    if node.right is not None:
                        node = node.right
                    else:
                        node.right = new_node
                        new_node.parent = node
                        self._inspect_insertion(new_node)
                        break
        
    
    def print_tree(self):
        if self.root == None:
            print("BST is empty")
        else:
            self._print_tree(self.root, 1)
            
    def _print_tree(self, node, number_of_space):
        if node.right is not None:
            self._print_tree(node.right, number_of_space+2)
        if node.parent:
            print(
                " " * number_of_space + str(node.value) + 
                "(P:" + str(node.parent.value) + ") (H:" + str(node.height) + ")")
        else:
            print(" " * number_of_space + str(node.value) + "(H:" + str(node.height) + ")")
        
        if node.left is not None:
            self._print_tree(node.left, number_of_space+2)
    
    def delete_node(self, value, root=None):
        if root is None:
            root = self.root

        node = self.search(value, root)
        
        if node:
            children_num = self.number_of_children(node)
            
            if children_num == 0:
                if node.parent:
                    if node == node.parent.left:
                        node.parent.left = None
                    else:
                        node.parent.right = None
                else:
                    self.root = None
                    
            elif children_num == 1:
                child = node.left
                if node.left is None:
                    child = node.right
                if node.parent:
                    if node.parent.left == node:
                        node.parent.left = child
                    else:
                        node.parent.right = child
                    child.parent = node.parent
                else:
                    self.root = child
                    child.parent = None
                    
            else:
                successor = self.min_value_node(node.right)
               
                node.value = successor.value
                
                self.delete_node(successor.value, node.right)  
                
                return
            
            parent = node.parent
            if parent is not None:
                self._inspect_deletion(parent)
            
            return True
        
        return False
    
    def search(self, value, startingPoint):
        node = startingPoint
        while node is not None and node.value != value:
            if value < node.value:
                node = node.left
            else:
                node = node.right
        return node 
    
    def min_value_node(self, node):
        while node.left:
            node = node.left
        return node
    
    def number_of_children(self, node):
        count = 0
        if node.left:
            count += 1
        if node.right:
            count += 1
        return count
        
    #traverse up the tree to see if any ancestor node is unbalanced
    def _inspect_insertion(self, current_node, path=[]):
        parent = current_node.parent
        
        if parent is None:
            return
        
        path.append(current_node)
        
        left_child_height = parent.left.height if parent.left is not None else -1
        right_child_height = parent.right.height if parent.right is not None else -1
       
        balance_factor = abs(right_child_height - left_child_height)
   
        if balance_factor == 0:
            return
        
        if balance_factor > 1:
            path.append(parent)
            self._rebalance_node(path[-1], path[-2], path[-3])
            return
        
        parent.height = max(left_child_height, right_child_height) + 1
        
        self._inspect_insertion(parent, path)
        
    def _inspect_deletion(self, current_node):
        
        if current_node is None: 
            return
        
        self._update_height(current_node)
        
        
        left_child_height = current_node.left.height if current_node.left is not None else -1
        right_child_height = current_node.right.height if current_node.right is not None else -1
        
        
        if abs(right_child_height - left_child_height) > 1:
            y = self.taller_child(current_node)
            x = None
            
            left_child_height = y.left.height if y.left is not None else -1
            right_child_height = y.right.height if y.right is not None else -1
            
            if left_child_height == right_child_height:
                if current_node.left == y:
                    x = y.left
                else:
                    x = y.right
            else:
                x = self.taller_child(y)
                
            self._rebalance_node(current_node, y, x)
            
        self._inspect_deletion(current_node.parent)
        
    def taller_child(self, node):
        if node.left is None:
            return node.right
        elif node.right is None:
            return node.left
        else:
            return node.left if node.left.height > node.right.height else node.right
        
    def _rebalance_node(self, z, y, x):
        if z.left == y and y.left == x:
            self._right_rotate(z)
            
        elif z.right == y and y.right == x:
            self._left_rotate(z)
            
        elif z.left == y and y.right == x:
            self._left_rotate(y)
            self._right_rotate(z)
        
        elif z.right == y and y.left == x:
            self._right_rotate(y)
            self._left_rotate(z)
        
        else:
            raise Exception("x, y, z failed to be balanced")
    
    def _right_rotate(self, z):
        y = z.left
        y.parent = z.parent
        if y.parent is None:
            self.root = y
        if z.parent:
            if z.parent.left == z:
                z.parent.left = y
            else:
                z.parent.right = y
        z.left = y.right
        if z.left:
            z.left.parent = z
        y.right = z
        z.parent = y
        
        self._update_height(z)
        self._update_height(y)
        
   
    
    def _left_rotate(self, z):
        y = z.right
        y.parent = z.parent
        if y.parent is None:
            self.root = y
        if z.parent:
            if z.parent.left == z:
                z.parent.left = y
            else:
                z.parent.right = y
        z.right = y.left
        if y.left:
            y.left.parent = z
        y.left = z
        z.parent = y
        
        self._update_height(z)
        self._update_height(y)
        


    def _update_height(self, node):
        left_child_height = node.left.height if node.left is not None else -1
        right_child_height = node.right.height if node.right is not None else -1
        node.height = max(left_child_height, right_child_height) + 1
        
    


In [2]:
avl = AVLTree()
avl.insert(11)
avl.insert(7)
avl.insert(3)
avl.insert(1)
avl.insert(5)
avl.print_tree()

   11(P:7) (H:0)
 7(H:2)
     5(P:3) (H:0)
   3(P:7) (H:1)
     1(P:3) (H:0)


In [3]:
avl.insert(0)
avl.print_tree()

     11(P:7) (H:0)
   7(P:3) (H:1)
     5(P:7) (H:0)
 3(H:2)
   1(P:3) (H:1)
     0(P:1) (H:0)


In [4]:
avl.insert(4)
avl.print_tree()

     11(P:7) (H:0)
   7(P:3) (H:2)
     5(P:7) (H:1)
       4(P:5) (H:0)
 3(H:3)
   1(P:3) (H:1)
     0(P:1) (H:0)


In [5]:
avl.insert(6)
avl.print_tree()

     11(P:7) (H:0)
   7(P:3) (H:2)
       6(P:5) (H:0)
     5(P:7) (H:1)
       4(P:5) (H:0)
 3(H:3)
   1(P:3) (H:1)
     0(P:1) (H:0)


In [6]:
avl.insert(8)
avl.insert(10)

In [7]:
print("========= test case for deletion =========")
avl = AVLTree()
avl.insert(3)
avl.insert(1)
avl.insert(7)
avl.insert(0)
avl.insert(5)
avl.insert(10)

avl.insert(4)
avl.insert(6)
avl.insert(8)
avl.insert(11)

avl.print_tree()

       11(P:10) (H:0)
     10(P:7) (H:1)
       8(P:10) (H:0)
   7(P:3) (H:2)
       6(P:5) (H:0)
     5(P:7) (H:1)
       4(P:5) (H:0)
 3(H:3)
   1(P:3) (H:1)
     0(P:1) (H:0)


In [8]:
avl.delete_node(7)
avl.print_tree()

       11(P:10) (H:0)
     10(P:8) (H:1)
   8(P:3) (H:2)
       6(P:5) (H:0)
     5(P:8) (H:1)
       4(P:5) (H:0)
 3(H:3)
   1(P:3) (H:1)
     0(P:1) (H:0)


In [9]:
avl.delete_node(5)
avl.print_tree()

       11(P:10) (H:0)
     10(P:8) (H:1)
   8(P:3) (H:2)
     6(P:8) (H:1)
       4(P:6) (H:0)
 3(H:3)
   1(P:3) (H:1)
     0(P:1) (H:0)


In [10]:
avl.delete_node(0)
avl.print_tree()

     11(P:10) (H:0)
   10(P:8) (H:1)
 8(H:3)
     6(P:3) (H:1)
       4(P:6) (H:0)
   3(P:8) (H:2)
     1(P:3) (H:0)


In [11]:
avl.delete_node(10)
avl.print_tree()

     11(P:8) (H:0)
   8(P:6) (H:1)
 6(H:2)
     4(P:3) (H:0)
   3(P:6) (H:1)
     1(P:3) (H:0)


In [12]:
avl.delete_node(8)
avl.print_tree()

   11(P:6) (H:0)
 6(H:2)
     4(P:3) (H:0)
   3(P:6) (H:1)
     1(P:3) (H:0)


In [13]:
avl.delete_node(11)
avl.print_tree()

   6(P:3) (H:1)
     4(P:6) (H:0)
 3(H:2)
   1(P:3) (H:0)


In [14]:
avl.delete_node(6)
avl.print_tree()

   4(P:3) (H:0)
 3(H:1)
   1(P:3) (H:0)


In [15]:
avl.delete_node(3)
avl.print_tree()

 4(H:1)
   1(P:4) (H:0)


In [16]:
avl.delete_node(4)
avl.print_tree()

 1(H:0)


In [17]:
avl.delete_node(1)
avl.print_tree()

BST is empty


In [18]:
print("========= test case for deletion =========")
avl = AVLTree()
avl.insert(8)
avl.insert(5)
avl.insert(9)
avl.insert(4)
avl.insert(6)
avl.insert(10)

avl.insert(3)
avl.insert(7)


avl.print_tree()

     10(P:9) (H:0)
   9(P:8) (H:1)
 8(H:3)
       7(P:6) (H:0)
     6(P:5) (H:1)
   5(P:8) (H:2)
     4(P:5) (H:1)
       3(P:4) (H:0)


In [19]:
avl.delete_node(10)
avl.print_tree()

     9(P:8) (H:0)
   8(P:5) (H:2)
       7(P:6) (H:0)
     6(P:8) (H:1)
 5(H:3)
   4(P:5) (H:1)
     3(P:4) (H:0)
