# What does T.E.D do?

I'll do a deep dive into what unordered tree edit distance does. 

1. Compare Mike and my annotations.
2. Compare different algorithms with the ground truth. 

This comparison will be based on visualizing the optimal mapping in descending order of the costs incurred.

Based on this, I'll determine whether the unordered tree edit distance is sufficient or whether we need to take a look at other variations such as:

1. Constrained Tree Edit Distance.
2. Structure-respecting Tree Edit Distance.

In [None]:
######################################################## 
## LOAD DATA and MODEL
######################################################## 

from vectorrvnn.data import *
from vectorrvnn.utils import *
from vectorrvnn.baselines import *
from vectorrvnn.trainutils import *
from vectorrvnn.interfaces import *
from more_itertools import unzip
import svgpathtools as svg
from tqdm import tqdm
import matplotlib.pyplot as plt

data = TripletDataset('../data/MikeAnnotations')

def getAnnotationsByName (name) : 
    ann = []
    for i, d in enumerate(data) :
        try : 
            id, _name_ = data.metadata[i].split(', ')
            if _name_ == name : 
                ann.append((id, d))
        except Exception :
            pass
    return ann

mike = getAnnotationsByName('mike')
sumit = getAnnotationsByName('sumit')

commonIds = set(unzip(mike)[0]).intersection(set(unzip(sumit)[0]))

mike = dict([(id, d) for id, d in mike if id in commonIds])
sumit = dict([(id, d) for id, d in sumit if id in commonIds])
svgFiles = list(map(lambda x : mike[x].svgFile, commonIds))

opts = Options().parse(testing=[
    '--batch_size', '32',
    '--checkpoints_dir', '../results',
    '--dataroot', '../data/All',
    '--embedding_size', '32', 
    '--load_ckpt', 'pattern_oneof/best_0-770.pth',                          
    '--modelcls', 'PatternGrouping',
    '--name', 'pattern_oneof',
    '--structure_embedding_size', '8',
    '--samplercls', 'DiscriminativeSampler',
    '--device', 'cuda:1',
    '--phase', 'test',
])

model = buildModel(opts)
triplet = dict([(id, model.greedyTree(forest2tree(d))) for id, d in mike.items()])
for id in triplet.keys() : 
    triplet[id].doc = mike[id].doc

In [None]:
######################################################## 
## VISUALIZE A TREE FOR SANITY
######################################################## 
import random

randomId = random.choice(list(commonIds))
print("MIKE's TREE")
treeImageFromGraph(mike[randomId])
plt.show()
print("SUMIT's TREE")
treeImageFromGraph(sumit[randomId])
plt.show()
print("TRIPLET's TREE")
treeImageFromGraph(triplet[randomId])
plt.show()

In [None]:
def __ted__ (t1, t2) : 
    
    def d (x, y) :
        ps1 = set(t1.nodes[x]['pathSet'])
        ps2 = set(t2.nodes[y]['pathSet'])
        return 0 if ps1 == ps2 else 1

    n, m = t1.number_of_nodes(), t2.number_of_nodes()
    nodes1, nodes2 = list(t1.nodes), list(t2.nodes)
    allPairs = list(product(range(n), range(m)))
    v = np.array([[LpVariable(f'm_{x}_{y}', cat='Binary') for y in range(m)] for x in range(n)])
    distanceMatrix = np.array([[d(x, y) for y in nodes2] for x in nodes1])
    prob = LpProblem("TreeEditDistance", LpMinimize)
    prob += (v * (distanceMatrix - 2)).sum()
    for x in range(n) :
        prob += v[x,:].sum() <= 1
    for y in range(m) : 
        prob += v[:,y].sum() <= 1
    for (i, j), (i_, j_) in product(allPairs, allPairs) : 
        x, x_ = nodes1[i], nodes1[i_]
        y, y_ = nodes2[j], nodes2[j_]
        if x_ in descendants(t1, x) and y_ not in descendants(t2, y) : 
            prob += v[i, j] + v[i_, j_] <= 1
        if x_ not in descendants(t1, x) and y_ in descendants(t2, y) : 
            prob += v[i, j] + v[i_, j_] <= 1
    prob.solve()
    return np.vectorize(lambda x : x.varValue)(v).astype(int)

a, b = triplet[randomId], sumit[randomId]
match =  __ted__(a, b)

In [None]:
def _xrange (pos) :
    xs = [v[0] for v in pos.values()]
    return min(xs), max(xs)

def _yrange (pos) : 
    ys = [v[1] for v in pos.values()]
    return min(ys), max(ys)

