# LinkedIn Learning
# Python Data Structures --  Trees

## 2. Modifying Trees

### A. Adding Nodes

In [38]:
class Node:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
        
    def search(self, target):
        if self.data == target:
            print("Found it!")
            return self
        if self.left and self.data > target:
            return self.left.search(target)
        if self.right and self.data < target:
            return self.right.search(target)
        print("Value not found in tree")

    def add(self, data):
        if self.data == data:
            return 
        if data < self.data:
            if self.left is None:
                self.left = Node(data)
            else:
                self.left.add(data)
        else:  # Ensuring BST property
            if self.right is None:
                self.right = Node(data)
            else:
                self.right.add(data)

    def traversePreorder(self):
        print(self.data, end=" ")
        if self.left:
            self.left.traversePreorder()
        if self.right:
            self.right.traversePreorder()
        
    def traverseInorder(self):
        if self.left:
            self.left.traverseInorder()
        print(self.data, end=" ")
        if self.right:
            self.right.traverseInorder()
        
    def traversePostorder(self):
        if self.left:
            self.left.traversePostorder() 
        if self.right:
            self.right.traversePostorder()
        print(self.data, end=" ")
        
    def height(self, h=0):
        leftHeight = self.left.height(h+1) if self.left else h
        rightHeight = self.right.height(h+1) if self.right else h
        return max(leftHeight, rightHeight)
        
    def getNodesAtDepth(self, depth, nodes=None):
        if nodes is None:
            nodes = []
        if depth == 0:
            nodes.append(self.data)
            return nodes
        if self.left:
            self.left.getNodesAtDepth(depth-1, nodes)
        if self.right:
            self.right.getNodesAtDepth(depth-1, nodes)
        return nodes

In [39]:
class Tree:
    def __init__(self, root, name=''):
        self.root = root
        self.name = name

    def search(self, target):
        return self.root.search(target)

    def add(self, data):
        self.root.add(data)

    def traversePreorder(self):
        self.root.traversePreorder()
        print()

    def traverseInorder(self):
        self.root.traverseInorder()
        print()

    def traversePostorder(self):
        self.root.traversePostorder()
        print()

    def height(self):
        return self.root.height()

    def getNodesAtDepth(self, depth):
        return self.root.getNodesAtDepth(depth)

    def _nodeToChar(self, n, spacing):
        if n is None:
            return '_' + (' ' * spacing)
        spacing = max(spacing - len(str(n)) + 1, 1)
        return str(n) + (' ' * spacing)

    def print(self, label=''):
        print(self.name + ' ' + label)
        height = self.root.height()
        spacing = 3
        width = int((2 ** height - 1) * (spacing + 1) + 1)
        offset = int((width - 1) / 2)

        for depth in range(0, height + 1):
            if depth > 0:
                print(' ' * (offset + 1) + (' ' * (spacing + 2)).join(['/' + (' ' * (spacing - 2)) + '\\'] * (2 ** (depth - 1))))
            row = self.root.getNodesAtDepth(depth, [])
            print((' ' * offset) + ''.join([self._nodeToChar(n, spacing) for n in row]))
            spacing = offset + 1
            offset = max(int(offset / 2) - 1, 0)
        print()



In [51]:
tree_1 = Tree(Node(50))
tree_1.add(25)
tree_1.add(75)

In [52]:
tree_1.print()

 
  50  
 / \
25  75  



In [53]:
tree_1.add(10)

In [54]:
tree_1.print()

 
      50  
   /     \
  25      75      
 / \     / \
10  



In [55]:
tree_1.add(76)

In [56]:
tree_1.print()

 
      50  
   /     \
  25      75      
 / \     / \
10  76  



In [57]:
tree_1.add(75)

In [58]:
tree_1.print()

 
      50  
   /     \
  25      75      
 / \     / \
10  76  



In [48]:
print("Preorder Traversal:")
tree_1.traversePreorder()

Preorder Traversal:
50 25 10 75 76 


In [49]:
print("Inorder Traversal:")
tree_1.traverseInorder()

Inorder Traversal:
10 25 50 75 76 


In [50]:
print("Postorder Traversal:")
tree_1.traversePostorder()

Postorder Traversal:
10 25 76 75 50 


### B. Deleting Nodes

