Flattening the Taxonomy Tree
================

In [1]:
from collections import Counter, defaultdict
import pickle
import gzip
import csv

In [2]:
# Load files
catNameFile = 'data/categoryKey.csv'
catName = {}
with open(catNameFile) as fin:
    r = csv.reader(fin)
    for row in r:
        catName[int(row[0]) - 1] = row[1]

# Load taxonomy tree
taxonomyFile = 'model/hcTree1_2.pickle.gz'
with gzip.open(taxonomyFile) as fin:
    taxoTreeNodes = pickle.load(fin)

# Load label info
treeLabelFile = 'model/hcTree1_matching_2.pickle.gz'
with gzip.open(treeLabelFile) as fin:
    matchingMtx, descList = pickle.load(fin)

In [3]:
# Find best matching for each description
descBest = [0] * len(descList)
for i, val in enumerate(matchingMtx):
    for j, v in enumerate(val):
        if v < matchingMtx[descBest[j]][j]:
            descBest[j] = i

In [4]:
# Find best matching for each node
nodeBest = [val.index(min(val)) for val in matchingMtx]

In [5]:
# Designate node description
nodeDesc = [d if descBest[d] == i else None for i, d in enumerate(nodeBest)]

In [8]:
def buildFlattenedTree_recur(taxoTreeNodes, nodeDesc, topNode):
    nodeRec = taxoTreeNodes[topNode]
    s1, s2 = nodeRec[-3], nodeRec[-2]
    if s1 is None and s2 is None:
        return ({'name':catName[topNode], 'children':[]},)
    else:
        if nodeDesc[topNode - len(catName)] is not None:
            st1 = buildFlattenedTree_recur(taxoTreeNodes, nodeDesc, s1)
            st2 = buildFlattenedTree_recur(taxoTreeNodes, nodeDesc, s2)
            return ({'name':descList[nodeDesc[topNode - len(catName)]], 'children':st1+st2},)
        else:
            st1 = buildFlattenedTree_recur(taxoTreeNodes, nodeDesc, s1)
            st2 = buildFlattenedTree_recur(taxoTreeNodes, nodeDesc, s2)
            return st1+st2
    
def buildFlattenedTree(taxoTreeNodes, nodeDesc):
    topNode = len(taxoTreeNodes) - 1
    rootForest = buildFlattenedTree_recur(taxoTreeNodes, nodeDesc, topNode)
    if len(rootForest) == 1:
        return rootForest[0]
    else:
        return {'name':'ROOT', 'children':rootForest}

In [9]:
flattenedTree = buildFlattenedTree(taxoTreeNodes, nodeDesc)

In [10]:
import json
with open('vis/hcTree1_label_flat_2.json', 'w') as fout:
    json.dump(flattenedTree, fout)