# Binary Trees

In [1]:
# Tree node prototype:
class BinaryTreeNode:
    def __init__(self, data=None, left=None, right=None):
        self.data = data
        self.left = left
        self.right = right

**Question 9.1**: Test if a Binary Tree is height-balanced

In [None]:
"""Book solution
    + review comments
    
Time complexity: O(n), same as postorder traversal with n being the number of nodes in tree
Space complexity: O(h), h being the height of the tree
"""
def is_balanced_binary_tree(tree: BinaryTreeNode) -> bool:
    # uses a namedtuple to keep track of two return values in one location
    BalancedStatusWithheight = collections.namedtuple(
        'BalancedStatusWithHeight', ('balanced', 'height'))
    
    # First value of the return value indicates if tree is balanced, and if
    # balanced the second value of the return value is the height of tree.
    def check_balanced(tree):
        # If the tree is empty, it is automatically balanced with a negative
        # height since there is no root.
        if not tree:
            return BalancedStatusWithheight(balanced=True, height=-1)
        
        # Post order recursion through the tree is performed.
        # this left and right check will cancel the recursive check early if any
        # unbalance is found in a subtree and return False
        left_result = check_balanced(tree.left)
        if not left_result.balanced:
            return left_result
        
        right_result = check_balanced(tree.right)
        if not right_result.balanced:
            return right_result
        
        # Check if the heights of the left and right subtrees have a min diff of 1
        is_balanced = abs(left_result.height - right_result.height) <= 1
        # Calculates the height of the current node to return
        height = max(left_result.height, right_result.height) + 1
        return BalancedStatusWithheight(is_balanced, height)
    
    return check_balanced(tree).balanced

**Question 9.4**: Compute the LCA when nodes have parent pointers

Given two nodes in a binary tree, design an algorithm that computes their LCA (lowest common ancestor). Assumes each node has a parent pointer.

*hint*: The problem is easy if both nodes are the same distance from the root.

If we don't mind using space, this can be made easy by slowly traveling up the two node's parent lineage at the same time. If they're the same, it returns that node. If they are different, we'll have two sets that'll keep track of visited nodes for each lineage. If a match is found, that node is returned.
If both reach a null pointer and that root isn't the same, they're not in the same tree.

- Time: O(h) where h is the height of the tree (since that's worst case LCA).
- Space: O(h) the path up the tree that's saved for each node.

In [6]:
class BTN_with_parent:
    def __init__(self, data=None, left=None, right=None, parent=None):
        self.data = data
        self.left = left
        self.right = right
        self.parent = parent
        
        
def lowest_common_ancestor(n1: BTN_with_parent, n2: BTN_with_parent):
    if n1 == n2:
        return n1
    
    n1_path, n2_path = [], []
    while n1 or n2:
        if n1:
            n1_path.append(n1)
            n1 = n1.parent
        if n2:
            n2_path.append(n2)
            n2 = n2.parent
        
        if n1 == n2 or n1 in n2_path:
            return n1
        if n2 in n1_path:
            return n2
    return None

This is essentially the brute force method according to the book. The more optimal method has the same time complexity but a better space complexity of O(1)

To achieve the better space complexity, we get rid of the issue that they might be at different heights by ascending the height of the lower node to match the heigher. From there, the two nodes climb up the tree and reach the LCA at the same time.

In [7]:
def lca(node0: BTN_with_parent, node1: BTN_with_parent) -> BTN_with_parent:
    def get_depth(node):
        depth = 0
        while node.parent:
            depth += 1
            node = node.parent
        return depth
    
    depth0, depth1 = map(get_depth, (node0, node1))
    if depth1 > depth0:
        node0, node1 = node1, node0
        
    diff = abs(depth1 - depth0)
    while diff:
        node0 = node0.parent
        diff -= 1
        
    while node0 is not node1:
        node0, node1 = node0.parent, node1.parent
        
    return node0