In [60]:
class Node:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
        
    def search(self, target):
        if self.data == target:
            print("Found it!")
            return self
        if self.left and self.data > target:
            return self.left.search(target)
        if self.right and self.data < target:
            return self.right.search(target)
        print("Value not found in tree")

    def add(self, data):
        if self.data == data:
            return 
        if data < self.data:
            if self.left is None:
                self.left = Node(data)
            else:
                self.left.add(data)
        else:  # Ensuring BST property
            if self.right is None:
                self.right = Node(data)
            else:
                self.right.add(data)

    def traversePreorder(self):
        print(self.data, end=" ")
        if self.left:
            self.left.traversePreorder()
        if self.right:
            self.right.traversePreorder()
        
    def traverseInorder(self):
        if self.left:
            self.left.traverseInorder()
        print(self.data, end=" ")
        if self.right:
            self.right.traverseInorder()
        
    def traversePostorder(self):
        if self.left:
            self.left.traversePostorder() 
        if self.right:
            self.right.traversePostorder()
        print(self.data, end=" ")
        
    def height(self, h=0):
        leftHeight = self.left.height(h+1) if self.left else h
        rightHeight = self.right.height(h+1) if self.right else h
        return max(leftHeight, rightHeight)
        
    def getNodesAtDepth(self, depth, nodes=None):
        if nodes is None:
            nodes = []
        if depth == 0:
            nodes.append(self.data)
            return nodes
        if self.left:
            self.left.getNodesAtDepth(depth-1, nodes)
        if self.right:
            self.right.getNodesAtDepth(depth-1, nodes)
        return nodes

    def findMin(self):
        if self.left:
            return self.left.findMin()
        return self.data
    
    def delete(self, target):
        if self.data == target: 
            # Do the deletion here
            if self.left and self.right:
                # RTFM
                minValue = self.right.findMin()
                self.data = minValue
                self.right = self.right.delete(minValue)
                return self
            else:
                return self.left or self.right
                
        if self.right and target > self.data:
            self.right = self.right.delete(target)
        if self.left and target < self.data:
            self.left = self.left.delete(target)
        return self

In [91]:
class Tree:
    def __init__(self, root, name=''):
        self.root = root
        self.name = name

    def search(self, target):
        return self.root.search(target)

    def add(self, data):
        self.root.add(data)

    #def traversePreorder(self):
       # self.root.traversePreorder()
       # print()

    def traverseInorder(self):
        self.root.traverseInorder()
        print()

    #def traversePostorder(self):
        #self.root.traversePostorder()
        #print()

    def height(self):
        return self.root.height()

    def getNodesAtDepth(self, depth):
        return self.root.getNodesAtDepth(depth)

    def _nodeToChar(self, n, spacing):
        if n is None:
            return '_' + (' ' * spacing)
        spacing = max(spacing - len(str(n)) + 1, 1)
        return str(n) + (' ' * spacing)

    def print(self, label=''):
        print(self.name + ' ' + label)
        height = self.root.height()
        spacing = 3
        width = int((2 ** height - 1) * (spacing + 1) + 1)
        offset = int((width - 1) / 2)

        for depth in range(0, height + 1):
            if depth > 0:
                print(' ' * (offset + 1) + (' ' * (spacing + 2)).join(['/' + (' ' * (spacing - 2)) + '\\'] * (2 ** (depth - 1))))
            row = self.root.getNodesAtDepth(depth, [])
            print((' ' * offset) + ''.join([self._nodeToChar(n, spacing) for n in row]))
            spacing = offset + 1
            offset = max(int(offset / 2) - 1, 0)
        print()

    def delete(self, target):
        self.root = self.root.delete(target)


In [92]:
tree_2 = Tree(Node(50))

In [93]:
tree_2.root.left = Node(25)
tree_2.root.right = Node(75)

In [94]:
tree_2.root.right.left = Node(67)
tree_2.root.right.right = Node(100)

In [95]:
tree_2.root.right.right.right = Node(120)
tree_2.root.right.right.left = Node(80)

In [96]:
tree_2.root.right.right.left.right = Node(92)

In [97]:
tree_2.print()

 
                              50  
               /                             \
              25                              75                              
       /             \                 /             \
      67              100             
   /     \         /     \         /     \         /     \
  80      120     
 / \     / \     / \     / \     / \     / \     / \     / \
92  



In [98]:
print("Inorder Traversal of BST:")
tree_2.traverseInorder()

Inorder Traversal of BST:
25 50 67 75 80 92 100 120 


In [99]:
tree_2.delete(75)

In [100]:
tree_2.print()

 
              50  
       /             \
      25              80              
   /     \         /     \
  67      100     
 / \     / \     / \     / \
92  120 



In [101]:
tree_2.delete(50)

In [102]:
tree_2.print()

 
              67  
       /             \
      25              80              
   /     \         /     \
  100     
 / \     / \     / \     / \
92  120 



### C. Detecting Unbalanced tree

