# AVL Tree

In [69]:
class Node(object):

    def __init__(self, data):
        self.data = data
        self.height = 0
        self.leftChild = None
        self.rightChild = None


class AVL(object):

    def __init__(self):
        self.root = None
    
    def remove(self, data):
        if self.root:
            self.root = self.removeNode(data, self.root)
            
    def removeNode(self, data, node):
        
        if not node:
            return node
        
        if data < node.data:
            node.leftNode = self.removeNode(data, node.leftChild)
        elif data > node.data:
            node.rightChild = self.removeNode(data, node.rightChild)
        else:
            
            if not node.leftChild and not node.rightChild:
                del node
                return None
            
            if not node.leftChild:
                tempNode = node.rightChild
                del node
                return tempNode
            elif not node.rightChild:
                tempNode = node.leftChild
                del node
                return tempNode
            
            tempNode = self.getPredecessor(node.leftChild)
            node.data = tempNode.data
            node.leftChild = self.removeNode(tempNode.data, node.leftChild);
            
        if not node:
            return node
        
        node.height = max(self.calcHeight(node.leftChild), self.calcHeight(node.rightChild)) + 1
        
        balance = self.calcBalance(node)
        
        if balance > 1 and self.calcBalance(node.leftChild) >=0:
            return self.rotateRight(node)
        
        if balance > 1 and self.calcBalance(node.leftChild) < 0:
            node.leftChild = self.rotateLeft(node.leftChild)
            return self.rotateRight(node)
        
        if balance < -1 and self.calcBalance(node.rightChild) <=0:
            return self.rotateLeft(node)
        
        if balance < -1 and self.calcBalance(node.rightChild) > 0:
            node.rightChild = self.rotateRight(node.rightChild)
            return self.rotateLeft(node)
        
        return node
            
        
    def getPredecessor(self, node):

        if node.rightChild:
            return self.getPredecessor(node.rightChild);

        return node;    
        
    def insert(self, data):
        self.root = self.insertNode(data, self.root)

    def insertNode(self, data, node):

        if not node:
            return Node(data)

        if data < node.data:
            node.leftChild = self.insertNode(data, node.leftChild)
        else:
            node.rightChild = self.insertNode(data, node.rightChild)

        node.height = max(self.calcHeight(node.leftChild), self.calcHeight(node.rightChild)) + 1

        return self.violation(data, node)

    def calcHeight(self, node):

        if not node:
            return -1

        return node.height

    def calcBalance(self, node):

        if not node:
            return 0

        return self.calcHeight(node.leftChild) - self.calcHeight(node.rightChild)
    

    def rotateRight(self, node):

        templeftChild = node.leftChild
        t = templeftChild.rightChild
        templeftChild.rightChild = node
        node.leftChild = t

        node.height = max(self.calcHeight(node.leftChild), self.calcHeight(node.rightChild)) + 1
        templeftChild.height = max(self.calcHeight(templeftChild.leftChild),
                                   self.calcHeight(templeftChild.rightChild)) + 1

        return templeftChild

    def rotateLeft(self, node):

        temprightChild = node.rightChild
        t = temprightChild.leftChild
        temprightChild.leftChild = node
        node.rightChild = t

        node.height = max(self.calcHeight(node.leftChild), self.calcHeight(node.rightChild)) + 1
        temprightChild.height = max(self.calcHeight(temprightChild.leftChild),
                                    self.calcHeight(temprightChild.rightChild)) + 1

        return temprightChild



    def violation(self, data, node):

        balance = self.calcBalance(node)

        # case1 (left left)
        if balance > 1 and data < node.leftChild.data:
            return self.rotateRight(node)

        # case2 (right right)
        if balance < -1 and data > node.rightChild.data:
            return self.rotateLeft(node)

        # case3 (left right)
        if balance > 1 and data > node.leftChild.data:
            node.leftChild = self.rotateLeft(node.leftChild)
            return self.rotateRight(node)

        # case4 (right left)
        if balance < -1 and data < node.rightChild.data:
            node.rightChild = self.rotateRight(node.rightChild)
            return self.rotateLeft(node)

        return node

    def traverse(self):
        if self.root:
            self.traverseInOrder(self.root);

    def traverseInOrder(self, node):


        if node.leftChild:
            self.traverseInOrder(node.leftChild);

        print("%s " % node.data);

        if node.rightChild:
            self.traverseInOrder(node.rightChild);

avl = AVL()
avl.insert(20)
avl.insert(30)
avl.insert(80)
avl.remove(30)
avl.traverse()            
            

20 
80 
