In [1]:
import numpy as np
import scipy
from sklearn import datasets

In [2]:
mnist = datasets.fetch_openml('mnist_784', version=1)
mnistX, mnistY = mnist.data.to_numpy(), mnist.target.astype('float32')

Tree node definition

In [71]:
class Node:
    def __init__(self, parent_node, children_nodes, lv, ps, label):
        self.parent = parent_node
        self.children = children_nodes
        self.level = lv
        self.radius = 2 ** lv
        self.pointset = ps
        self.label = label

Tree metrics algorithm

In [102]:
class Tree:
    def __init__(self, datax, datay):
        self.datasetx = datax
        self.datasety = datay
        self.distances = None
        self.root = None
        self.delta = 0
        self.r0 = np.random.uniform(0.5, 1)
        self.nodecounter = 0
        self.levels = {}
        # mapping data point number to node containing only this data point
        self.numtonode = {}

        # mapping data to unique number (its index in dataset)
        self.datatonum = {}
        for i in range(self.datasetx.shape[0]):
            self.datatonum[str(self.datasetx[i])] = i
            
def printwy(s, datay):
    res = "{"
    for i in s:
        res += f"({i}, {datay[i]})"
    res += "}"
    return res
    
def printtree(tree, datay):
    for i in range(tree.delta, 1, -1):
        layer = ""
        for v in tree.levels[i]:
            layer += f"lab: {v.label}, pset: {printwy(v.pointset, datay)}| "
        print(layer)

In [111]:
def calculate_distances(xs):
    n = xs.shape[0]

    print(xs)
    
    result = np.zeros((n,n))

    for i in range(n):
        for j in range(n):
            result[i][j] = scipy.spatial.distance.euclidean( xs[i], xs[j] )
    print(result)

    minres = np.inf
    for i in range(n):
        for j in range(n):
            if i != j:
                minres = min(minres, result[i][j])
                
    return result, np.max(result), minres
    

# input: data, output: tree
def treemetrics(datax, datay):
    # random permutation of data
    indices = np.random.permutation(datax.shape[0])
    datax = datax[indices]
    datay = datay[indices]

    N = datax.shape[0]
    tree = Tree(datax, datay)

    # calculate distances
    tree.distances, maxdistance, mindistance = calculate_distances(datax)
    print(mindistance, maxdistance)

    while 2 ** tree.delta < 2 * maxdistance:
        tree.delta += 1

    allpoints = []
    for i in range(N):
        allpoints.append( i )
    #for x in allpoints:
    #    print(x, datax[x])

    
    tree.root = Node( None, [], tree.delta, set(allpoints), tree.nodecounter)
    tree.nodecounter += 1
    tree.levels[tree.delta] = [tree.root]
    for i in range( tree.delta, 1, -1 ):
        tree.levels[i - 1] = []
        currentradius = (2 ** (i - 1)) * tree.r0
        
        for c in tree.levels[i]:
            S = c.pointset.copy()
            created_nodes = []
            for j in range(0, N):
                ball = set()
                for other in range(0, N):
                    if tree.distances[j][other] < currentradius:
                        ball.add(other)

                isection = ball.intersection(S).copy()
                if len(isection) > 0:
                    new_node = Node(c, [], i - 1, isection, tree.nodecounter)
                    tree.nodecounter += 1
                    tree.levels[i - 1].append(new_node)
                    S = S.difference(isection)
                    created_nodes.append( new_node )

            c.children = created_nodes

    for node in tree.levels[1]:
        elem = None
        for x in node.pointset:
            elem = x
        tree.numtonode[x] = node

    printtree(tree, datay)
    return tree

In [112]:
def treedistance(tree, a, b):
    #if str(a) not in tree.datatonum:
    #    raise error
    #if str(b) not in tree.datatonum:
    #    raise error

    #print(a)
    alabel = tree.datatonum[str(a)]
    blabel = tree.datatonum[str(b)]

    anode = tree.numtonode[alabel]
    bnode = tree.numtonode[blabel]

    distance = 0
    while anode.label != bnode.label:
        distance += 2 ** anode.level
        anode = anode.parent
        bnode = bnode.parent

    return distance * 2

