Skip to content

Commit 3864a74

Browse files
authored
Add files via upload
1 parent dbe0244 commit 3864a74

File tree

1 file changed

+94
-87
lines changed

1 file changed

+94
-87
lines changed

BPE.py

Lines changed: 94 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def word_split_for_bpe(word, space_symbol='</w>'):
2626

2727

2828
# word frequency 추출.
29-
def get_word_frequency_dict_from_document(path, space_symbol='</w>', top_k=None):
29+
def get_word_frequency_dict_from_document(path, space_symbol='</w>'):
3030
word_frequency_dict = {}
3131

3232
with open(path, 'r', encoding='utf-8') as f:
@@ -40,27 +40,16 @@ def get_word_frequency_dict_from_document(path, space_symbol='</w>', top_k=None)
4040

4141
for word in sentence.split():
4242
# "abc" => "a b c space_symbol"
43-
split_word = word_split_for_bpe(word, space_symbol)
44-
43+
word = word_split_for_bpe(word, space_symbol)
44+
4545
# word frequency
46-
if split_word in word_frequency_dict:
47-
word_frequency_dict[split_word] += 1
46+
if word in word_frequency_dict:
47+
word_frequency_dict[word] += 1
4848
else:
49-
word_frequency_dict[split_word] = 1
49+
word_frequency_dict[word] = 1
50+
51+
return word_frequency_dict
5052

51-
if top_k is None:
52-
return word_frequency_dict
53-
54-
else:
55-
# top_k frequency word
56-
sorted_word_frequency_list = sorted(
57-
word_frequency_dict.items(), # ('key', value) pair
58-
key=lambda x:x[1], # x: ('key', value), and x[1]: value
59-
reverse=True
60-
) # [('a', 3), ('b', 2), ... ]
61-
top_k_word_frequency_dict = dict(sorted_word_frequency_list[:top_k])
62-
63-
return top_k_word_frequency_dict
6453

6554

6655
# merge two dictionary
@@ -116,45 +105,22 @@ def merge_bpe_word(best_pair_and_word_frequency_list):
116105

117106

118107

119-
120-
# from bpe to idx
121-
def make_bpe2idx(word_frequency_list):
122-
bpe2idx = {
123-
'</p>':0,
124-
'UNK':1,
125-
'</g>':2, #go
126-
'</e>':3 #eos
127-
}
128-
idx2bpe = {
129-
0:'</p>',
130-
1:'UNK',
131-
2:'</g>', #go
132-
3:'</e>' #eos
133-
}
134-
idx = 4
135-
136-
for word, _ in word_frequency_list: # word, freq
137-
for bpe in word.split():
138-
# bpe가 bpe2idx에 없는 경우만 idx 부여.
139-
if bpe not in bpe2idx:
140-
bpe2idx[bpe] = idx
141-
idx2bpe[idx] = bpe
142-
idx += 1
143-
return bpe2idx, idx2bpe
144-
145-
146108
def merge_a_word(merge_info, word, cache={}):
147109
# merge_info: list
148110
# word: "c e m e n t </w>" => "ce m e n t<\w>" 되어야 함.
149111

150-
if len(word.split()) == 1:
112+
#if len(word.split()) == 1:
113+
if word.count(' ') == 0:
151114
return word
152115

153116
if word in cache:
154117
return cache[word]
155118
else:
156119
bpe_word = word
157120
for info in merge_info:
121+
if bpe_word.count(' ') == 0:
122+
break
123+
158124
bigram = re.escape(' '.join(info))
159125
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
160126

@@ -166,6 +132,44 @@ def merge_a_word(merge_info, word, cache={}):
166132
return bpe_word
167133

168134