In [103]:
class Node:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
        
    def search(self, target):
        if self.data == target:
            print("Found it!")
            return self
        if self.left and self.data > target:
            return self.left.search(target)
        if self.right and self.data < target:
            return self.right.search(target)
        print("Value not found in tree")

    def add(self, data):
        if self.data == data:
            return 
        if data < self.data:
            if self.left is None:
                self.left = Node(data)
            else:
                self.left.add(data)
        else:  # Ensuring BST property
            if self.right is None:
                self.right = Node(data)
            else:
                self.right.add(data)

    def traversePreorder(self):
        print(self.data, end=" ")
        if self.left:
            self.left.traversePreorder()
        if self.right:
            self.right.traversePreorder()
        
    def traverseInorder(self):
        if self.left:
            self.left.traverseInorder()
        print(self.data, end=" ")
        if self.right:
            self.right.traverseInorder()
        
    def traversePostorder(self):
        if self.left:
            self.left.traversePostorder() 
        if self.right:
            self.right.traversePostorder()
        print(self.data, end=" ")
        
    def height(self, h=0):
        leftHeight = self.left.height(h+1) if self.left else h
        rightHeight = self.right.height(h+1) if self.right else h
        return max(leftHeight, rightHeight)
        
    def getNodesAtDepth(self, depth, nodes=None):
        if nodes is None:
            nodes = []
        if depth == 0:
            nodes.append(self.data)
            return nodes
        if self.left:
            self.left.getNodesAtDepth(depth-1, nodes)
        if self.right:
            self.right.getNodesAtDepth(depth-1, nodes)
        return nodes

    def findMin(self):
        if self.left:
            return self.left.findMin()
        return self.data
    
    def delete(self, target):
        if self.data == target: 
            # Do the deletion here
            if self.left and self.right:
                # RTFM
                minValue = self.right.findMin()
                self.data = minValue
                self.right = self.right.delete(minValue)
                return self
            else:
                return self.left or self.right
                
        if self.right and target > self.data:
            self.right = self.right.delete(target)
        if self.left and target < self.data:
            self.left = self.left.delete(target)
        return self

    def isBalanced(self):
        leftHeight = self.left.height() + 1 if self.left else 0
        rightHeight = self.right.height() + 1 if self.right else 0
        return abs(leftHeight - rightHeight) < 2

In [104]:
class Tree:
    def __init__(self, root, name=''):
        self.root = root
        self.name = name

    def search(self, target):
        return self.root.search(target)

    def add(self, data):
        self.root.add(data)

    #def traversePreorder(self):
       # self.root.traversePreorder()
       # print()

    def traverseInorder(self):
        self.root.traverseInorder()
        print()

    #def traversePostorder(self):
        #self.root.traversePostorder()
        #print()

    def height(self):
        return self.root.height()

    def getNodesAtDepth(self, depth):
        return self.root.getNodesAtDepth(depth)

    def _nodeToChar(self, n, spacing):
        if n is None:
            return '_' + (' ' * spacing)
        spacing = max(spacing - len(str(n)) + 1, 1)
        return str(n) + (' ' * spacing)

    def print(self, label=''):
        print(self.name + ' ' + label)
        height = self.root.height()
        spacing = 3
        width = int((2 ** height - 1) * (spacing + 1) + 1)
        offset = int((width - 1) / 2)

        for depth in range(0, height + 1):
            if depth > 0:
                print(' ' * (offset + 1) + (' ' * (spacing + 2)).join(['/' + (' ' * (spacing - 2)) + '\\'] * (2 ** (depth - 1))))
            row = self.root.getNodesAtDepth(depth, [])
            print((' ' * offset) + ''.join([self._nodeToChar(n, spacing) for n in row]))
            spacing = offset + 1
            offset = max(int(offset / 2) - 1, 0)
        print()

    def delete(self, target):
        self.root = self.root.delete(target)


In [105]:
tree_3 = Tree(Node(50), 'An Unbalanced Tree')
tree_3.root.left = Node(25)
tree_3.root.right = Node(75)
tree_3.root.right.right = Node(100)
tree_3.root.right.right.right = Node(150)
tree_3.print()

An Unbalanced Tree 
              50  
       /             \
      25              75              
   /     \         /     \
  100     
 / \     / \     / \     / \
150 



In [106]:
print(tree_3.root.isBalanced())

False


In [107]:
print(tree_3.root.left.isBalanced())

True


In [108]:
print(tree_3.root.right.isBalanced())

False


### D. Challenge 2 --> Adding a balance Indicator to the printed tree

