Labeling Taxonomy Tree of Categories
====================

In [3]:
from collections import Counter, defaultdict
import multiprocessing as mp
from math import log
from functools import partial
import pickle
import gzip
import csv
from math import log

In [4]:
# Load files
catDescFile = 'data/cat_desc_wn.csv'
catDescHeaders = None
descCatCnt = defaultdict(Counter)
descCnt = Counter()
with open(catDescFile) as fin:
    r = csv.reader(fin)
    for row in r:
        if catDescHeaders is None:
            catDescHeaders = row
            print catDescHeaders
        else:
            descCatCnt[row[1]][int(row[0]) - 1] += int(row[-1])
            descCnt[row[1]] += int(row[-1])

catNameFile = 'data/categoryKey.csv'
catName = {}
with open(catNameFile) as fin:
    r = csv.reader(fin)
    for row in r:
        catName[int(row[0]) - 1] = row[1]

['category_id', 'description', 'count']


In [5]:
# Load taxonomy tree
taxonomyFile = 'model/hcTree1_original.pickle.gz'
with gzip.open(taxonomyFile) as fin:
    taxoTreeNodes = pickle.load(fin)

In [6]:
# Take descriptions of occurance >= $threshold$ and normalize distribution
threshold = 10
descList = [d for d, c in descCnt.most_common() if c >= threshold]
print len(descList), len(descCnt)
print sum(descCnt[d] for d in descList), sum(descCnt.itervalues())

descSet = set(descList)
descDist = defaultdict(Counter)
for desc, dist in descCatCnt.iteritems():
    if desc not in descSet: continue
    s = float(sum(dist.itervalues()))
    for c, v in dist.iteritems():
        descDist[desc][c] += v / s

2917 2917
4450878 4450878


In [7]:
def JensenShannonDiv(dist1, dist2):
    ans = 0.0
    for k in set(dist1.keys()) | set(dist2.keys()):
        x = dist1[k] if k in dist1 else 0
        y = dist2[k] if k in dist2 else 0
        m = (x + y) * 0.5
        if x > 0:
            ans += x * log(x / m)
        if y > 0:
            ans += y * log(y / m)
    return ans * 0.5

def evalMatch(catDist, descDist, desc):
    return JensenShannonDiv(catDist, descDist[desc]) - 0.01 * log(descCnt[desc])

def matchingResult(catSet, mpPool, topNums=5):
    catDist = {c:1.0 / len(catSet) for c in catSet}
    mpKernel = partial(evalMatch, catDist, descDist)
    vals = mpPool.map(mpKernel, descList)
    return vals

In [8]:
# Build category sets for each node in the taxonomy tree
nodeCatSet = []
for i, node in enumerate(taxoTreeNodes):
    if node[0] == 1:
        nodeCatSet.append(set([i]))
    else:
        nodeCatSet.append(nodeCatSet[node[-3]] | nodeCatSet[node[-2]])

In [9]:
matchingMtx = []
mpPool = mp.Pool(processes=2)
for i, s in enumerate(nodeCatSet):
#     if i < len(catName): continue
    matchingMtx.append(matchingResult(s, mpPool))
    if i % 10 == 0:
        print 'Processing node #',i
mpPool.close()

Processing node # 0
Processing node # 10
Processing node # 20
Processing node # 30
Processing node # 40
Processing node # 50
Processing node # 60
Processing node # 70
Processing node # 80
Processing node # 90
Processing node # 100
Processing node # 110
Processing node # 120
Processing node # 130
Processing node # 140
Processing node # 150
Processing node # 160
Processing node # 170
Processing node # 180
Processing node # 190
Processing node # 200
Processing node # 210
Processing node # 220
Processing node # 230
Processing node # 240
Processing node # 250
Processing node # 260
Processing node # 270
Processing node # 280
Processing node # 290
Processing node # 300
Processing node # 310
Processing node # 320
Processing node # 330
Processing node # 340
Processing node # 350
Processing node # 360
Processing node # 370
Processing node # 380
Processing node # 390
Processing node # 400
Processing node # 410
Processing node # 420
Processing node # 430
Processing node # 440
Processing node # 450

In [10]:
outputFileName = 'model/hcTree1_matching.pickle.gz'
with gzip.open(outputFileName, 'wb') as fout:
    pickle.dump((matchingMtx, descList), fout)

In [11]:
for i, matches in enumerate(matchingMtx):
    if len(matchingMtx) - i <= 10:
        print i, sorted(enumerate(matches), key=lambda x:x[1])[:3], taxoTreeNodes[i]

2099 [(13, 0.1878392814615142), (39, 0.20250548276190292), (226, 0.2341227006916385)] (111, 387.7966897485003, 2090, 2078, 1.355638759831728)
2100 [(31, 0.09454327942168135), (3, 0.11754990830052049), (30, 0.12877143423798257)] (178, 607.6720768606814, 2094, 2081, 1.4489559291407494)
2101 [(1, 0.05069109038694025), (19, 0.1249572038462993), (38, 0.13617335809629266)] (338, 1183.4705633776248, 2097, 2095, 1.4828614263031925)
2102 [(2, 0.14042195510470953), (75, 0.1663088977197052), (28, 0.16874270898014262)] (158, 545.3158629555846, 2093, 2091, 1.5192155034762984)
2103 [(13, 0.13829276925021142), (4, 0.17283897065962012), (39, 0.20444265859002345)] (213, 738.3682339170988, 2099, 2096, 1.6038876597063956)
2104 [(1, 0.06848487911946285), (0, 0.14337959848083326), (19, 0.14865818149487703)] (406, 1407.5512860989402, 2101, 2092, 1.646253936452292)
2105 [(31, 0.06271856656553661), (17, 0.11360435737091924), (215, 0.11511002789850311)] (278, 941.8621685482683, 2100, 2098, 1.6657312592580618)
