In [None]:
from utils import avlNode, height

# Part 1: AVL Rotations


In [None]:
#grade
# INPUT:
# A avlNode object where there is an inbalance
# OUTPUT:
# A balanced avlNode object
# NOTE: rotate returns the 'new' root node of the subtree
# Accordingly, it is used by 'node = rotate(node)'
def rotateLeft(node):
    #YOUR CODE HERE
    rightside = node.right
    node.right = rightside.left

    rightside.left = node
    node.parent = rightside
    return rightside

In [None]:
#grade
# INPUT:
# A avlNode object where there is an inbalance
# OUTPUT:
# A balanced avlNode object
# NOTE: rotate returns the 'new' root node of the subtree
# Accordingly, it is used by 'node = rotate(node)'
def rotateRight(node):
    #YOUR CODE HERE
    leftside = node.left
    node.left = leftside.right 
    leftside.right = node
    node.parent = leftside
    return leftside



In [None]:
#grade
# INPUT:
# A avlNode object where there is an inbalance
# OUTPUT:
# A balanced avlNode object
# NOTE: rotate returns the 'new' root node of the subtree
# Accordingly, it is used by 'node = rotate(node)'
def rotateLeftRight(node):
    #YOUR CODE HERE
    node.left = rotateLeft(node.left)
    return rotateRight(node)
    


In [None]:
#grade
# INPUT:
# A avlNode object where there is an inbalance
# OUTPUT:
# A balanced avlNode object
# NOTE: rotate returns the 'new' root node of the subtree
# Accordingly, it is used by 'node = rotate(node)'
def rotateRightLeft(node):
    #YOUR CODE HERE
    node.right = rotateRight(node.right)
    return rotateLeft(node)


# Part 2: AVL Rebalancing


In [10]:
#grade
# INPUT:
# A avlNode object which is the root of the current sub-tree
# OUTPUT:
# a avlNode that defines the root of the current sub-tree
# Rebalance should check for balance and -- if necessary -- perform the correct rotation
# NOTE: If the tree is balanced, rebalance should return the input node unchanged.
def rebalance(node):
    #YOUR CODE HERE
    heightright = height(node.right)
    heightleft = height(node.left)
    balance = heightright-heightleft
    if balance == 2:
        
        childr = height(node.right.right)
        childl = height(node.right.left)
        childbal = childr - childl
        if childbal == 1:
            return rotateLeft(node)
        elif childbal == -1:
            return rotateRightLeft(node)
        else:
            return node
    elif balance == -2:
        childlr = height(node.left.right)
        childll = height(node.left.left)
        childball = childlr - childll
        if childball == 1:
            return rotateLeftRight(node)
        elif childball == -1:
            return rotateRight(node)
        else: 
            return node
    else:
        return node

The code below is a reproduction of the setup the autograder will use to grade your rebalance. Changes you make to the code won't break the autograder but will likely make it very difficult to debug any issues you have (so don't change it)!