In [113]:
class Node:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
        
    def search(self, target):
        if self.data == target:
            print("Found it!")
            return self
        if self.left and self.data > target:
            return self.left.search(target)
        if self.right and self.data < target:
            return self.right.search(target)
        print("Value not found in tree")

    def add(self, data):
        if self.data == data:
            return 
        if data < self.data:
            if self.left is None:
                self.left = Node(data)
            else:
                self.left.add(data)
        else:  # Ensuring BST property
            if self.right is None:
                self.right = Node(data)
            else:
                self.right.add(data)

    def traversePreorder(self):
        print(self.data, end=" ")
        if self.left:
            self.left.traversePreorder()
        if self.right:
            self.right.traversePreorder()
        
    def traverseInorder(self):
        if self.left:
            self.left.traverseInorder()
        print(self.data, end=" ")
        if self.right:
            self.right.traverseInorder()
        
    def traversePostorder(self):
        if self.left:
            self.left.traversePostorder() 
        if self.right:
            self.right.traversePostorder()
        print(self.data, end=" ")
        
    def height(self, h=0):
        leftHeight = self.left.height(h+1) if self.left else h
        rightHeight = self.right.height(h+1) if self.right else h
        return max(leftHeight, rightHeight)
        
    def getNodesAtDepth(self, depth, nodes=[]):
        if depth == 0:
            nodes.append(self)
            return nodes
            
        if self.left:
            self.left.getNodesAtDepth(depth-1, nodes)
        else:
            nodes.extend([None]*2**(depth-1))
            
        if self.right:
            self.right.getNodesAtDepth(depth-1, nodes)
        else:
            nodes.extend([None]*2**(depth-1))
        return nodes

    def findMin(self):
        if self.left:
            return self.left.findMin()
        return self.data
    
    def delete(self, target):
        if self.data == target: 
            # Do the deletion here
            if self.left and self.right:
                # RTFM
                minValue = self.right.findMin()
                self.data = minValue
                self.right = self.right.delete(minValue)
                return self
            else:
                return self.left or self.right
                
        if self.right and target > self.data:
            self.right = self.right.delete(target)
        if self.left and target < self.data:
            self.left = self.left.delete(target)
        return self

    def isBalanced(self):
        leftHeight = self.left.height() + 1 if self.left else 0
        rightHeight = self.right.height() + 1 if self.right else 0
        return abs(leftHeight - rightHeight) < 2

    def toStr(self):
        if not self.isBalanced():
            return str(self.data)+'*'
        return str(self.data)

In [114]:
class Tree:
    def __init__(self, root, name=''):
        self.root = root
        self.name = name

    def search(self, target):
        return self.root.search(target)

    def add(self, data):
        self.root.add(data)

    #def traversePreorder(self):
       # self.root.traversePreorder()
       # print()

    def traverseInorder(self):
        self.root.traverseInorder()
        print()

    #def traversePostorder(self):
        #self.root.traversePostorder()
        #print()

    def height(self):
        return self.root.height()

    def getNodesAtDepth(self, depth):
        return self.root.getNodesAtDepth(depth)

    def _nodeToChar(self, n, spacing):
        if n is None:
            return '_' + (' ' * spacing)
        spacing = spacing - len(n.toStr()) + 1
        return n.toStr() + (' ' * spacing)

    def print(self, label=''):
        print(self.name + ' ' + label)
        height = self.root.height()
        spacing = 3
        width = int((2 ** height - 1) * (spacing + 1) + 1)
        offset = int((width - 1) / 2)

        for depth in range(0, height + 1):
            if depth > 0:
                print(' ' * (offset + 1) + (' ' * (spacing + 2)).join(['/' + (' ' * (spacing - 2)) + '\\'] * (2 ** (depth - 1))))
            row = self.root.getNodesAtDepth(depth, [])
            print((' ' * offset) + ''.join([self._nodeToChar(n, spacing) for n in row]))
            spacing = offset + 1
            offset = max(int(offset / 2) - 1, 0)
        print()

    def delete(self, target):
        self.root = self.root.delete(target)


In [115]:
tree_4 = Tree(Node(50), 'An Unbalanced Tree')
tree_4.root.left = Node(25)
tree_4.root.right = Node(75)
tree_4.root.right.right = Node(100)
tree_4.root.right.right.right = Node(150)
tree_4.print()

An Unbalanced Tree 
              50* 
       /             \
      25              75*             
   /     \         /     \
  _       _       _       100     
 / \     / \     / \     / \
_   _   _   _   _   _   _   150 



## 3. Rebalancing Trees

### A. Rotating Trees in Python

In [116]:
class Node:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
        
    def search(self, target):
        if self.data == target:
            print("Found it!")
            return self
        if self.left and self.data > target:
            return self.left.search(target)
        if self.right and self.data < target:
            return self.right.search(target)
        print("Value not found in tree")

    def add(self, data):
        if self.data == data:
            return 
        if data < self.data:
            if self.left is None:
                self.left = Node(data)
            else:
                self.left.add(data)
        else:  # Ensuring BST property
            if self.right is None:
                self.right = Node(data)
            else:
                self.right.add(data)

    def traversePreorder(self):
        print(self.data, end=" ")
        if self.left:
            self.left.traversePreorder()
        if self.right:
            self.right.traversePreorder()
        
    def traverseInorder(self):
        if self.left:
            self.left.traverseInorder()
        print(self.data, end=" ")
        if self.right:
            self.right.traverseInorder()
        
    def traversePostorder(self):
        if self.left:
            self.left.traversePostorder() 
        if self.right:
            self.right.traversePostorder()
        print(self.data, end=" ")
        
    def height(self, h=0):
        leftHeight = self.left.height(h+1) if self.left else h
        rightHeight = self.right.height(h+1) if self.right else h
        return max(leftHeight, rightHeight)
        
    def getNodesAtDepth(self, depth, nodes=[]):
        if depth == 0:
            nodes.append(self)
            return nodes
            
        if self.left:
            self.left.getNodesAtDepth(depth-1, nodes)
        else:
            nodes.extend([None]*2**(depth-1))
            
        if self.right:
            self.right.getNodesAtDepth(depth-1, nodes)
        else:
            nodes.extend([None]*2**(depth-1))
        return nodes

    def findMin(self):
        if self.left:
            return self.left.findMin()
        return self.data
    
    def delete(self, target):
        if self.data == target: 
            # Do the deletion here
            if self.left and self.right:
                # RTFM
                minValue = self.right.findMin()
                self.data = minValue
                self.right = self.right.delete(minValue)
                return self
            else:
                return self.left or self.right
                
        if self.right and target > self.data:
            self.right = self.right.delete(target)
        if self.left and target < self.data:
            self.left = self.left.delete(target)
        return self

    def isBalanced(self):
        leftHeight = self.left.height() + 1 if self.left else 0
        rightHeight = self.right.height() + 1 if self.right else 0
        return abs(leftHeight - rightHeight) < 2

    def toStr(self):
        if not self.isBalanced():
            return str(self.data)+'*'
        return str(self.data)

