In [None]:
Given a binary tree, return all duplicate subtrees. For each kind of duplicate subtrees, you only need to return the root node of any one of them.

Two trees are duplicate if they have the same structure with same node values.

Example 1:

        1
       / \
      2   3
     /   / \
    4   2   4
       /
      4
The following are two duplicate subtrees:

      2
     /
    4
and

    4

In [101]:
class TreeNode:
    
    def __init__(self,val,left=None,right=None):
        self.val=val
        self.left=left
        self.right=right

root = TreeNode(1)
root.left=TreeNode(2)
root.left.left=TreeNode(4)
root.right=TreeNode(3)
root.right.left=TreeNode(2)
root.right.right=TreeNode(4)
root.right.left.left=TreeNode(4)

For any two nodes if the value and the left subtree and right subtree are same then two nodes are duplicate subtrees. 
We can walk the tree and associate every node with an identifier given by its value and identifiers of its left and right child.

In the below function , I walk the tree and generate id for every node

In [98]:
nodeHash = dict()
def walk(root):
    if root is None : return None
    root.id = (root.val, walk(root.left), walk(root.right))
    # Keep count of number of times same hash is generated. 
    if root.id in nodeHash:
        nodeHash[root.id]+=1
    else: nodeHash[root.id]=1
    return root.id

In [99]:
# Walk the tree and generate hash(id) for node
walk(root)

(1,
 (2, (4, None, None), None),
 (3, (2, (4, None, None), None), (4, None, None)))

In [100]:
# The hashcodes which have count greater than 1 are repeating subtrees.
nodeHash

{(4, None, None): 3,
 (2, (4, None, None), None): 2,
 (3, (2, (4, None, None), None), (4, None, None)): 1,
 (1,
  (2, (4, None, None), None),
  (3, (2, (4, None, None), None), (4, None, None))): 1}

Code to find duplicate nodes

I traverse the tree in a BFS and keep a count of the number of times each node value is encountered

In [119]:
def find_duplicate_bfs(root):
    count=dict()
    worklist=[root]
    if root is None: return 
    while len(worklist) !=0:
        node = worklist.pop(0)
        if node.left is not None: worklist.append(node.left)
        if node.right is not None: worklist.append(node.right)
        if node.val in count:
            count[node.val]+=1
        else: count[node.val]=1
    
    for k in count:
        if count[k]>1:
            print(k)

In [120]:
find_duplicate_bfs(root)

2
4


In [117]:
count=dict()
def find_duplicate_dfs(root):
    if root is None: return 
    if root.val in count:
        count[root.val]+=1
    else: count[root.val]=1
    find_duplicate_dfs(root.left)
    find_duplicate_dfs(root.right)
    

In [118]:
find_duplicate_dfs(root)
for k in count:
    if count[k]>1:
        print(k)

2
4


Track the path from root node to each leaf node.

In [124]:
def track_path(root):
    if root is None: return None
    leafPath=[]
    path=[]
    walk(root,path,leafPath)
    print(leafPath)
    

In [125]:
def walk(node,path,leafPath):
    if node.left is None and node.right is None: 
        leafPath.append(path+[node.val])
        return
    if node.left is not None: walk(node.left,path+[node.val],leafPath)
    if node.right is not None: walk(node.right,path+[node.val],leafPath)

In [126]:
track_path(root)

[[1, 2, 4], [1, 3, 2, 4], [1, 3, 4]]


In [42]:
def dfs(root):
    if root is None: return 
    print(root.val)
    dfs(root.left)
    dfs(root.right)

In [45]:
trees = collections.defaultdict()
trees.default_factory = trees.__len__
count = collections.Counter()
ans = []
def lookup(node):
    if node:
        uid = trees[node.val, lookup(node.left), lookup(node.right)]
        count[uid] += 1
        if count[uid] == 2:
            ans.append(node)
        return uid

lookup(root)
for node in ans:
    dfs(node)
    

4
