In [1]:
class TreeNode:
    def __init__(self, value):
        self.value = value
        self.children = []

In [2]:
def get_postorder(tree):
    nodes = []
    if tree:
        for child in tree.children:
            nodes.extend(get_postorder(child))
        nodes.append(tree)
    return nodes

In [34]:
def get_postorder_mapping(root):
    postorder = get_postorder(root)
    return {i+1: node for i, node in enumerate(postorder)}

In [16]:
def get_leftmost(node):

    while node.children:
        node = node.children[0]
    return node
def get_left_mosts(root):
  post_order = get_postorder(root)
  return {node.value: get_leftmost(node).value for node in post_order}

In [23]:
def compute_keyroots(root):
    post_order = get_postorder(root)
    left_mosts = get_left_mosts(root)

    keyroots = []
    seen_leftmosts = set()

    for i, node in reversed(list(enumerate(post_order))):
        leftmost_value = left_mosts[node.value]
        if leftmost_value not in seen_leftmosts:
            keyroots.append(i+1)
            seen_leftmosts.add(leftmost_value)

    return list(reversed(keyroots))


In [7]:
# Tree 1
a1 = TreeNode("a")
b1 = TreeNode("b")
c1 = TreeNode("c")
d1 = TreeNode("d")
e1 = TreeNode("e")
f1 = TreeNode("f")

c1.children = [b1]
d1.children = [a1, c1]
f1.children = [d1, e1]

# Tree 2
a2 = TreeNode("a")
b2 = TreeNode("b")
c2 = TreeNode("c")
d2 = TreeNode("d")
e2 = TreeNode("e")
f2 = TreeNode("f")

d2.children = [a2, b2]
c2.children = [d2]
f2.children = [c2, e2]

post_order_tree1 = get_postorder(f1)
post_order_tree2 = get_postorder(f2)

print("Tree 1 postorder:", [node.value for node in post_order_tree1])
print("Tree 2 postorder:", [node.value for node in post_order_tree2])



Tree 1 postorder: ['a', 'b', 'c', 'd', 'e', 'f']
Tree 2 postorder: ['a', 'b', 'd', 'c', 'e', 'f']


In [17]:
left_mosts_tree1 = get_left_mosts(f1)
left_mosts_tree2 = get_left_mosts(f2)
print(left_mosts_tree1)
print(left_mosts_tree2)

{'a': 'a', 'b': 'b', 'c': 'b', 'd': 'a', 'e': 'e', 'f': 'a'}
{'a': 'a', 'b': 'b', 'd': 'a', 'c': 'a', 'e': 'e', 'f': 'a'}


In [24]:
keyroots_t1 = compute_keyroots(f1)
keyroots_t2 = compute_keyroots(f2)
print(keyroots_t1)
print(keyroots_t2)

[3, 5, 6]
[2, 5, 6]


In [36]:
postorder_mapping_t1 = get_postorder_mapping(f1)
postorder_mapping_t2 = get_postorder_mapping(f2)
for postorder_num, node in postorder_mapping_t1.items():
    print(f"Node {node.value}: Postorder number {postorder_num}")
print(postorder_mapping_t1[1].value)

Node a: Postorder number 1
Node b: Postorder number 2
Node c: Postorder number 3
Node d: Postorder number 4
Node e: Postorder number 5
Node f: Postorder number 6
a


In [5]:
def cost_function(node1, node2):
    # Insert
    if node1 is None:
        return 1
    # Delete
    if node2 is None:
        return 1
    return 0 if node1.value == node2.value else 1

In [25]:
def count_nodes(root):
    if not root:
        return 0
    count = 1
    for child in root.children:
        count += count_nodes(child)
    return count
m = count_nodes(f1)
n = count_nodes(f2)
print(m, n)
treedist = [[0] * (n + 1) for _ in range(m + 1)]


6 6


