## [Count Univalue Subtrees](https://leetcode.com/problems/count-univalue-subtrees/description/)

Given a binary tree, count the number of uni-value subtrees.

A Uni-value subtree means all nodes of the subtree have the same value.

Example :
```
Input:  root = [5,1,5,5,5,null,5]

              5
             / \
            1   5
           / \   \
          5   5   5

Output: 4
```

In [1]:
# Definition for a binary tree node.
class TreeNode(object):
    def __init__(self, x, left=None, right=None):
        self.val = x
        self.left = left
        self.right = right

class Solution(object):
    def countUnivalSubtrees(self, root):
        """
        :type root: TreeNode
        :rtype: int
        """
        
        # The first solution turned out to be nothing but post-order traversal
        # since I implemented that using recursion, trying the iterative post-order
        # traversal here
        
        # for iterative post-order, we can implement that using two stacks or one stack.
        # going with two stack for simplicity here
        # once we get the post-order traversal done, we can validate the univalue subtrees
        # in one pass.
        
        # another option is to implement post-order using one stack and check the validity
        # of univalue of nodes in the order of their visit. Again, post order traversal
        # property will come in handy here because a node is visited only after visiting
        # its children.
        
        if not root:
            return 0
        
        stack = [root]
        postOrder = []
        
        while stack:
            node = stack.pop()
            
            # push node to our aux stack. we will reverse it at the end
            postOrder.append(node)
            
            if node.left:
                stack.append(node.left)
            
            if node.right:
                stack.append(node.right)
        
        uniValueTrees = 0
        
        # need this map to know whether the child is univalue tree or not.
        uniValueMap = {}
        
        # an utility function to make the code more readable
        def leafNode(node):
            return (node and (not node.left and not node.right))
        
        for i in range(len(postOrder)-1, -1, -1):
            root = postOrder[i]
            if leafNode(root):
                uniValueTrees += 1
                uniValueMap[root] = True
            else:
                # since we are doing post-order, root is visited only after visiting
                # its child nodes. so we are guaranteed to have already found the univalue
                # of the child nodes
                leftIsUnique = (not root.left or (root.val == root.left.val and uniValueMap[root.left]))
                rightIsUnique = (not root.right or (root.val == root.right.val and uniValueMap[root.right]))
                if leftIsUnique and rightIsUnique:
                    uniValueTrees += 1
                    uniValueMap[root] = True
                else:
                    uniValueMap[root] = False
        
        return uniValueTrees
    
        # complexity: O(n) time, O(n) space
        # can be optimized to use single pass instead of two pass.
            
        
    def countUnivalSubtreesRecursion(self, root):
        """
        :type root: TreeNode
        :rtype: int
        """
        # univalue subtrees
        # all nodes of the subtree have the same value
        #   left and right? 
        # how is the output 4 for the given example?
        #
        #    5
        #   5 5 -> is this 3 or 1? -> 3 -> why? left and right
        #   subtrees on their own with one node each, and root
        #   is another subtree of itself with 2 univalue subtrees.
        #   Hence 3.
        #
        # if leaf node
        #       no subtrees. only the root. so consider as uni-value
        # if node.left:
        #       num_univalue_on_left = countUnivalSubtrees(node.left)
        # if node.right:
        #       num_univalue_on_right = countUnivalSubtrees(node.left)
        # 
        # if node.left exists then its val should equal to the root and
        # if node.right exists then its val should equal to the root:
        #       then root is also a univalue
        
        def traverse(root, count):
            if not root:
                return True

            # if not root.left and not root.right:
            #     # leaf node
            #     count[0] += 1
            #     return True
            # this check is not needed as the traversal implicitly takes care

            uniValueOnLeft = traverse(root.left, count)
            uniValueOnRight = traverse(root.right, count)

            if (uniValueOnLeft and uniValueOnRight) and \
                (not root.left or root.left.val == root.val) and \
                (not root.right or root.right.val == root.val):
                    count[0] += 1
                    return True
            else:
                # root value didn't match the subtree value
                return False
        
        count = [0]
        traverse(root, count)
        return count[0]

## Notes
This one turned out to be little tricky. The question was not clear at the beginning. I had to run through some sample cases to understand the ask here. Once the base cases are identified, it became relatively easier to implement. I made a mistake of returning the univalues through recursion. Learnt that such approach may not work in all problems. For some problems, simple traversal may do the trick. For problems like these, I should have found out the use of post-order traversal as and when the base cases are identified because we update univalue of a node only after updating the univalue of the child nodes. Nonetheless, this was a good learning exercise. Retouched the post-order traversal techniques (recursion and iterative with two-stack and single stack) and also python dictionary basics.

### Complexity
- Time - O(n)
- Space - O(n)

In [7]:
# some very basic test cases

s = Solution()

testTree = TreeNode(5)
assert (s.countUnivalSubtrees(testTree) == 1)

testTree.left = TreeNode(5)
assert (s.countUnivalSubtrees(testTree) == 2)

testTree.right = TreeNode(1)
assert (s.countUnivalSubtrees(testTree) == 2)

testTree.right = TreeNode(5, TreeNode(5))
assert (s.countUnivalSubtrees(testTree) == 4)