In [21]:
class Node:
    def __init__(self,data,parent):
        self.data = data
        self.parent = parent
        self.height = 0
        self.left_node = None
        self.right_node = None
    
class AVLtree:
    def __init__(self):
        self.root = None
        
    def insert_data(self,data):
        print("inserted node: ", data)
        
        if self.root:
            self.insert(data,self.root)
            
        else:
            self.root = Node(data,None)
            
        
            
    def insert(self,data,node):
        if data>node.data:
            if node.right_node:
                self.insert(data,node.right_node)
            else:
                node.right_node = Node(data,node)
                node.height = max(self.calc_height(node.left_node),self.calc_height(node.right_node))+1
        else:
            if node.left_node:
                self.insert(data,node.left_node)
            else:
                node.left_node = Node(data,node)
                node.height = max(self.calc_height(node.left_node),self.calc_height(node.right_node))+1
                       
        self.handle_violation(node)
        
    def remove_data(self,data):
        if self.root:
            self.remove(data,self.root)

        
    def remove(self,data,node):
        if data<node.data:
            self.remove(data,node.left_node)
        elif data>node.data:
            self.remove(data,node.left_node)
        else:
            if node.left_node is None and node.right_node is None:
                parent = node.parent
                if node.parent:
                    if node.parent.left_node.data == node.data:
                        node.parent.left_node = None
                    else:
                        node.parent.right_node = None    
                else:
                    self.root = None
                        
                del node
                self.handle_violation(parent)
                        
            elif node.left_node is None and node.right_node:
                parent = node.parent
                if parent:
                    if parent.left_node == node:
                        parent.left_node = node.right_node
                    else:
                        parent.right_node = node.right_node
                else:
                    self.root = node.right_node
                        
                node.right_node.parent = parent
                del node
                self.handle_violation(parent)
                    
            elif node.right_node is None and node.left_node:
                parent = node.parent
                if parent:
                    if parent.left_node == node:
                        parent.left_node = node.left_node
                    else:
                        parent.right_node = node.left_node
                else:
                    self.root = node.left_node
                    
                node.left_node.parent = parent
                del node
                self.handle_violation(parent)
                
            else:
                predecessor = self.get_predecessor(node.left_node)
                t = node.data
                node.data = predecessor.data
                predecessor.data = t
                
                self.remove(predecessor.data,predecessor)
                
                    
    def get_predecessor(self,node):
        if node.right_node:
            return self.get_predecessor(node.right_node)
        else:
            return node
    
    def handle_violation(self, node):
        while node is not None:
            node.height = max(self.calc_height(node.left_node),self.calc_height(node.right_node))+1
            self.violation_helper(node)
            node = node.parent
        
    def violation_helper(self,node):
        balance = self.calc_balance(node)
        
        if balance >1:
            if self.calc_balance(node.left_node)<0:
                self.rotate_left(node.left_node)
            
            self.rotate_right(node)
            
        if balance < -1:
            if self.calc_balance(node.right_node)>0:
                self.rotate_right(node.right_node)
            
            self.rotate_left(node)
    
    def calc_height(self,node):
        if node is None:
            return -1
        return node.height
    
    def calc_balance(self,node):
        if node:
            return self.calc_height(node.left_node)-self.calc_height(node.right_node)
        else:
            return 0
    
    def rotate_left(self,node):
        if node.parent:
            parent = node.parent
            n_right = node.right_node
            node.right_node = n_right.left_node
            if n_right.left_node:
                n_right.left_node.parent = node
            n_right.left_node=node
            node.parent = n_right
            n_right.parent = parent
            
            if parent.left_node == node:
                parent.left_node = n_right
            else:
                parent.right_node = n_right
        else:
            parent = None
            n_right = node.right_node
            node.right_node = n_right.left_node
            if n_right.left_node:
                n_right.left_node.parent = node
            n_right.left_node=node
            node.parent = n_right
            n_right.parent = parent
            self.root = n_right
        
        node.height = max(self.calc_height(node.left_node),self.calc_height(node.right_node))+1
        n_right.height = max(self.calc_height(n_right.left_node),self.calc_height(n_right.right_node))+1
        
        print("rotated left on node: ",node.data)
            
    def rotate_right(self,node):
        if node.parent:
            parent = node.parent
            n_left = node.left_node
            node.left_node = n_left.right_node
            if n_left.right_node:
                n_left.right_node.parent = node
            n_left.right_node=node
            node.parent = n_left
            n_left.parent = parent
            
            if parent.left_node == node:
                parent.left_node = n_left
            else:
                parent.right_node = n_left
        else:
            parent = None
            n_left = node.left_node
            node.left_node = n_left.right_node
            if n_left.right_node:
                n_left.right_node.parent = node
            n_left.right_node=node
            node.parent = n_left
            n_left.parent = parent
            self.root = n_left
            
        node.height = max(self.calc_height(node.left_node),self.calc_height(node.right_node))+1
        n_left.height = max(self.calc_height(n_left.left_node),self.calc_height(n_left.right_node))+1
        
        print("rotated right on node: ",node.data)
                
    def inorder_traversal(self):
        self.inorder(self.root)
    
    def inorder(self,node):
        if node.left_node:
            self.inorder(node.left_node)
        print(node.data)
        if node.right_node:
            self.inorder(node.right_node)
        
            
            
        

                    
                    
                
            
            
        

In [25]:
if __name__ == "__main__":
    f = AVLtree()
    f.insert_data(6)
    f.insert_data(1)
    f.insert_data(2)
    f.insert_data(3)
    f.insert_data(4)
    
    f.inorder_traversal()
    
    f.remove_data(2)
    f.inorder_traversal()
    print("root of the tree is: ", f.root.data)

inserted node:  6
inserted node:  1
inserted node:  2
rotated left on node:  1
rotated right on node:  6
inserted node:  3
inserted node:  4
rotated left on node:  3
rotated right on node:  6
1
2
3
4
6
rotated left on node:  1
1
3
4
6
root of the tree is:  4