135+
def make_bpe2idx(word_frequency_list, npy_path):
136+
word_frequency_dict = {}
137+
for word, freq in word_frequency_list:
138+
# ex: ('B e it r a g</w>', 8)
139+
split = word.split() # [B e it r a g</w>]
140+
for bpe in split:
141+
if bpe not in word_frequency_dict:
142+
word_frequency_dict[bpe] = freq
143+
else:
144+
word_frequency_dict[bpe] += freq
145+
146+
sorted_voca = sorted(tuple(word_frequency_dict.items()), key=lambda x: x[1], reverse=True)
147+
148+
bpe2idx = {
149+
'</p>':0,
150+
'UNK':1,
151+
'</g>':2, #go
152+
'</e>':3 #eos
153+
}
154+
idx2bpe = {
155+
0:'</p>',
156+
1:'UNK',
157+
2:'</g>', #go
158+
3:'</e>' #eos
159+
}
160+
idx = 4
161+
162+
with open(npy_path+'sorted_voca.txt', 'w', encoding='utf-8') as o:
163+
for voca, freq in sorted_voca:
164+
o.write(str(voca) + ' ' + str(freq) + '\n')
165+
bpe2idx[voca] = idx
166+
idx2bpe[idx] = voca
167+
idx += 1
168+
169+
return bpe2idx, idx2bpe
170+
171+
172+
169173
# 문서를 읽고, bpe 적용. cache 사용할것. apply_bpe에서 사용.
170174
def _apply_bpe(path, out_path, space_symbol='</w>', merge_info=None, cache={}):
171175
start = time.time()
@@ -197,15 +201,15 @@ def _apply_bpe(path, out_path, space_symbol='</w>', merge_info=None, cache={}):
197201
row.extend(merge.split())
198202
wr.writerow(row)
199203

200-
if (i+1) % 500 == 0:
204+
if (i+1) % 100000 == 0:
201205
current_cache_len = len(cache)
202206
print('out_path:', out_path, 'line:', i+1, 'total cache:', current_cache_len, 'added:', current_cache_len-cache_len)
203207
cache_len = current_cache_len
204208

205209
o.close()
206210

207211

208-
def _learn_bpe(word_frequency_dict, num_merges=37000, multi_proc=1):
212+
def _learn_bpe(word_frequency_dict, npy_path, num_merges=37000, multi_proc=1):
209213
#word_frequency_dict = {'l o w </w>' : 1, 'l o w e r </w>' : 1, 'n e w e s t </w>':1, 'w i d e s t </w>':1}
210214

211215
merge_info = [] # 합친 정보를 기억하고있다가 다른 데이터에 적용.
@@ -262,63 +266,67 @@ def _learn_bpe(word_frequency_dict, num_merges=37000, multi_proc=1):
262266
word_frequency = merge_bpe_word((best, word_frequency)) # 가장 높은 빈도의 2gram을 합침.
263267
######
264268

265-
269+
# multiproc close
266270
if multi_proc > 1:
267271
pool.close()
268272

273+
274+
# make npy
275+
if not os.path.exists(npy_path):
276+
print("create" + npy_path + "directory")
277+
os.makedirs(npy_path)
278+
269279
# 빠른 변환을 위한 cache 저장. 기존 word를 key로, bpe 결과를 value로.
270280
cache = {}
271281
for i in range(len(cache_list)):
272282
key = cache_list[i][0]
273283
value = word_frequency[i][0]
274284
cache[key] = value
275285

276-
# voca 추출.
277-
bpe2idx, idx2bpe = make_bpe2idx(word_frequency)
278-
return bpe2idx, idx2bpe, merge_info, cache # dict, dict, list, dict
286+
save_data(npy_path+'merge_info.npy', merge_info) # list
287+
save_data(npy_path+'cache.npy', cache) # dict
288+
print('save merge_info.npy', ', size:', len(merge_info))
289+
print('save cache.npy', ', size:', len(cache))
290+
291+
292+
bpe2idx, idx2bpe = make_bpe2idx(word_frequency, npy_path)
293+
save_data(npy_path+'bpe2idx.npy', bpe2idx) # dict
294+
save_data(npy_path+'idx2bpe.npy', idx2bpe) # dict
295+
print('save bpe2idx.npy', ', size:', len(bpe2idx))
296+
print('save idx2bpe.npy', ', size:', len(idx2bpe))
279297

280298

281299

282-
def learn_bpe(path_list, npy_path, space_symbol='</w>', top_k=None, num_merges=37000, multi_proc=1):
300+
def learn_bpe(path_list, npy_path, space_symbol='</w>', num_merges=37000, voca_threshold=5, multi_proc=1):
283301

284302
print('get word frequency dictionary')
285303
total_word_frequency_dict = {}
286304
for path in path_list:
287305
word_frequency_dict = get_word_frequency_dict_from_document(
288306
path=path,
289307
space_symbol=space_symbol,
290-
top_k=top_k#None
291308
) #ok
292309
total_word_frequency_dict = merge_dictionary(total_word_frequency_dict, word_frequency_dict)
293310

