# 834. Sum of Distances in Tree

There is an undirected connected tree with `n` nodes labeled from `0` to `n - 1` and `n - 1` edges.

You are given the integer `n` and the array `edges` where <code>edges[i] = [a<sub>i</sub>, b<sub>i</sub>]</code> indicates that there is an edge between nodes <code>a<sub>i</sub></code> and <code>b<sub>i</sub></code> in the tree.

Return an array `answer` of length `n` where `answer[i]` is the sum of the distances between the <code>i<sup>th</sup></code> node in the tree and all other nodes.

<https://leetcode.com/problems/sum-of-distances-in-tree/description/?envType=daily-question&envId=2024-04-28>

**Constraint:**
* <code>1 <= n <= 3 * 10<sup>4</sup></code>
* `edges.length == n - 1`
* `edges[i].length == 2`
* <code>0 <= a<sub>i</sub>, b<sub>i</sub> < n</code>
* <code>a<sub>i</sub> != b<sub>i</sub></code>
* The given input represents a valid tree.

Example 1:  

<img src="./Images/834-1.png" width="200">

> **Input:** n = 6, edges = [[0, 1], [0, 2], [2, 3], [2, 4], [2, 5]]  
> **Output:** [8, 12, 6, 10, 10, 10]  
> **Explanation:** The tree is shown above.  
We can see that dist(0, 1) + dist(0, 2) + dist(0, 3) + dist(0, 4) + dist(0, 5)  
equals 1 + 1 + 2 + 2 + 2 = 8.  
Hence, answer[0] = 8, and so on.

Example 2:  

<img src="./Images/834-2.png" width="50">

> **Input:** n = 1, edges = []  
> **Output:** [0]

Example 3:  

<img src="./Images/834-3.png" width="100">

> **Input:** n = 2, edges = [[1, 0]]  
> **Output:** [1, 1]

In [1]:
from typing import List
from collections import defaultdict

class Solution:
    def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
        graph = defaultdict(list)
        for x, y in edges:
            graph[x].append(y)
            graph[y].append(x)

        def dfs1(node, parent):
            for child in graph[node]:
                if child != parent:
                    dfs1(child, node)
                    record[node] += record[child]
                    result[node] += result[child] + record[child]

        def dfs2(node, parent):
            for child in graph[node]:
                if child != parent:
                    result[child] = result[node] - record[child] + (n - record[child])
                    dfs2(child, node)

        record, result = [1] * n, [0] * n
        dfs1(0, -1)
        dfs2(0, -1)
        
        return result
    
    def display(self, n: int, edges: List[List[int]]) -> None:
        result = self.sumOfDistancesInTree(n, edges)
        print(f"Result: {result}")

In [2]:
# Example 1

n = 6
edges = [[0, 1], [0, 2], [2, 3], [2, 4], [2, 5]]
Solution().display(n, edges)

Result: [8, 12, 6, 10, 10, 10]


In [3]:
# Example 2

n = 1
edges = []
Solution().display(n, edges)

Result: [0]


In [4]:
# Example 3

n = 2
edges = [[1, 0]]
Solution().display(n, edges)

Result: [1, 1]


**Idea:**  
* Step1: Convert edges into graph.
* Step2: Compute the distance between nodes and their descendants.
* Step3: Compute the final distances.

**Time Complexity:** $O(V + E)$  
<img src="./Images/834-4.png" width="500">