Labeling Taxonomy Tree of Categories
====================

In [1]:
from collections import Counter, defaultdict
import multiprocessing as mp
from math import log
from functools import partial
import pickle
import gzip
import csv

In [2]:
# Load files
catDescFile = 'data/cat_desc.csv'
catDescHeaders = None
descCatCnt = defaultdict(Counter)
descCnt = Counter()
with open(catDescFile) as fin:
    r = csv.reader(fin)
    for row in r:
        if catDescHeaders is None:
            catDescHeaders = row
            print catDescHeaders
        else:
            descCatCnt[row[1]][int(row[0]) - 1] += int(row[-1])
            descCnt[row[1]] += int(row[-1])

catNameFile = 'data/categoryKey.csv'
catName = {}
with open(catNameFile) as fin:
    r = csv.reader(fin)
    for row in r:
        catName[int(row[0]) - 1] = row[1]

['category_id', 'description', 'count']


In [3]:
# Load taxonomy tree
taxonomyFile = 'model/hcTree1.pickle.gz'
with gzip.open(taxonomyFile) as fin:
    taxoTreeNodes = pickle.load(fin)

In [4]:
# Take descriptions of occurance >= $threshold$ and normalize distribution
threshold = 10
descList = [d for d, c in descCnt.most_common() if c >= threshold]
print len(descList), len(descCnt)
print sum(descCnt[d] for d in descList), sum(descCnt.itervalues())

descSet = set(descList)
descDist = defaultdict(Counter)
for desc, dist in descCatCnt.iteritems():
    if desc not in descSet: continue
    s = float(sum(dist.itervalues()))
    for c, v in dist.iteritems():
        descDist[desc][c] += v / s

8313 84922
4982844 5115137


In [5]:
def JensenShannonDiv(dist1, dist2):
    ans = 0.0
    for k in set(dist1.keys()) | set(dist2.keys()):
        x = dist1[k] if k in dist1 else 0
        y = dist2[k] if k in dist2 else 0
        m = (x + y) * 0.5
        if x > 0:
            ans += x * log(x / m)
        if y > 0:
            ans += y * log(y / m)
    return ans * 0.5

def evalMatch(catDist, descDist, desc):
    return JensenShannonDiv(catDist, descDist[desc])

def matchingResult(catSet, mpPool, topNums=5):
    catDist = {c:1.0 / len(catSet) for c in catSet}
    mpKernel = partial(evalMatch, catDist, descDist)
    vals = mpPool.map(mpKernel, descList)
    return vals

In [6]:
# Build category sets for each node in the taxonomy tree
nodeCatSet = []
for i, node in enumerate(taxoTreeNodes):
    if node[0] == 1:
        nodeCatSet.append(set([i]))
    else:
        nodeCatSet.append(nodeCatSet[node[-3]] | nodeCatSet[node[-2]])

In [7]:
matchingMtx = []
mpPool = mp.Pool()
for i, s in enumerate(nodeCatSet):
    if i < len(catName): continue
    matchingMtx.append(matchingResult(s, mpPool))
    if i % 10 == 0:
        print 'Processing node #',i
mpPool.close()

Processing node # 1060
Processing node # 1070
Processing node # 1080
Processing node # 1090
Processing node # 1100
Processing node # 1110
Processing node # 1120
Processing node # 1130
Processing node # 1140
Processing node # 1150
Processing node # 1160
Processing node # 1170
Processing node # 1180
Processing node # 1190
Processing node # 1200
Processing node # 1210
Processing node # 1220
Processing node # 1230
Processing node # 1240
Processing node # 1250
Processing node # 1260
Processing node # 1270
Processing node # 1280
Processing node # 1290
Processing node # 1300
Processing node # 1310
Processing node # 1320
Processing node # 1330
Processing node # 1340
Processing node # 1350
Processing node # 1360
Processing node # 1370
Processing node # 1380
Processing node # 1390
Processing node # 1400
Processing node # 1410
Processing node # 1420
Processing node # 1430
Processing node # 1440
Processing node # 1450
Processing node # 1460
Processing node # 1470
Processing node # 1480
Processing 

In [8]:
outputFileName = 'model/hcTree1_matching.pickle.gz'
with gzip.open(outputFileName, 'wb') as fout:
    pickle.dump((matchingMtx, descList), fout)

In [9]:
for i, matches in enumerate(matchingMtx):
    if len(matchingMtx) - i <= 10:
        print i, sorted(enumerate(matches), key=lambda x:x[1])[:3], taxoTreeNodes[i + len(catName)]

1044 [(13, 0.2976502394543247), (41, 0.3010517825879155), (257, 0.3099856637631657)] (111, 387.7966897485003, 2090, 2078, 1.355638759831728)
1045 [(28, 0.1978522147397928), (32, 0.23214088997786936), (179, 0.2346394239716983)] (178, 607.6720768606814, 2094, 2081, 1.4489559291407494)
1046 [(1, 0.173466627081497), (473, 0.21309967310275804), (19, 0.23389380727071363)] (338, 1183.4705633776248, 2097, 2095, 1.4828614263031925)
1047 [(77, 0.259677044209342), (464, 0.26111684795197004), (2, 0.2623498516980681)] (158, 545.3158629555846, 2093, 2091, 1.5192155034762984)
1048 [(13, 0.24853072160782283), (4, 0.29209246076648837), (41, 0.30304914993645443)] (213, 738.3682339170988, 2099, 2096, 1.6038876597063956)
1049 [(1, 0.19126950823748), (473, 0.2328489047404419), (40, 0.2501917696708961)] (406, 1407.5512860989402, 2101, 2092, 1.646253936452292)
1050 [(28, 0.1654963893934989), (17, 0.22331021333304563), (16, 0.24790620281322429)] (278, 941.8621685482683, 2100, 2098, 1.6657312592580618)
1051 [(