/
catch_labels.py
93 lines (68 loc) · 1.97 KB
/
catch_labels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from collections import Counter
import json
import numpy as np
import re, sys
from tqdm import tqdm
with_file = int(sys.argv[1])
sample_nodes = np.load('wiki_data/wiki_node_5k.npy')
f = open('wiki_data/wiki_dic.json').read()
wiki_dic = json.loads(f)
f = open('wiki_data/title_index.json').read()
t_i = json.loads(f)
f = open('wiki_data/index_title.json').read()
i_t = json.loads(f)
label_pattern = re.compile('\[\[Category:(.*?)\]\]')
if with_file != 1:
t_l = {}
print('catching labels...')
for node in tqdm(sample_nodes):
if node not in i_t:
continue
page = wiki_dic[i_t[node]]
id = re.findall("<id>(.*?)</id>", page)
if len(id) > 0:
id = id[0]
else:
continue
if id in sample_nodes:
labels = label_pattern.findall(page)
# print(labels)
t_l[i_t[node]] = labels
ls = []
for v in tqdm(t_l.values()):
for s in v:
ls.append(s)
counts = sorted(Counter(ls).items(), reverse=True, key=lambda kv: kv[1])
np.save('labels', counts)
else:
counts = np.load('labels.npy')
print(counts[:50])
si = int(input('start index: '))
ti = int(input('end index: '))
counts = counts[si:ti]
print(counts)
counts = np.array(counts)[:,0]
# print(counts)
print('labeling nodes...')
t_l = {}
for node in tqdm(sample_nodes):
if node not in i_t:
continue
title = i_t[node]
page = wiki_dic[i_t[node]]
id = re.findall("<id>(.*?)</id>", page)
if len(id) > 0:
id = id[0]
else:
continue
if id in sample_nodes:
labels = label_pattern.findall(page)
for label in labels:
if label in counts:
t_l[title] = label
break
print('number of labeled data:', len(t_l))
fname = 'wiki_data/title_labels_5k.json'
with open(fname, 'w') as g:
json.dump(t_l, g)
#