294-
'''
295-
save_data('./word_frequency_dictionary.npy', total_word_frequency_dict)
296-
print('save ./word_frequency_dictionary.npy', 'size:', len(total_word_frequency_dict), '\n')
297-
total_word_frequency_dict = load_data('./word_frequency_dictionary.npy', mode='dictionary')
298-
'''
311+
312+
# 빈도수가 일정 미만인 단어 제외.
313+
total_word_frequency_dict_size = len(total_word_frequency_dict)
314+
for item in list(total_word_frequency_dict.items()):
315+
if item[1] < voca_threshold: # item[0] is key, item[1] is value
316+
del total_word_frequency_dict[item[0]]
317+
print('frequency word dict size:', total_word_frequency_dict_size)
318+
print('threshold applied frequency word dict size:', len(total_word_frequency_dict), 'removed:', total_word_frequency_dict_size-len(total_word_frequency_dict), '\n')
319+
299320

300321
print('learn bpe')
301-
check= time.time()
302-
bpe2idx, idx2bpe, merge_info, cache = _learn_bpe(
322+
_learn_bpe(
303323
total_word_frequency_dict,
324+
npy_path=npy_path,
304325
num_merges=num_merges,
305326
multi_proc=multi_proc
306-
)# dict, dict, list, dict
307-
print('multiproc:', multi_proc, 'time:', time.time()-check)
327+
)
308328

309-
if not os.path.exists(npy_path):
310-
print("create" + npy_path + "directory")
311-
os.makedirs(npy_path)
312-
313-
save_data(npy_path+'bpe2idx.npy', bpe2idx)
314-
save_data(npy_path+'idx2bpe.npy', idx2bpe)
315-
save_data(npy_path+'merge_info.npy', merge_info)
316-
save_data(npy_path+'cache.npy', cache)
317-
print('save bpe2idx.npy', 'size:', len(bpe2idx))
318-
print('save idx2bpe.npy', 'size:', len(idx2bpe))
319-
print('save merge_info.npy', 'size:', len(merge_info))
320-
print('save cache.npy', 'size:', len(cache))
321-
print()
329+
print('\n\n\n')
322330

323331

324332

@@ -327,27 +335,25 @@ def apply_bpe(path_list, out_bpe_path, out_list, npy_path, space_symbol='</w>',
327335
print("create" + out_bpe_path + "directory")
328336
os.makedirs(out_bpe_path)
329337

330-
print('load bpe info')
331338
merge_info = load_data(npy_path+'merge_info.npy')
332339
cache = load_data(npy_path+'cache.npy', mode='dictionary')
333-
340+
341+
print('apply bpe')
334342
for i in range(len(path_list)):
335343
path = path_list[i]
336344
out_path = out_list[i]
337345

338-
print('apply bpe', path, out_path)
346+
print('path:', path, ', out_path:', out_path)
339347
_apply_bpe(
340348
path=path,
341349
out_path=out_bpe_path+out_path,
342350
space_symbol=space_symbol,
343351
merge_info=merge_info,
344352
cache=cache
345353
)
346-
print('save ok', out_path)
347354
save_data(npy_path+'cache.npy', cache)
348-
print('save updated cache ./cache.npy', 'size:', len(cache))
355+
print('\n\n\n')
349356

350-
print()
351357

352358

353359
# save directory
@@ -375,11 +381,12 @@ def apply_bpe(path_list, out_bpe_path, out_list, npy_path, space_symbol='</w>',
375381
if __name__ == '__main__':
376382
# if don't use multiprocessing:
377383
# learn_bpe(path_list, npy_path, space_symbol='</w>', top_k=None)
378-
379-
# multiprocessing, multi_proc: # process, os.cpu_count(): # cpu processor of current computer
380-
# learn bpe from documents
381-
learn_bpe(path_list, npy_path, space_symbol='</w>', top_k=None, num_merges=37000, multi_proc=os.cpu_count())
384+
# multi_proc: # process, os.cpu_count(): # cpu processor of current computer
382385

386+
# learn bpe from documents
387+
learn_bpe(path_list, npy_path, space_symbol='</w>', num_merges=35000, voca_threshold=50, multi_proc=os.cpu_count())
388+
#learn_bpe(path_list, npy_path, space_symbol='</w>', num_merges=30000, voca_threshold=5, multi_proc=os.cpu_count())
389+
383390
# apply bpe to documents
384391
apply_bpe(path_list, out_bpe_path, out_list, npy_path, space_symbol='</w>', pad_symbol='</p>')
385392
apply_bpe(test_path_list, out_bpe_path, test_out_list, npy_path, space_symbol='</w>', pad_symbol='</p>')

0 commit comments

Comments
 (0)