source: [LeetCode](https://leetcode.com/problems/count-complete-tree-nodes/?envType=study-plan-v2&envId=top-interview-150)

# üå≤ Quick Review ‚Äî LeetCode 222: Count Complete Tree Nodes  
**Difficulty:** Easy (but requires a clever optimization)  
**Topic:** Complete Binary Trees / Binary Search / Tree Properties  

---

## üìå Problem Summary  
Given the root of a **complete binary tree**, return the number of nodes.

A complete tree has these properties:
- Every level is completely filled except possibly the last  
- Nodes in the last level are filled **from left to right**  
- Height = h ‚Üí nodes range from `2^h` to `2^(h+1) - 1`

We must design an algorithm that runs in **less than O(n)**, meaning **no full traversal allowed**.

---

## üöÄ Key Idea  
In a complete tree:

### ‚úî If the left subtree and right subtree have equal heights:
The left subtree is a **perfect** tree.

### ‚úî If the heights differ:
The right subtree is a **perfect** tree.

This allows counting nodes using the formula:

Nodes in a perfect tree of height h = 2^h - 1


And then recursively counting the remaining subtree.

---

## üìù Short Approach (Interview-Ready)

1. **Define a helper** to compute the height from root to the bottom-left node.  
   - Takes O(h) time (h ‚â§ log n).

2. For each node:
   - Compute left height (`hl`)  
   - Compute right height (`hr`)

3. If equal (`hl == hr`):
   - Left subtree is perfect ‚Üí count = `2^hl - 1`  
   - Recurse into right subtree

4. If not equal:
   - Right subtree is perfect ‚Üí count = `2^hr - 1`  
   - Recurse into left subtree

5. Total complexity:  
   - Height computation = O(log n)  
   - Recursive calls = O(log n)  
   - Final time = **O((log n)¬≤)** (acceptable per prompt)

---

## üßë‚Äçüíª Template Code (Optimized O((log n)¬≤))

```python
class Solution:
    def countNodes(self, root):
        if not root:
            return 0
        
        def get_height(node):
            h = 0
            while node:
                h += 1
                node = node.left
            return h
        
        left_h = get_height(root.left)
        right_h = get_height(root.right)
        
        if left_h == right_h:
            # Left is perfect
            return (1 << left_h) + self.countNodes(root.right)
        else:
            # Right is perfect
            return (1 << right_h) + self.countNodes(root.left)


# My solution

In [None]:
# Definition for a binary tree node.
# class TreeNode(object):
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution(object):
    def countNodes(self, root):
        """
        :type root: Optional[TreeNode]
        :rtype: int
        """
        def find_lr_height(node):
            left_height = 0
            right_height = 0
            curr_left = node
            curr_right = node
            while curr_left:
                curr_left = curr_left.left
                left_height += 1
            while curr_right:
                curr_right = curr_right.right
                right_height += 1
            return left_height, right_height
        n = 0
        curr = root
        while curr:
            ll, lr = find_lr_height(curr.left)
            n += 1
            if ll == lr:
                n += 2**ll - 1
                print(ll, n)
                curr = curr.right
            else:
                rl, rr = find_lr_height(curr.right)
                print(rl, n)
                n += 2**rl - 1
                curr = curr.left
        return n
        