In [179]:
class avlNode:
    def __init__(self, key, parent = None, height = 0):
        self.key = key
        self.parent = parent
        self.left = None
        self.right = None
        self.balanceFactor = 0 #bf = 结点的右子树的深度减去左子树的深度。新叶子=0
        self.height = height
    
    def setKey(self, key):
        self.key = key
    
    def setLeft(self, leftChild):
        self.left = leftChild
    
    def setRight(self, rightChild):
        self.right = rightChild
        
    def isLeaf(self):
        if self.left is None and self.right is None:
            return True
        else:
            return False
    
    def __iter__(self):
        if self.left is not None:
            for i in self.left:
                yield i
        yield self
        if self.right is not None:
            for i in self.right:
                yield i
                
    def getParent(self):
        return self.parent
    
    def isLeftChild(self):
        if self == self.getParent().left:
            return True
        else:
            return False
    
    def modifyBF(self, newBF):
        self.balanceFactor = newBF
        

In [180]:
class avlTree:
    def __init__(self):
        self.root = None   
        self.size = 0

    def height(self, node):
        if node is None:
            return -1
        else:
            return node.height
    def setHeight(self, node):
        node.height = max(self.height(node.left), self.height(node.right)) + 1
    def insert(self, key):
        if self.root is None:
            self.root = avlNode(key)
        else:
            self._insert(key, self.root)
        self.size += 1
    
    def _insert(self, key, node):
        if key == node.key:
            raise
        elif key > node.key:
            if node.right is None:
                node.right = avlNode(key, node)
                self.printBF()
                self.updateHeight(node)
            else:
                self._insert(key, node.right)
        else:
            if node.left is None:
                node.left = avlNode(key, node)
                self.printBF()
                self.updateHeight(node)
            else:
                self._insert(key, node.left)
                    
    def get(self, key):
        if self.root is None:
            return None
        else:
            return self._get(key, self.root)
    
    def _get(self, key, node):
        if key == node.key:
            return node
        elif key > node.key:
            if node.right is None:
                return None
            else:
                return self._get(key, node.right)
        else:
            if node.left is None:
                return None
            else:
                return self._get(key, node.left)
            
    def updateHeight(self, node):
        leftHeight = self.height(node.left)
        rightHeight = self.height(node.right)
        node.balanceFactor = rightHeight - leftHeight
        node.height = max(leftHeight, rightHeight) + 1
        if abs(node.balanceFactor) == 2:
            print('rebalancing', node.key)
            self.print()
            self.rebalance(node)
            self.updateHeight(node)
        if node.parent is not None:
            self.updateHeight(node.parent)            

    def _rebalance(self, node):
        if (self.height(node.right) - self.height(node.left)) == 2:
            if node.right.balanceFactor == 1:
                replacement = self.leftRotate(node)
            elif node.right.balanceFactor == -1:
                replacement = self.rightAndLeft(node)
            else:
                raise 
        elif (self.height(node.right) - self.height(node.left)) == -2:
            if node.left.balanceFactor == -1:
                replacement = self.rightRotate(node)
            elif node.left.balanceFactor == 1:
                replacement = self.leftAndRight(node)
            else:
                raise 
        return replacement
    
    def rebalance(self, node):
        if  node == self.root:
            replacement = self._rebalance(node)
            self.root = replacement
            self.root.parent = None
        else:
            parent = node.getParent()
            isLeftChild = node.isLeftChild()
            replacement = self._rebalance(node)
            if isLeftChild:
                parent.left = replacement
            else:
                parent.right = replacement
            if replacement is not None:
                replacement.parent = parent
            
    def delete(self, key):
        nodeToDelete = self.get(key)
        if nodeToDelete is None:
            return None
        elif  nodeToDelete == self.root:
            if self.size == 1:
                self.root = None
                self.size = 0
            else:
                replacement = self._delete(nodeToDelete)
                self.root = replacement
                self.size -= 1
        else:
            parent = nodeToDelete.getParent()
            replacement = self._delete(nodeToDelete)
            if nodeToDelete.isLeftChild():
                parent.left = replacement
            else:
                parent.right = replacement
            if replacement is not None:
                replacement.parent = parent
            self.size -= 1
            
    def _delete(self, nodeToDelete):
        if nodeToDelete.right is not None:
            replacement = self.findMinVal(nodeToDelete.right)
            replacement.parent.left = None #replacement is minimal value, i.e. left leaf of its parent. so delete it
            replacement.left = nodeToDelete.left
            
            self.rebalance(replacement)
            self.setHeight(replacement)
            return replacement #方便连接
        
        elif nodeToDelete.left is not None:
            replacement = self.findMaxVal(nodeToDelete.left) 
            replacement.parent.right = None #replacement is max value, i.e. right leaf of its parent. so delete it
            replacement.right = nodeToDelete.right
            
            self.rebalance(replacement)
            self.setHeight(replacement)
            return replacement #方便连接
        
        else:
            return None
        
            
    def findMinVal(self, node):
        if node.left is not None:
            return self.findMinVal(self, node.left)
        else:
            return node
    
    def findMaxVal(self, node):
        if node.right is not None:
            return self.findMaxVal(self, node.right)
        else:
            return node
    
    def __iter__(self):
        return self.root.__iter__()
    
    def print(self):
        print([i.key for i in self])
    
    def printBF(self):
        print([i.balanceFactor for i in self])
    

    def leftRotate(self, oldRoot):
        newRoot = oldRoot.right

        oldRoot.right = newRoot.left
        self.setParent(oldRoot.right, oldRoot)

        newRoot.left = oldRoot
        self.setParent(oldRoot, newRoot)
        
        self.setHeight(oldRoot)
        self.setHeight(newRoot)

        return newRoot

    def rightRotate(self, oldRoot):
        newRoot = oldRoot.left

        oldRoot.left = newRoot.right
        self.setParent(oldRoot.left, oldRoot)

        newRoot.right = oldRoot
        self.setParent(oldRoot, newRoot)
        
        self.setHeight(oldRoot)
        self.setHeight(newRoot)
        
        return newRoot

    def leftAndRight(self, root):
        root.left = self.leftRotate(root.left)
        return self.rightRotate(root)

    def rightAndLeft(self, root):
        root.right = self.rightRotate(root.right)
        return self.leftRotate(root)
    
    def setParent(self, child, parent):
        if child is not None:
            child.parent = parent