In [117]:
class Tree:
    def __init__(self, root, name=''):
        self.root = root
        self.name = name

    def search(self, target):
        return self.root.search(target)

    def add(self, data):
        self.root.add(data)

    #def traversePreorder(self):
       # self.root.traversePreorder()
       # print()

    def traverseInorder(self):
        self.root.traverseInorder()
        print()

    #def traversePostorder(self):
        #self.root.traversePostorder()
        #print()

    def height(self):
        return self.root.height()

    def getNodesAtDepth(self, depth):
        return self.root.getNodesAtDepth(depth)

    def _nodeToChar(self, n, spacing):
        if n is None:
            return '_' + (' ' * spacing)
        spacing = spacing - len(n.toStr()) + 1
        return n.toStr() + (' ' * spacing)

    def print(self, label=''):
        print(self.name + ' ' + label)
        height = self.root.height()
        spacing = 3
        width = int((2 ** height - 1) * (spacing + 1) + 1)
        offset = int((width - 1) / 2)

        for depth in range(0, height + 1):
            if depth > 0:
                print(' ' * (offset + 1) + (' ' * (spacing + 2)).join(['/' + (' ' * (spacing - 2)) + '\\'] * (2 ** (depth - 1))))
            row = self.root.getNodesAtDepth(depth, [])
            print((' ' * offset) + ''.join([self._nodeToChar(n, spacing) for n in row]))
            spacing = offset + 1
            offset = max(int(offset / 2) - 1, 0)
        print()

    def delete(self, target):
        self.root = self.root.delete(target)


In [118]:
def rotateRight(root):
    pivot = root.left 
    reattachNode = pivot.right 
    root.left = reattachNode
    pivot.right = root
    return pivot

In [119]:
unbalancedLeftLeft = Tree(Node(30), 'UNBALANCED LEFT LEFT')
unbalancedLeftLeft.root.left = Node(20)
unbalancedLeftLeft.root.left.right = Node(21)
unbalancedLeftLeft.root.left.left = Node(10)
unbalancedLeftLeft.root.left.left.left = Node(9)
unbalancedLeftLeft.root.left.left.right = Node(11)
unbalancedLeftLeft.print()

UNBALANCED LEFT LEFT 
              30* 
       /             \
      20              _               
   /     \         /     \
  10      21      _       _       
 / \     / \     / \     / \
9   11  _   _   _   _   _   _   



In [120]:
unbalancedLeftLeft.root = rotateRight(unbalancedLeftLeft.root)

In [121]:
unbalancedLeftLeft.print()

UNBALANCED LEFT LEFT 
      20  
   /     \
  10      30      
 / \     / \
9   11  21  _   



In [122]:
def rotateLeft(root):
    pivot = root.right
    reattachNode = pivot.left
    root.right = reattachNode
    pivot.left = root
    return pivot

In [127]:
unbalancedRightRight = Tree(Node(10), 'UNBALANCED RIGHT RIGHT')
unbalancedRightRight.root.right = Node(20)
unbalancedRightRight.root.right.left = Node(19)
unbalancedRightRight.root.right.right = Node(30)
unbalancedRightRight.root.right.right.left = Node(29)
unbalancedRightRight.root.right.right.right = Node(31)
unbalancedRightRight.print()

UNBALANCED RIGHT RIGHT 
              10* 
       /             \
      _               20              
   /     \         /     \
  _       _       19      30      
 / \     / \     / \     / \
_   _   _   _   _   _   29  31  



In [128]:
unbalancedRightRight.root = rotateLeft(unbalancedRightRight.root)
unbalancedRightRight.print()

UNBALANCED RIGHT RIGHT 
      20  
   /     \
  10      30      
 / \     / \
_   19  29  31  



In [129]:
unbalancedLeftRight = Tree(Node(30), 'UNBALANCED LEFT RIGHT')
unbalancedLeftRight.root.right = Node(31)
unbalancedLeftRight.root.left = Node(10)
unbalancedLeftRight.root.left.right = Node(20)
unbalancedLeftRight.root.left.left = Node(9)
unbalancedLeftRight.root.left.right.left = Node(19)
unbalancedLeftRight.root.left.right.right = Node(21)
unbalancedLeftRight.print()

