In [2]:
class BSTNode:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None

In [3]:
#Sample BST

root = BSTNode(40)
root.left = BSTNode(20)
root.right = BSTNode(80)
root.left.left = BSTNode(8)
root.left.right = BSTNode(30)
root.right.left = BSTNode(60)
root.right.right = BSTNode(100)
root.right.right.right = BSTNode(120)

InOrder traversal of BST - gives ascending order

In [4]:
def printBSTAsc(root):
    if root == None:
        return
    printBSTAsc(root.left)
    print(root.data, end = " ")
    printBSTAsc(root.right)
    
printBSTAsc(root)

8 20 30 40 60 80 100 120 

Search in BST

In [5]:
#Recursive Solution - Time: O(height), Space: O(height)
def searchBST(root, key):
    if root == None:
        return False
    else:
        if key == root.data:
            return True
        elif key < root.data:
            return searchBST(root.left, key)
        else:
            return searchBST(root.right, key)

print(searchBST(root, 100))
print(searchBST(root, 120))
print(searchBST(root, 1))
print(searchBST(root, 500))

True
True
False
False


In [6]:
#Iterative Solution - Time: O(height), Space: O(1)
def searchBSTIter(root, key):
    temp = root
    while temp != None:
        if key == temp.data:
            return True
        elif key < temp.data:
            temp = temp.left
        else:
            temp = temp.right
            
    return False

print(searchBSTIter(root, 100))
print(searchBSTIter(root, 120))
print(searchBSTIter(root, 1))
print(searchBSTIter(root, 500 ))

True
True
False
False


Insert a node into BST (Assume no duplicates allowed)

In [7]:
#Iterative solution: Time: O(height), Space: O(1)
def insertBST(root, key):
    newNode = BSTNode(key)
    temp = root
    parent = root
    while temp != None:
        parent = temp
        if key == temp.data:
            print(f"Node with value {key} already exists")
            return root
        elif key < temp.data:
            temp = temp.left
        else:
            temp = temp.right
            
    if parent == None:
        return newNode
    if key < parent.data:
        parent.left = newNode
    elif key > parent.data:
        parent.right = newNode
    return root

insertBST(root, 100)
printBSTAsc(root)
print()
insertBST(root, 75)
printBSTAsc(root)
print()
insertBST(root, 25)
printBSTAsc(root)

Node with value 100 already exists
8 20 30 40 60 80 100 120 
8 20 30 40 60 75 80 100 120 
8 20 25 30 40 60 75 80 100 120 

In [8]:
#Recursive Solution: Time: O(height), Space: O(height)
def insertBSTRec(root, key):
    if root == None:
        root = BSTNode(key)
        return root
    else:
        if key == root.data:
            print(f"Node with value {key} already exists")
            return root
        elif key < root.data:
            root.left = insertBSTRec(root.left, key)
            return root
        else:
            root.right = insertBSTRec(root.right, key)
            return root

root1 = None
nodeList = [10, 15, 5, 7, 12, 7, 18, 9]
for node in nodeList:
    root1 = insertBSTRec(root1, node)
printBSTAsc(root1)     

Node with value 7 already exists
5 7 9 10 12 15 18 

Delete a node from BST

In [9]:
#Time: O(height), Space: O(height)
def deleteNodeBST(root, key):
    if root == None:
        print(f"Key {key} not in the BST")
        return None
    else:
        if key == root.data:
            if root.left == None and root.right == None:
                root = None
                return root
            elif root.left == None and root.right != None:
                return root.right
            elif root.left != None and root.right == None:
                return root.left
            else:
                successorData = InOrderSuccessor(root.right)
                root.data = successorData
                root.right = deleteNodeBST(root.right, successorData)
        
        elif key < root.data:
            root.left = deleteNodeBST(root.left, key)
            return root
        else:
            root.right = deleteNodeBST(root.right, key)
            return root

def InOrderSuccessor(root):
    if root == None:
        return
    curr = root
    while curr.left != None:
        curr = curr.left
    return curr.data

printBSTAsc(root1)
print()
deleteNodeBST(root1, 9)
printBSTAsc(root1)
print()
deleteNodeBST(root1, 5)
printBSTAsc(root1)
print()
deleteNodeBST(root1, 10)
printBSTAsc(root1)
print()

5 7 9 10 12 15 18 
5 7 10 12 15 18 
7 10 12 15 18 
7 12 15 18 


In [10]:
root2 = BSTNode(10)
insertBST(root2, 7)
insertBST(root2, 15)
printBSTAsc(root2)
print()
deleteNodeBST(root2, 10)
printBSTAsc(root2)

7 10 15 
7 15 

Floor of a value in BST

In [None]:
#Naive Solution: Time - O(Nodes), Space - O(height)
import math
def floorBST(root, key, floorVal = -math.inf):
    if root == None:
        return None if floorVal == -math.inf else floorVal
    
    if (root.data <= key) and (root.data > floorVal):
        floorVal = root.data
    floorVal = floorBST(root.left, key, floorVal)
    floorVal = floorBST(root.right, key, floorVal)
    return floorVal
        
print(floorBST(root,79))
print(floorBST(root,7))
print(floorBST(root,19))
print(floorBST(root,121))
    

75
None
8
120


In [None]:
#Optimal Approach: Time - O(height), Space - O(1). A small modification to this program - I am returning the floor node rather than the floor value
def floorBST1(root, key):
    curr = root
    floor = None
    while curr != None:
        if key == curr.data:
            return curr
        elif key < curr.data:
            curr = curr.left
        else:
            floor = curr
            curr = curr.right
            
    return floor

temp = floorBST1(root, 79)
if temp:
    print(temp.data)
    temp = floorBST1(root, 79)