In [None]:
# class avl for testing the code below, finish rebalance function first before testing
class avl:
    def __init__(self, root=None):
        self.root = root

    # INPUT:
    # A int 'key' containing the numeric ID of the gene being inserted
    # A float 'val' containing the expression of the gene
    # OUTPUT:
    # Nothing
    def insert(self, key, val):
        self.root = self.insert_helper(self.root, key, val)

    def insert_helper(self, node, key, val):
        if node == None:
           # self.printTree()
            return avlNode(key, val)

        if key < node.key: # look left
            node.left = self.insert_helper(node.left, key, val)
        else: # look right
            node.right = self.insert_helper(node.right, key, val)

        return rebalance(node)

    def height(self):
        return self.height_helper(self.root)

    def height_helper(self, node):
        if node == None:
            return -1

        return 1 + max(self.height_helper(node.left), self.height_helper(node.right))

    # INPUT:
    # A integer storing the key of the key, val pair
    # OUTPUT:
    # The corresponding val for the key 
    # NOTE: You do not need to edit this function but will probably need to use it
    def find(self, key):
        n = self.find_helper(self.root, key)
        if n:
            return n.val
        else:
            return None
    
    def find_helper(self,node,key):
        nkey = node.key
        if nkey > key:
            if node.left:
                return self.find_helper(node.left, key)
            else:
                return None
        elif nkey < key:
            if node.right:
                return self.find_helper(node.right, key)
            else:
                return None
        else:
            return node

    def printTree(self):
        if self.root == None:
            print("(empty)\n")
            return

        root_height = self.height()
        print_matrix_width = (4 << root_height) - 3
        print_matrix_height = 2 * root_height + 1

        output = [None]*print_matrix_height
        for i in range(len(output)):
            output[i] = [' '] * (print_matrix_width+4)

        self.pt_helper(self.root, output, 0, 0, print_matrix_width)

        for row in output:
            print("".join(row))

    def pt_helper(self, root, output, left, top, curr_width):
        val = [char for char in str(root.key)]
        curr_width = int(curr_width)

        left_start_shift = int(1 - (len(val)-1) / 2)
        i = 0
        while( i < len(val) and left + curr_width/2 + i < len(output[top])):
            output[top][int(left + curr_width/2 + left_start_shift + i)] = val[i]
            i+=1
        
        branchOffset = (curr_width+3) >> 3
        branchOffset = int(branchOffset)

        center = int(left + curr_width/2)
        leftcenter = int(left + (curr_width/2 -1)/2)
        rightcenter = int(left + curr_width/2 +2 + (curr_width/2 -1)/2)

        if root.left != None:
            branch_pos = center - branchOffset + 1

            for pos in range(int(center+left_start_shift - 2), branch_pos, -1):
                output[top][pos] = '_'

            output[top+1][branch_pos] = '/'

            for pos in range(branch_pos-1, leftcenter + 2, -1):
                output[top+1][pos] = '_'
        
            self.pt_helper(root.left, output, left, top+2, curr_width/2 - 1)

        
        if root.right != None:
            branch_pos = center + branchOffset + 1

            for pos in range(center+left_start_shift + len(val) + 1, branch_pos, 1):
                output[top][pos] = '_'

            output[top+1][branch_pos] = '\\'

            for pos in range(branch_pos+1, rightcenter, 1):
                output[top+1][pos] = '_'
        
            self.pt_helper(root.right, output, left + curr_width/2 + 2, top+2, curr_width/2 - 1)

In [8]:
# Unlike previous labs I am not giving you a manual test for every rotation
# You can use the autograder to check but you are encouraged to manually check
# different values and different rotations on the provided tree

# NOTE: This tree is already balanced! Most of your tests will also be on already balanced AVL trees
# When rebalance is called on a balanced tree nothing happens, but when you call a SPECIFIC rotation
# your code should still perform that rotation

# In other words, the check for balance should happen in rebalance (not the rotations themselves)
t10 = avlNode(10, 0, None, avlNode(12, 0))
t25 = avlNode(25, 0)
t13 = avlNode(13, 0, t10, t25)
t51 = avlNode(51, 0, avlNode(42, 0), avlNode(99, 0))
t38 = avlNode(38, 0, t13, t51)

rt = avl(t38)
rt.printTree()
''' The original provided tree looks like this: 
           _ 38 __              
         __/       \___          
      13               51        
     /   \_           /   \      
  10       25      42      99    
    \                            
     12                          
'''

rt.root = rotateLeft(t38)
''' If I call rotateRight I will produce the following tree:
            _ 13 __              
         __/       \___          
      10               38        
         \_           /   \      
           12      25      51    
                           / \   
                         42   99 
'''
rt.printTree()

In [9]:
# Testing the AVL insert function (your rebalance)
at = avl()
at.insert(5,3)
at.insert(2,2)
at.insert(3,4)
at.insert(4,4)
at.printTree()

'''
        3
     /   \_
   2        5
           /
          4
'''


at.insert(6,6)
at.insert(7,7)
at.printTree()
'''
        5
     /   \_
   3        6
  / \        \
 2    4      10

'''

               2 __              
                   \___          
                        3        
                          \      
                            5    
                           /     
                          4      
                               2 ______                          
                                       \_______                  
                                                3 __             
                                                    \__          
                                                        5        
                                                      /   \_     
                                                    4        6   
                                                              \  
                                                               7 


'\n        5\n     /   \\_\n   3        6\n  / \\         2    4      10\n\n'