UNBALANCED LEFT RIGHT 
              30* 
       /             \
      10              31              
   /     \         /     \
  9       20      _       _       
 / \     / \     / \     / \
_   _   19  21  _   _   _   _   



In [130]:
unbalancedLeftRight.root.left = rotateLeft(unbalancedLeftRight.root.left)
unbalancedLeftRight.root = rotateRight(unbalancedLeftRight.root)
unbalancedLeftRight.print()

UNBALANCED LEFT RIGHT 
      20  
   /     \
  10      30      
 / \     / \
9   19  21  31  



In [131]:
unbalancedRightLeft = Tree(Node(30), 'UNBALANCED RIGHT LEFT')
unbalancedRightLeft.root.left = Node(31)
unbalancedRightLeft.root.right = Node(10)
unbalancedRightLeft.root.right.left = Node(20)
unbalancedRightLeft.root.right.right = Node(9)
unbalancedRightLeft.root.right.left.right = Node(19)
unbalancedRightLeft.root.right.left.left = Node(21)
unbalancedRightLeft.print()

UNBALANCED RIGHT LEFT 
              30* 
       /             \
      31              10              
   /     \         /     \
  _       _       20      9       
 / \     / \     / \     / \
_   _   _   _   21  19  _   _   



In [132]:
unbalancedRightLeft.root.right = rotateRight(unbalancedRightLeft.root.right)
unbalancedRightLeft.root = rotateLeft(unbalancedRightLeft.root)
unbalancedRightLeft.print()

UNBALANCED RIGHT LEFT 
      20  
   /     \
  30      10      
 / \     / \
31  21  19  9   



### B. Fixing a tree with multiple points of Imbalance

In [134]:
class Node:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
        
    def search(self, target):
        if self.data == target:
            print("Found it!")
            return self
        if self.left and self.data > target:
            return self.left.search(target)
        if self.right and self.data < target:
            return self.right.search(target)
        print("Value not found in tree")

    def add(self, data):
        if self.data == data:
            return 
        if data < self.data:
            if self.left is None:
                self.left = Node(data)
            else:
                self.left.add(data)
        else:  # Ensuring BST property
            if self.right is None:
                self.right = Node(data)
            else:
                self.right.add(data)

    def traversePreorder(self):
        print(self.data, end=" ")
        if self.left:
            self.left.traversePreorder()
        if self.right:
            self.right.traversePreorder()
        
    def traverseInorder(self):
        if self.left:
            self.left.traverseInorder()
        print(self.data, end=" ")
        if self.right:
            self.right.traverseInorder()
        
    def traversePostorder(self):
        if self.left:
            self.left.traversePostorder() 
        if self.right:
            self.right.traversePostorder()
        print(self.data, end=" ")
        
    def height(self, h=0):
        leftHeight = self.left.height(h+1) if self.left else h
        rightHeight = self.right.height(h+1) if self.right else h
        return max(leftHeight, rightHeight)
        
    def getNodesAtDepth(self, depth, nodes=[]):
        if depth == 0:
            nodes.append(self)
            return nodes
            
        if self.left:
            self.left.getNodesAtDepth(depth-1, nodes)
        else:
            nodes.extend([None]*2**(depth-1))
            
        if self.right:
            self.right.getNodesAtDepth(depth-1, nodes)
        else:
            nodes.extend([None]*2**(depth-1))
        return nodes

    def findMin(self):
        if self.left:
            return self.left.findMin()
        return self.data
    
    def delete(self, target):
        if self.data == target: 
            # Do the deletion here
            if self.left and self.right:
                # RTFM
                minValue = self.right.findMin()
                self.data = minValue
                self.right = self.right.delete(minValue)
                return self
            else:
                return self.left or self.right
                
        if self.right and target > self.data:
            self.right = self.right.delete(target)
        if self.left and target < self.data:
            self.left = self.left.delete(target)
        return self

    def isBalanced(self):
        leftHeight = self.left.height() + 1 if self.left else 0
        rightHeight = self.right.height() + 1 if self.right else 0
        return abs(leftHeight - rightHeight) < 2

    def toStr(self):
        if not self.isBalanced():
            return str(self.data)+'*'
        return str(self.data)

    def getLeftRightHeightDifference(self):
        leftHeight = self.left.height() + 1 if self.left else 0
        rightHeight = self.right.height() + 1 if self.right else 0
        return leftHeight - rightHeight
        
    def fixImbalanceIfExists(self):
        if self.getLeftRightHeightDifference() > 1:
            # Left Imbalance
            if self.left.getLeftRightHeightDifference() > 0:
                # Left Left Imbalance
                return rotateRight(self)
            else:
                # Left Right Imbalance
                self.left = rotateLeft(self.left)
                return rotateRight(self)
        elif self.getLeftRightHeightDifference() < -1:
            #Right Imbalance
            if self.right.getLeftRightHeightDifference() < 0:
                # Right Right Imbalance
                return rotateLeft(self)
            else:
                # Right Left Imbalance
                self.right = rotateRight(self.right)
                return rotateLeft(self)
        return self

    def rebalance(self):
        if self.left:
            self.left.rebalance()
            self.left = self.left.fixImbalanceIfExists()
        if self.right:
            self.right.rebalance()
            self.right = self.right.fixImbalanceIfExists()