In [117]:
print(mnistX.shape)
mytree = treemetrics(mnistX[:200], mnistY[:200])

(70000, 784)
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
[[   0.         2929.47998798 2573.23609488 ... 3171.96469085
  3119.94791623 3085.62473415]
 [2929.47998798    0.         2529.97134371 ... 2715.80466897
  2950.69720575 2760.12445372]
 [2573.23609488 2529.97134371    0.         ... 2407.12234837
  2922.00256673 2435.19075228]
 ...
 [3171.96469085 2715.80466897 2407.12234837 ...    0.
  2955.22706403 2515.24233425]
 [3119.94791623 2950.69720575 2922.00256673 ... 2955.22706403
     0.         2713.61290533]
 [3085.62473415 2760.12445372 2435.19075228 ... 2515.24233425
  2713.61290533    0.        ]]
419.583126448145 3668.5373107002742
lab: 0, pset: {(0, 5.0)(1, 0.0)(2, 4.0)(3, 1.0)(4, 9.0)(5, 2.0)(6, 1.0)(7, 3.0)(8, 1.0)(9, 4.0)(10, 3.0)(11, 5.0)(12, 3.0)(13, 6.0)(14, 1.0)(15, 7.0)(16, 2.0)(17, 8.0)(18, 6.0)(19, 9.0)(20, 4.0)(21, 0.0)(22, 9.0)(23, 1.0)(24, 1.0)(25, 2.0)(26, 4.0)(27, 3.0)(28, 2.0)(29, 7.0)

In [118]:
sum_euclid_dists = np.zeros((10, 10))
sum_tree_dists = np.zeros((10, 10))
cnt_pair_dists = np.zeros((10, 10))

tree = mytree
for i in range(tree.datasetx.shape[0]):
    for j in range(tree.datasetx.shape[0]):
        if i != j:
            sum_euclid_dists[int(tree.datasety[i]), int(tree.datasety[j])] += tree.distances[i][j]
            sum_tree_dists[int(tree.datasety[i]), int(tree.datasety[j])] += treedistance(tree, tree.datasetx[i], tree.datasetx[j])
            cnt_pair_dists[int(tree.datasety[i]), int(tree.datasety[j])] += 1

average_euclid_dists = sum_euclid_dists / cnt_pair_dists
average_tree_dists = sum_tree_dists / cnt_pair_dists



In [119]:
print(average_euclid_dists)
print(average_tree_dists)
print(average_tree_dists / average_euclid_dists)

[[2513.41956796 2497.37801031 2588.83713045 2456.34367856 2522.55453533
  2603.46259613 2543.46296951 2544.67261255 2576.0122826  2446.59600198]
 [2497.37801031 2496.22998389 2562.91008561 2467.24224478 2502.30311406
  2586.60549002 2520.46347994 2519.87021424 2563.2440497  2438.57579302]
 [2588.83713045 2562.91008561 2654.75286566 2559.88600307 2599.41350625
  2673.13849515 2624.42014822 2625.96696699 2669.31572745 2542.23335043]
 [2456.34367856 2467.24224478 2559.88600307 2419.95962678 2491.43120939
  2607.03056065 2517.98931861 2537.25052289 2551.77038069 2396.82889975]
 [2522.55453533 2502.30311406 2599.41350625 2491.43120939 2555.83939832
  2605.37845106 2547.15080336 2556.21324514 2579.0966726  2459.92804668]
 [2603.46259613 2586.60549002 2673.13849515 2607.03056065 2605.37845106
  2713.67104732 2601.99130402 2617.83786058 2643.40765458 2567.35415641]
 [2543.46296951 2520.46347994 2624.42014822 2517.98931861 2547.15080336
  2601.99130402 2558.24621651 2560.86427466 2585.18960504 