def _calculateShift (pos1, pos2) :
    mx1, Mx1 = _xrange(pos1)
    mx2, Mx2 = _xrange(pos2)
    my1, My1 = _yrange(pos1)
    my2, My2 = _yrange(pos2)
    my = min(my1, my2)
    My = max(My1, My2)
    shift = max(Mx1 - mx2, 0) + ((Mx1 - mx1) + (Mx2 - mx2)) / 8
    pos2 = dictmap(lambda k, v : (v[0] + shift, v[1]), pos2)
    pos1 = dictmap(
        lambda k, v : (v[0], my + (My - my) * (v[1] - my1) / (My1 - my1)), 
        pos1
    )
    pos2 = dictmap(
        lambda k, v : (v[0], my + (My - my) * (v[1] - my2) / (My2 - my2)), 
        pos2
    )
    return pos1, pos2



def treeMatchVis (t1, t2, matchMatrix) : 
    fig, ax = plt.subplots(figsize=(20, 5), dpi=200)
    treeMatchVisOnAxis(t1, t2, matchMatrix, fig, ax)
    return (fig, ax)

def treeMatchVisOnAxis (t1, t2, matchMatrix, fig, ax, prefix=('1-', '2-')) :
    t1.graph['nodesep']=1
    t2.graph['nodesep']=1
    pos1 = graphviz_layout(t1, prog='dot')
    pos2 = graphviz_layout(t2, prog='dot')
    pos1, pos2 = _calculateShift(pos1, pos2)

    nodes1, nodes2 = list(t1.nodes), list(t2.nodes)

    # set the match graph and mark positions.
    a, b = nx.DiGraph(t1), nx.DiGraph(t2)
    g = nx.union(a, b, rename=prefix)
    g.remove_edges_from(list(g.edges))
    gPos = dict()
    edge_color = list(islice(
        collapse(cycle([
            'red',
            'yellow',
            'green',
            'blue',
        ])), 
        0, 
        int((matchMatrix == 1).sum())
    ))
    for (i, j), color in zip(zip(*np.where(matchMatrix == 1)), edge_color) :
        u = f'{prefix[0]}{nodes1[i]}'
        v = f'{prefix[1]}{nodes2[j]}'
        g.add_edge(u, v)
        gPos[u] = pos1[nodes1[i]]
        gPos[v] = pos2[nodes2[j]]
        a.nodes[nodes1[i]]['color'] = color
        b.nodes[nodes2[j]]['color'] = color

    nx.draw_networkx_nodes(
        t1, 
        pos1, 
        ax=ax, 
        node_size=0.5
    )
    nx.draw_networkx_nodes(
        t2, 
        pos2, 
        ax=ax, 
        node_size=0.5
    )
    nx.draw_networkx_edges(
        t1, 
        pos1, 
        ax=ax, 
        arrowsize=1
    ) 
    nx.draw_networkx_edges(
        t2, 
        pos2, 
        ax=ax, 
        arrowsize=1
    )
    # horizontal space for one tree in pixels
    pixX = fig.get_figwidth() * fig.dpi / 4
    # max number of images that'll be side by side
    sideBySideIms = max(maxNodesByLevel(t1), maxNodesByLevel(t2)) + 1
    # raster size
    imSize = 128
    # zoom for each view
    zoom = pixX / (sideBySideIms * imSize * 3) 
    for n in t1 :
        subsetDoc = subsetSvg(t1.doc, t1.nodes[n]['pathSet'])
        img = rasterize(subsetDoc, imSize, imSize)
        color = [1, 1, 1]
        if 'color' in a.nodes[n] : 
            color = COLOR_MAP[a.nodes[n]['color']]
        img = alphaComposite(img, module=np, color=color)
        imagebox = OffsetImage(img, zoom=zoom)
        imagebox.image.axes = ax
        ab = AnnotationBbox(imagebox, pos1[n], pad=0)
        ax.add_artist(ab)

    for n in t2 :
        subsetDoc = subsetSvg(t2.doc, t2.nodes[n]['pathSet'])
        img = rasterize(subsetDoc, imSize, imSize)
        color = [1, 1, 1]
        if 'color' in b.nodes[n] : 
            color = COLOR_MAP[b.nodes[n]['color']]
        img = alphaComposite(img, module=np, color=color)
        imagebox = OffsetImage(img, zoom=zoom)
        imagebox.image.axes = ax
        ab = AnnotationBbox(imagebox, pos2[n], pad=0)
        ax.add_artist(ab)

    ax.axis('off')


fig, ax = treeMatchVis(a, b, match)
plt.show()