In [None]:
# 202 - Count Complete Tree Nodes
"""
    Given the root of a complete binary tree, return the number of the nodes in the tree.

    According to Wikipedia, every level, except possibly the last, is completely filled in a complete binary tree, 
    and all nodes in the last level are as far left as possible. It can have between 1 and 2^h nodes inclusive at the last level h.

    Design an algorithm that runs in less than O(n) time complexity.
"""

In [None]:
# My accepted solution I:
#
# Definition for a binary tree node.
# class TreeNode(object):
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution(object):
    def countNodes(self, root):
        """
        :type root: Optional[TreeNode]
        :rtype: int
        """
        if root == None:
            return 0
        elif root.left == None:
            return 1
        elif root.right == None:
            return 1 + self.countNodes(root.left)
        else:
            return 1 + self.countNodes(root.left) + self.countNodes(root.right)
        
# Time complexity: In the worst case, this function recursively visits every node once, so time complexity = O(n),
#                  while n is the number of nodes in the binary tree.
# Space complexity: the space complexity corresponds to the height of the binary tree, so space complexity = O(h),
#                   - best case: O(log(n)), for balanced binary tree
#                   - worst case: O(n), for skewed binary tree

In [None]:
# My accepted solution II:
#
# Definition for a binary tree node.
# class TreeNode(object):
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution(object):
    def countNodes(self, root):
        """
        :type root: Optional[TreeNode]
        :rtype: int
        """
        def getHeight(node, alongLeft=True):
            height = 0
            while node:
                node = node.left if alongLeft == True else node.right
                height += 1
            return height

        if root == None:
            return 0
        else:
            leftHeight = getHeight(root.left, alongLeft=True)
            rightHeight = getHeight(root.right, alongLeft=False)
            if leftHeight == rightHeight: # perfect binary tree
                return (1 << (leftHeight + 1)) - 1 # number of node = 2^h - 1
            else:
                return 1 + self.countNodes(root.left) + self.countNodes(root.right)
            
# Time complexity: for a balanced binary tree, getHeight() corresponds to O(h) = O(log(n))
#                  in worst case, the recursion descends O(log(n)) levels (tree height)
#                  for each level, the getHeight() function is called 2 times, which is O(2h) = O(h) = O(log(n))
#                  so, time complexity = O(log(n)) * O(log(n)) = O(log(n)^2)
# Space complexity: in worst case, the recursion descends O(log(n)) levels (tree height)
#                   each recursive call consumes constant space O(1)
#                   for this reason, space complexity = O(log(n))