# Given a binary tree, write a function to sum the depths of all the nodes in all the sub trees.

In [1]:
class BinaryTree:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None
        
a = BinaryTree(1)
a.left = BinaryTree(2)
a.right = BinaryTree(3)
a.left.left = BinaryTree(4)
a.left.right = BinaryTree(5)
a.right.left = BinaryTree(6)
a.right.right = BinaryTree(7)
a.left.left.left = BinaryTree(8)
a.left.left.right = BinaryTree(9)

In [None]:
"""
        1
       / \
      2   3
     / \ / \
    4  5 6  7
   / \  
  8  9 
  
  Answer: 26
  Explanation: The root node 1 has total 16 depths, node 2 has 6 depths,
               node 3 has 2 depths, node 4 has 2 depths, node 5, 6, 7, 8
               & 9 has 0 depths. So the sum 16 + 6 + 2 + 2 = 26. 
  
"""

In [2]:
# Naive sol.. O(n * log(n))T / O(h)S
# h is height of the tree

def allKindsOfNodeDepths_1(root):
    sumOfAllDepths = 0
    stack = [root]
    
    while len(stack) > 0:
        node = stack.pop()
        
        if node is None:
            continue
            
        sumOfAllDepths += nodeDepths(node)
        
        stack.append(node.left)
        stack.append(node.right)
        
    return sumOfAllDepths

def nodeDepths(node, depth = 0):
    if node is None:
        return 0
    
    return depth + nodeDepths(node.left, depth + 1) + nodeDepths(node.right, depth + 1)

In [3]:
allKindsOfNodeDepths_1(a)

26

In [4]:
# Naive recursive sol.. O(n * log(n))T / O(h)S
# h is height of the tree

def allKindsOfNodeDepths_2(root):
    if root is None:
        return 0
    
    return allKindsOfNodeDepths_2(root.left) + allKindsOfNodeDepths_2(root.right) + nodeDepths(root)

def nodeDepths(node, depth = 0):
    if node is None:
        return 0
    
    return depth + nodeDepths(node.left, depth + 1) + nodeDepths(node.right, depth + 1)

In [5]:
allKindsOfNodeDepths_2(a)

26

In [6]:
# Optimal sol.. O(n)T / O(n)S

def allKindsOfNodeDepths_3(root):
    nodeCounts = {}
    addNodeCounts(root, nodeCounts)
    
    nodeDepths = {}
    addNodeDepths(root, nodeDepths, nodeCounts)
    
    return sumAllNodeDepths(root, nodeDepths)

def addNodeCounts(node, nodeCounts):
    nodeCounts[node] = 1
    
    if node.left is not None:
        addNodeCounts(node.left, nodeCounts)
        nodeCounts[node] += nodeCounts[node.left]
        
    if node.right is not None:
        addNodeCounts(node.right, nodeCounts)
        nodeCounts[node] += nodeCounts[node.right]
        
def addNodeDepths(node, nodeDepths, nodeCounts):
    nodeDepths[node] = 0
    
    if node.left is not None:
        addNodeDepths(node.left, nodeDepths, nodeCounts)
        nodeDepths[node] += nodeDepths[node.left] + nodeCounts[node.left]
        
    if node.right is not None:
        addNodeDepths(node.right, nodeDepths, nodeCounts)
        nodeDepths[node] += nodeDepths[node.right] + nodeCounts[node.right]
        
def sumAllNodeDepths(node, nodeDepths):
    if node is None:
        return 0
    
    return sumAllNodeDepths(node.left, nodeDepths) + sumAllNodeDepths(node.right, nodeDepths) + nodeDepths[node]

In [7]:
allKindsOfNodeDepths_3(a)

26

In [8]:
# Optimal sol.. O(n)T / O(h)S
# h is height of tree

def allKindsOfNodeDepths_4(root):
    return getTreeInfo(root).sumOfAllDepths

def getTreeInfo(tree):
    if tree is None:
        return TreeInfo(0, 0, 0)
    
    leftTreeInfo = getTreeInfo(tree.left)
    rightTreeInfo = getTreeInfo(tree.right)
    
    sumOfLeftDepths = leftTreeInfo.sumOfDepths + leftTreeInfo.numNodesInTree
    sumOfRightDepths = rightTreeInfo.sumOfDepths + rightTreeInfo.numNodesInTree
    
    numNodesInTree = 1 + leftTreeInfo.numNodesInTree + rightTreeInfo.numNodesInTree
    sumOfDepths = sumOfLeftDepths + sumOfRightDepths
    sumOfAllDepths = sumOfDepths + leftTreeInfo.sumOfAllDepths + rightTreeInfo.sumOfAllDepths
    
    return TreeInfo(numNodesInTree, sumOfDepths, sumOfAllDepths)

class TreeInfo:
    def __init__(self, numNodesInTree, sumOfDepths, sumOfAllDepths):
        self.numNodesInTree = numNodesInTree
        self.sumOfDepths = sumOfDepths
        self.sumOfAllDepths = sumOfAllDepths

In [9]:
allKindsOfNodeDepths_4(a)

26