In [135]:
class Tree:
    def __init__(self, root, name=''):
        self.root = root
        self.name = name

    def search(self, target):
        return self.root.search(target)

    def add(self, data):
        self.root.add(data)

    #def traversePreorder(self):
       # self.root.traversePreorder()
       # print()

    def traverseInorder(self):
        self.root.traverseInorder()
        print()

    #def traversePostorder(self):
        #self.root.traversePostorder()
        #print()

    def height(self):
        return self.root.height()

    def getNodesAtDepth(self, depth):
        return self.root.getNodesAtDepth(depth)

    def _nodeToChar(self, n, spacing):
        if n is None:
            return '_' + (' ' * spacing)
        spacing = spacing - len(n.toStr()) + 1
        return n.toStr() + (' ' * spacing)

    def print(self, label=''):
        print(self.name + ' ' + label)
        height = self.root.height()
        spacing = 3
        width = int((2 ** height - 1) * (spacing + 1) + 1)
        offset = int((width - 1) / 2)

        for depth in range(0, height + 1):
            if depth > 0:
                print(' ' * (offset + 1) + (' ' * (spacing + 2)).join(['/' + (' ' * (spacing - 2)) + '\\'] * (2 ** (depth - 1))))
            row = self.root.getNodesAtDepth(depth, [])
            print((' ' * offset) + ''.join([self._nodeToChar(n, spacing) for n in row]))
            spacing = offset + 1
            offset = max(int(offset / 2) - 1, 0)
        print()

    def delete(self, target):
        self.root = self.root.delete(target)

    def rebalance(self):
        self.root.rebalance()
        self.root = self.root.fixImbalanceIfExists()

In [136]:
def rotateRight(root):
    pivot = root.left 
    reattachNode = pivot.right 
    root.left = reattachNode
    pivot.right = root
    return pivot

In [137]:
def rotateLeft(root):
    pivot = root.right
    reattachNode = pivot.left
    root.right = reattachNode
    pivot.left = root
    return pivot

In [138]:
tree_5 = Tree(Node(50))
tree_5.root.left = Node(25)
tree_5.root.right = Node(75)
tree_5.root.left.left = Node(10)
tree_5.root.left.right = Node(35)
tree_5.root.left.right.left = Node(30)
tree_5.root.left.left.left = Node(5)
tree_5.root.left.left.right = Node(13)

tree_5.root.left.left.left.left = Node(2)
tree_5.root.left.left.left.left.left = Node(1)
tree_5.print()

 
                                                              50* 
                               /                                                             \
                              25*                                                             75                                                              
               /                             \                                 /                             \
              10*                             35                              _                               _                               
       /             \                 /             \                 /             \                 /             \
      5*              13              30              _               _               _               _               _               
   /     \         /     \         /     \         /     \         /     \         /     \         /     \         /     \
  2       _       _       _       _       _   

In [142]:
tree_5.rebalance()
tree_5.print()

 
              25  
       /             \
      10              50              
   /     \         /     \
  2       13      35      75      
 / \     / \     / \     / \
1   5   _   _   30  _   _   _   



### C. Challenge 3 --> Smarter Automated Rebalancing