In [63]:
def tree_dist(i, j, cost_function):
    i_leftmost = get_leftmost(post_order_tree1[i-1])
    j_leftmost = get_leftmost(post_order_tree2[j-1])

    i_leftmost_ind = post_order_tree1.index(i_leftmost) + 1
    j_leftmost_ind = post_order_tree2.index(j_leftmost) + 1
    # print(i_leftmost_ind, j_leftmost_ind)

    m = i - i_leftmost_ind + 1
    n = j - j_leftmost_ind + 1
    forestdist = [[0] * (n + 1) for _ in range(m + 1)]

    # Base cases
    for i1 in range(1, m + 1):
        forestdist[i1][0] = forestdist[i1 - 1][0] + cost_function(postorder_mapping_t1[i_leftmost_ind+i1-1], None)
    for j1 in range(1, n + 1):
        forestdist[0][j1] = forestdist[0][j1 - 1] + cost_function(None, postorder_mapping_t2[j_leftmost_ind+j1-1])

    # print(forestdist)

    # dp
    for i1 in range(1, m + 1):
        for j1 in range(1, n + 1):
            i1_node = postorder_mapping_t1[i_leftmost_ind+i1-1]
            # print(j_leftmost_ind, j1)
            j1_node = postorder_mapping_t2[j_leftmost_ind+j1-1]
            i1_leftmost = get_leftmost(i1_node)
            j1_leftmost = get_leftmost(j1_node)
            if i_leftmost == i1_leftmost and j_leftmost == j1_leftmost:
                forestdist[i1][j1] = min(
                    forestdist[i1 - 1][j1] + cost_function(i1_node, None),
                    forestdist[i1][j1 - 1] + cost_function(None, j1_node),
                    forestdist[i1 - 1][j1 - 1] + cost_function(i1_node, j1_node)
                )
                treedist[i_leftmost_ind+i1-1][j_leftmost_ind+j1-1] = forestdist[i1][j1]
            else:
                i1_prev = post_order_tree1.index(get_leftmost(postorder_mapping_t1[i_leftmost_ind+i1-1]))
                i2_prev = post_order_tree2.index(get_leftmost(postorder_mapping_t2[j_leftmost_ind+j1-1]))
                i1_prev = i1_prev if i1_prev in range(i_leftmost_ind, i+1) else 0
                i2_prev = i2_prev if i2_prev in range(j_leftmost_ind, j+1) else 0
                forestdist[i1][j1] = min(
                    forestdist[i1 - 1][j1] + cost_function(i1_node, None),
                    forestdist[i1][j1 - 1] + cost_function(None, j1_node),
                    forestdist[i1_prev][i2_prev] + treedist[i_leftmost_ind+i1-1][j_leftmost_ind+j1-1]
                )



In [47]:
tree_dist(3,6, cost_function)

import numpy as np

print(np.array(treedist))

[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 1 0 2 3 0 5]
 [0 2 0 2 2 0 4]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]]


In [72]:
treedist = [[0] * (n + 1) for _ in range(m + 1)]
def main_loop(T1, T2, cost_fn):
    keyroots_T1 = compute_keyroots(T1)
    keyroots_T2 = compute_keyroots(T2)

    for i_prime in range(len(keyroots_T1)):
        for j_prime in range(len(keyroots_T2)):
            i = keyroots_T1[i_prime]
            j = keyroots_T2[j_prime]
            # print(i, j)
            tree_dist(i, j, cost_fn)


In [73]:
res = main_loop(f1, f2, cost_function)

In [77]:
treedist = np.array(treedist)
treedist

array([[0, 0, 0, 0, 0, 0, 0],
       [0, 0, 1, 2, 3, 1, 5],
       [0, 1, 0, 2, 3, 1, 5],
       [0, 2, 1, 2, 2, 2, 4],
       [0, 3, 3, 1, 2, 4, 4],
       [0, 1, 1, 3, 4, 0, 5],
       [0, 5, 5, 3, 3, 5, 2]])

## Final result

In [78]:
print(treedist[-1,-1])

2
