Labeling Taxonomy Tree of Categories
====================

In [4]:
from collections import Counter, defaultdict
import multiprocessing as mp
from math import log
from functools import partial
import pickle
import gzip
import csv

In [5]:
# Load files
catDescFile = 'data/cat_desc_2.csv'
catDescHeaders = None
descCatCnt = defaultdict(Counter)
descCnt = Counter()
with open(catDescFile) as fin:
    r = csv.reader(fin)
    for row in r:
        if catDescHeaders is None:
            catDescHeaders = row
            print catDescHeaders
        else:
            descCatCnt[row[1]][int(row[0]) - 1] += int(row[-1])
            descCnt[row[1]] += int(row[-1])

catNameFile = 'data/categoryKey.csv'
catName = {}
with open(catNameFile) as fin:
    r = csv.reader(fin)
    for row in r:
        catName[int(row[0]) - 1] = row[1]

['category_id', 'description', 'count']


In [6]:
# Load taxonomy tree
taxonomyFile = 'model/hcTree1_2.pickle.gz'
with gzip.open(taxonomyFile) as fin:
    taxoTreeNodes = pickle.load(fin)

In [7]:
# Take descriptions of occurance >= $threshold$ and normalize distribution
threshold = 100
descList = [d for d, c in descCnt.most_common() if c >= threshold]
print len(descList), len(descCnt)
print sum(descCnt[d] for d in descList), sum(descCnt.itervalues())

descSet = set(descList)
descDist = defaultdict(Counter)
for desc, dist in descCatCnt.iteritems():
    if desc not in descSet: continue
    s = float(sum(dist.itervalues()))
    for c, v in dist.iteritems():
        descDist[desc][c] += v / s

2467 95369
6826367 7195840


In [8]:
def JensenShannonDiv(dist1, dist2):
    ans = 0.0
    for k in set(dist1.keys()) | set(dist2.keys()):
        x = dist1[k] if k in dist1 else 0
        y = dist2[k] if k in dist2 else 0
        m = (x + y) * 0.5
        if x > 0:
            ans += x * log(x / m)
        if y > 0:
            ans += y * log(y / m)
    return ans * 0.5

def evalMatch(catDist, descDist, desc):
    return JensenShannonDiv(catDist, descDist[desc])

def matchingResult(catSet, mpPool, topNums=5):
    catDist = {c:1.0 / len(catSet) for c in catSet}
    mpKernel = partial(evalMatch, catDist, descDist)
    vals = mpPool.map(mpKernel, descList)
    return vals

In [9]:
# Build category sets for each node in the taxonomy tree
nodeCatSet = []
for i, node in enumerate(taxoTreeNodes):
    if node[0] == 1:
        nodeCatSet.append(set([i]))
    else:
        nodeCatSet.append(nodeCatSet[node[-3]] | nodeCatSet[node[-2]])

In [10]:
matchingMtx = []
mpPool = mp.Pool()
for i, s in enumerate(nodeCatSet):
    if i < len(catName): continue
    matchingMtx.append(matchingResult(s, mpPool))
    if i % 10 == 0:
        print 'Processing node #',i
mpPool.close()

Processing node # 1060
Processing node # 1070
Processing node # 1080
Processing node # 1090
Processing node # 1100
Processing node # 1110
Processing node # 1120
Processing node # 1130
Processing node # 1140
Processing node # 1150
Processing node # 1160
Processing node # 1170
Processing node # 1180
Processing node # 1190
Processing node # 1200
Processing node # 1210
Processing node # 1220
Processing node # 1230
Processing node # 1240
Processing node # 1250
Processing node # 1260
Processing node # 1270
Processing node # 1280
Processing node # 1290
Processing node # 1300
Processing node # 1310
Processing node # 1320
Processing node # 1330
Processing node # 1340
Processing node # 1350
Processing node # 1360
Processing node # 1370
Processing node # 1380
Processing node # 1390
Processing node # 1400
Processing node # 1410
Processing node # 1420
Processing node # 1430
Processing node # 1440
Processing node # 1450
Processing node # 1460
Processing node # 1470
Processing node # 1480
Processing 

In [11]:
outputFileName = 'model/hcTree1_matching_2.pickle.gz'
with gzip.open(outputFileName, 'wb') as fout:
    pickle.dump((matchingMtx, descList), fout)

In [12]:
for i, matches in enumerate(matchingMtx):
    if len(matchingMtx) - i <= 10:
        print i, sorted(enumerate(matches), key=lambda x:x[1])[:3], taxoTreeNodes[i + len(catName)]

1044 [(36, 0.21092763776631712), (5, 0.24845651094555796), (32, 0.2598474946069811)] (146, 490.3942341824215, 2093, 2082, 1.4091328311912004)
1045 [(113, 0.2687346359153511), (663, 0.2687987387364557), (3, 0.27367946786346814)] (153, 579.3230369955083, 2091, 2084, 1.4137517336856464)
1046 [(242, 0.31379005737822885), (1211, 0.32613460227887836), (924, 0.3333149510470036)] (193, 714.9127632251251, 2092, 2087, 1.433201053068268)
1047 [(1, 0.2233342186382508), (25, 0.24177475390999076), (52, 0.25781710888014536)] (299, 1147.5219788328343, 2098, 2090, 1.5685058835640402)
1048 [(6, 0.2934790771991496), (2, 0.3074549618783569), (15, 0.3164779961744379)] (182, 649.6046455931771, 2097, 2095, 1.5777812695297295)
1049 [(29, 0.2885663219793484), (11, 0.3019995033967244), (28, 0.3091142640140115)] (275, 1018.0086530626759, 2101, 2096, 1.6356117380077726)
1050 [(2, 0.20801532485730864), (4, 0.22050368367692116), (6, 0.27532229758736826)] (335, 1228.9276825886855, 2103, 2100, 1.7357812048139905)
105