temp = floorBST1(root, 19)
if temp:
    print(temp.data)
temp = floorBST1(root, 7)
if temp:
    print(temp.data)
temp = floorBST1(root, 121)
if temp:
    print(temp.data)


75
8
120


Ceil of a value in BST

In [None]:
#Optimal Approach: Time - O(height), Space - O(1)
def ceilBST(root, key):
    ceil = None
    curr = root
    while curr != None:
        if key == curr.data:
            ceil = key
            return ceil
        elif key < curr.data:
            ceil = curr.data
            curr = curr.left
        else:
            curr = curr.right
    
    return ceil

print(ceilBST(root,50))
print(ceilBST(root,7))
print(ceilBST(root,19))
print(ceilBST(root,121))

60
8
20
None


AVL Tree

In [None]:
class AVLNode:
    def __init__(self, key):
        self.data = key
        self.height = 1
        self.left = None
        self.right = None
        
class AVLTree:
    def __init__(self):
        self.root = None
        
    def insert(self, key):
        self.root = self.insertUtil(self.root, key)
        
    def delete(self, key):
        self.root = self.deleteUtil(self.root, key)
    
    def search(self, key):
        if self.root == None:
            print("Node with key {key} not found")
            return False
        curr = self.root
        while curr:
            if key == curr.data:
                return True
            elif key < curr.data:
                curr = curr.left
            else:
                curr = curr.right
        return False
    
    def printTree(self):
        self.printTreeUtil(self.root)
            
    #Helper functions
    def printTreeUtil(self, root):
        if root:
            self.printTreeUtil(root.left)
            print(root.data, end = " ")
            self.printTreeUtil(root.right)
            
    def getHeight(self, node):
        return node.height if node else 0
    
    def getBalanceFactor(self, node):
        return (self.getHeight(node.left) - self.getHeight(node.right)) if node else 0
    
    def leftRotate(self, nodeY):
        nodeX = nodeY.right
        nodeXLST = nodeX.left
        nodeX.left = nodeY
        nodeY.right = nodeXLST
        nodeY.height = 1 + max(self.getHeight(nodeY.left), self.getHeight(nodeY.right))
        nodeX.height = 1 + max(self.getHeight(nodeX.left), self.getHeight(nodeX.right))
        return nodeX
    
    def rightRotate(self, nodeX):
        nodeY = nodeX.left
        nodeYRST = nodeY.right
        nodeY.right = nodeX
        nodeX.left = nodeYRST
        nodeX.height = 1 + max(self.getHeight(nodeX.left), self.getHeight(nodeX.right))
        nodeY.height = 1 + max(self.getHeight(nodeY.left), self.getHeight(nodeY.right))
        return nodeY
    
    def inorderSuccessor(self, node):
        curr = node
        while curr.left != None:
            curr = curr.left
        return curr.data
    
    def insertUtil(self, root, key):
        if not root:
            return AVLNode(key)
        if key < root.data:
            root.left = self.insertUtil(root.left, key)
        elif key > root.data:
            root.right = self.insertUtil(root.right, key)
        else:
            return root
        
        root.height = 1 + max(self.getHeight(root.left), self.getHeight(root.right))
        balanceFactor = self.getBalanceFactor(root)
        
        if balanceFactor > 1:                                   #Left heavy imbalances (i.e left-left or left-right)
            if key < root.left.data:
                return self.rightRotate(root)
            
            if key > root.left.data:
                root.left = self.leftRotate(root.left)
                return self.rightRotate(root)
        
        if balanceFactor < -1:                                  #Right heavy imbalances (i.e right-right or right-left)
            if key > root.right.data:
                return self.leftRotate(root)
            
            if key < root.right.data:
                root.right = self.rightRotate(root.right)
                return self.leftRotate(root)
            
        return root
    
    def deleteUtil(self, root, key):
        if not root:
            return root
        
        if key < root.data:
            root.left = self.deleteUtil(root.left, key)
        elif key > root.data:
            root.right = self.deleteUtil(root.right, key)
        else:
            if root.left == None:
                return root.right
            elif root.right == None:
                return root.left
            else:
                inorderSuccessorData = self.inorderSuccessor(root.right)
                root.data = inorderSuccessorData
                root.right = self.deleteUtil(root.right, inorderSuccessorData)
                return root
        
        root.height = 1 + max(self.getHeight(root.left), self.getHeight(root.right))
        balanceFactor = self.getBalanceFactor(root)
        
        if balanceFactor > 1:                                       #Left heavy imbalances (i.e left-left or left-right)
            if self.getBalanceFactor(root.left) >= 0:
                return self.rightRotate(root)
            if self.getBalanceFactor(root.left) < 0:
                root.left = self.leftRotate(root.left)
                return self.rightRotate(root)
            
        if balanceFactor < -1:                                       #Right heavy imbalances (i.e right-right or right-left)
            if self.getBalanceFactor(root.right) <= 0:
                return self.leftRotate(root)
            if self.getBalanceFactor(root.right) > 0:
                root.right = self.rightRotate(root.right)
                return self.leftRotate(root)
            
        return root     

In [27]:
tree1 = AVLTree()
for key in [1,2,3,4,5,6,7,8]:
    tree1.insert(key)

tree1.printTree()
print(f"\nHeight: {tree1.getHeight(tree1.root)}")
print("Search 15:", tree1.search(15))
print("Search 7:", tree1.search(7))
tree1.delete(6)
tree1.printTree()
print(f"\nHeight: {tree1.getHeight(tree1.root)}")

1 2 3 4 5 6 7 8 
Height: 4
Search 15: False
Search 7: True
1 2 3 4 5 7 8 
Height: 4