In [157]:
class Node:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
        
    def search(self, target):
        if self.data == target:
            print("Found it!")
            return self
        if self.left and self.data > target:
            return self.left.search(target)
        if self.right and self.data < target:
            return self.right.search(target)
        print("Value not found in tree")


    def traversePreorder(self):
        print(self.data, end=" ")
        if self.left:
            self.left.traversePreorder()
        if self.right:
            self.right.traversePreorder()
        
    def traverseInorder(self):
        if self.left:
            self.left.traverseInorder()
        print(self.data, end=" ")
        if self.right:
            self.right.traverseInorder()
        
    def traversePostorder(self):
        if self.left:
            self.left.traversePostorder() 
        if self.right:
            self.right.traversePostorder()
        print(self.data, end=" ")
        
    def height(self, h=0):
        leftHeight = self.left.height(h+1) if self.left else h
        rightHeight = self.right.height(h+1) if self.right else h
        return max(leftHeight, rightHeight)
        
    def getNodesAtDepth(self, depth, nodes=[]):
        if depth == 0:
            nodes.append(self)
            return nodes
            
        if self.left:
            self.left.getNodesAtDepth(depth-1, nodes)
        else:
            nodes.extend([None]*2**(depth-1))
            
        if self.right:
            self.right.getNodesAtDepth(depth-1, nodes)
        else:
            nodes.extend([None]*2**(depth-1))
        return nodes

    def findMin(self):
        if self.left:
            return self.left.findMin()
        return self.data
    
    def delete(self, target):
        if self.data == target: 
            # Do the deletion here
            if self.left and self.right:
                # RTFM
                minValue = self.right.findMin()
                self.data = minValue
                self.right = self.right.delete(minValue)
                return self
            else:
                return self.left or self.right
                
        if self.right and target > self.data:
            self.right = self.right.delete(target)
        if self.left and target < self.data:
            self.left = self.left.delete(target)
        return self.fixImbalanceIfExists()

    def isBalanced(self):
        leftHeight = self.left.height() + 1 if self.left else 0
        rightHeight = self.right.height() + 1 if self.right else 0
        return abs(leftHeight - rightHeight) < 2

    def toStr(self):
        if not self.isBalanced():
            return str(self.data)+'*'
        return str(self.data)

    def getLeftRightHeightDifference(self):
        leftHeight = self.left.height() + 1 if self.left else 0
        rightHeight = self.right.height() + 1 if self.right else 0
        return leftHeight - rightHeight
        
    def add(self, data):
        if self.data == data:
            return 
        if data < self.data:
            if self.left is None:
                self.left = Node(data)
                self.left = self.left.fixImbalanceIfExists()
            else:
                self.left.add(data)
        else:  # Ensuring BST property
            if self.right is None:
                self.right = Node(data)
            else:
                self.right.add(data)
                self.right = self.right.fixImbalanceIfExists()
        
    def fixImbalanceIfExists(self):
        if self.getLeftRightHeightDifference() > 1:
            # Left Imbalance
            if self.left.getLeftRightHeightDifference() > 0:
                # Left Left Imbalance
                return rotateRight(self)
            else:
                # Left Right Imbalance
                self.left = rotateLeft(self.left)
                return rotateRight(self)
        elif self.getLeftRightHeightDifference() < -1:
            #Right Imbalance
            if self.right.getLeftRightHeightDifference() < 0:
                # Right Right Imbalance
                return rotateLeft(self)
            else:
                # Right Left Imbalance
                self.right = rotateRight(self.right)
                return rotateLeft(self)
        return self

    def rebalance(self):
        if self.left:
            self.left.rebalance()
            self.left = self.left.fixImbalanceIfExists()
        if self.right:
            self.right.rebalance()
            self.right = self.right.fixImbalanceIfExists()

In [159]:
class Tree:
    def __init__(self, root, name=''):
        self.root = root
        self.name = name

    def search(self, target):
        return self.root.search(target)

    

    def traversePreorder(self):
        self.root.traversePreorder()
        print()

    def traverseInorder(self):
        self.root.traverseInorder()
        print()

    def traversePostorder(self):
        self.root.traversePostorder()
        print()

    def height(self):
        return self.root.height()

    def getNodesAtDepth(self, depth):
        return self.root.getNodesAtDepth(depth)

    def _nodeToChar(self, n, spacing):
        if n is None:
            return '_' + (' ' * spacing)
        spacing = spacing - len(n.toStr()) + 1
        return n.toStr() + (' ' * spacing)

    def print(self, label=''):
        print(self.name + ' ' + label)
        height = self.root.height()
        spacing = 3
        width = int((2 ** height - 1) * (spacing + 1) + 1)
        offset = int((width - 1) / 2)

        for depth in range(0, height + 1):
            if depth > 0:
                print(' ' * (offset + 1) + (' ' * (spacing + 2)).join(['/' + (' ' * (spacing - 2)) + '\\'] * (2 ** (depth - 1))))
            row = self.root.getNodesAtDepth(depth, [])
            print((' ' * offset) + ''.join([self._nodeToChar(n, spacing) for n in row]))
            spacing = offset + 1
            offset = max(int(offset / 2) - 1, 0)
        print()
        
    def add(self, data):
        self.root.add(data)
        self.root = self.root.fixImbalanceIfExists()

    def delete(self, target):
        self.root = self.root.delete(target)
        self.root = self.root.fixImbalanceIfExists()

    def rebalance(self):
        self.root.rebalance()
        self.root = self.root.fixImbalanceIfExists()

In [160]:
tree_6 = Tree(Node(50))
tree_6.add(25)
tree_6.add(75)
tree_6.add(10)
tree_6.add(35)
tree_6.add(30)
tree_6.add(5)
tree_6.add(2)
tree_6.add(1)

In [161]:
tree_6.print()

 
                              25  
               /                             \
              10*                             35                              
       /             \                 /             \
      5*              _               30              50              
   /     \         /     \         /     \         /     \
  2       _       _       _       _       _       _       75      
 / \     / \     / \     / \     / \     / \     / \     / \
1   _   _   _   _   _   _   _   _   _   _   _   _   _   _   _   



In [162]:
tree_6.delete(50)
tree_6.delete(75)
tree_6.print()

 
              10  
       /             \
      5*              30              
   /     \         /     \
  2       _       25      35      
 / \     / \     / \     / \
1   _   _   _   _   _   _   _   