# Test case

In [181]:
a = avlTree()
a.insert(15)
for i in [9,12,24,34,80,99]:
    print(a.root.key)
    a.print()
    a.printBF()
    a.insert(i)
a.print()
a.printBF()

15
[15]
[0]
[0, 0]
15
[9, 15]
[0, -1]
[0, 0, -1]
rebalancing 15
[9, 12, 15]
12
[9, 12, 15]
[1, 0, 0]
[1, 0, 0, 0]
12
[9, 12, 15, 24]
[1, 1, 1, 0]
[1, 1, 1, 0, 0]
rebalancing 15
[9, 12, 15, 24, 34]
12
[9, 12, 15, 24, 34]
[1, 1, 0, 0, 0]
[1, 1, 0, 0, 0, 0]
rebalancing 12
[9, 12, 15, 24, 34, 80]
24
[9, 12, 15, 24, 34, 80]
[1, 0, 0, 0, 1, 0]
[1, 0, 0, 0, 1, 0, 0]
rebalancing 34
[9, 12, 15, 24, 34, 80, 99]
[9, 12, 15, 24, 34, 80, 99]
[1, 0, 0, 0, 0, 0, 0]


In [170]:
a.root.parent

In [171]:
a.root.right.key

15

In [172]:
a.insert(24)

inserted 24 under 15
[1, 0, 0, 0]
15 1
12 1


In [173]:
a.root.right.isLeftChild()

False

In [174]:
a.insert(34)

inserted 34 under 24
[1, 1, 1, 0, 0]
24 1
15 2
rebalancing 15
[9, 12, 15, 24, 34]
replacing 15 with 24

parent 12 new right child is 24
15 0
24 0
12 1
24 0
12 1


In [150]:
a.root.key

12

In [151]:
a.root.right.key

15