In [23]:
class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None
        self.height = 1


In [24]:
class AVLTree:
    def __init__(self):
        self.root = None

    def search(self, target):
        #print("search")
        current = self.root
        while current:
            if current.value == target:
                return current
            elif current.value > target:
                current = current.left
            else:
                current = current.right
        return None 

    def traversal(self,node):
        #print("traversal")
        if node:
            print(node.value, end=" ")
            self.traversal(node.left)
            self.traversal(node.right)

    def insert(self,node,value):
        #print("insert")

        if not node:
            return Node(value)
        
        # Insert left or right
        if value < node.value:
            node.left = self.insert(node.left,value)
        else:
            node.right = self.insert(node.right,value)

        # Updates height to calculate balance
        node.height = 1 + max(self.getHeight(node.left), self.getHeight(node.right))
        balance = self.getBalance(node)

        # Check if require balance
        if balance > 1:
            if value <  node.left.value:
                return self.rotateRight(node)
            else:
                return self.rotateLeftRight(node)
        if balance < -1:
            if value > node.right.value:
                return self.rotateLeft(node)
            else:
                return self.rotateRightLeft(node)
            
        return node

    def delete(self,node,value):
        #print("delete")
        if not node:
            return None
        if value < node.value:
            node.left = self.delete(node.left,value)
        elif value > node.value:
            node.right = self.delete(node.right,value)
        else:
            if not node.left or not node.right:
                if not node.left:
                    return node.right
                else:
                    return node.left
            else:
                child = self.findMin(node.right)
                node.value = child.value
                node.right = self.delete(node.right,child.value)

        node.height = 1 + max(self.getHeight(node.left),self.getHeight(node.right))
        balance = self.getBalance(node)

        if balance > 1:
            if self.getBalance(node.left) >= 0:
                return self.rotateRight(node)
            else:
                return self.rotateLeftRight(node)
        if balance < -1:
            if self.getBalance(node.right) <= 0:
                return self.rotateLeft(node)
            else:
                node.right = self.rotateRight(Node)
                return self.rotateRightLeft(node)
        
        return node

    def getHeight(self, node):
        #print("Get Height")
        return node.height if node else 0

    def getBalance(self,node):
        #print("Get Balance")
        return self.getHeight(node.left) - self.getHeight(node.right) if node else 0

    def findMin(self,node):
        while node.left:
            node = node.left
        return node

    def rotateLeft(self,x):
        #print("Left Rotation")
        y = x.right
        tmp = y.left
        y.left = x
        x.right = tmp

        x.height = 1 + max(self.getHeight(x.left), self.getHeight(x.right))
        y.height = 1 + max(self.getHeight(y.left), self.getHeight(y.right))

        return y

    def rotateRight(self,y):
        #print("Right Rotation")
        x = y.left
        tmp = x.right
        x.right = y
        y.left = tmp

        x.height = 1 + max(self.getHeight(x.left), self.getHeight(x.right))
        y.height = 1 + max(self.getHeight(y.left), self.getHeight(y.right))

        return x

    def rotateLeftRight(self,node):
        node.left = self.rotateLeft(node.left)
        return self.rotateRight(node)

    def rotateRightLeft(self, node):
        node.right = self.rotateRight(node.right)
        return self.rotateLeft(node)

    
def printTree(node, level=0):
    if node is not None:
        print('  ' * int(level) + str(node.value))
        printTree(node.left, level + 1)
        printTree(node.right, level + 1)

In [25]:
avl_tree = AVLTree()

values_to_insert = [10, 5, 15, 3, 7, 12, 17,99,34,68,22,50,1]



In [26]:

for value in values_to_insert:
    avl_tree.root = avl_tree.insert(avl_tree.root, value)


#avl_tree.traversal(avl_tree.root)
printTree(avl_tree.root)


15
  5
    3
      1
    10
      7
      12
  50
    34
      17
        22
    68
      99


In [27]:
values_to_delete = [5, 15]
for value in values_to_delete:
    avl_tree.delete(avl_tree.root,value)

#avl_tree.traversal(avl_tree.root)
printTree(avl_tree.root)

17
  7
    3
      1
    10
      12
  50
    34
      22
    68
      99
