-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
127 lines (101 loc) · 3.04 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/local/env python
"""
==========================
Wikipedia - Word2Vec train
==========================
Usage Example
$ python train.py ../../data/zhwiki/ ../../data/zhwiki.model
"""
import os
import sys
import glob
import time
import jieba
import random
import logging
import argparse
import functools
from multiprocessing import Process
from multiprocessing import cpu_count
from gensim.models.word2vec import LineSentence
from gensim.models import Word2Vec
FORMAT = '[%(levelname)s]: %(message)s'
logging.basicConfig(format=FORMAT)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def benchmark(func):
"""
Calcuate the running time for func
"""
start = time.time()
@functools.wraps(func)
def wrapper(*args, **kwargs):
rc = func(*args, **kwargs)
print('Running time: {}'.format(time.time() - start))
return rc
return wrapper
def tokenize_corpus(path, index):
logger.info('Processing {}'.format(path))
with open('/tmp/tokenize.{}.cache'.format(index), 'a') as out:
with open(path, 'r', encoding='utf-8') as f:
for line in f:
for token in jieba.cut(line):
out.write(token)
out.write(' ')
out.write('\n')
def clear_cache():
logger.debug('Clearing cache...')
for fpath in glob.glob('/tmp/tokenize.*.cache'):
os.remove(fpath)
@benchmark
def tokenize_corpus_task(source):
"""
[1, 2, 3, 4], [5, 6, 7, 8]. [9, 10]
"""
clear_cache()
files = glob.glob(source + '/*/wiki_*')
worker_cnt = cpu_count()
workers = []
i = 0
for idx, fpath in enumerate(files):
logger.debug('Spawning process {} for worker {}...'.format(idx, i))
p = Process(target=tokenize_corpus, args=(fpath, i))
workers.append(p)
p.start()
i += 1
if idx != 0 and (idx + 1) % worker_cnt == 0 or idx == len(files) - 1:
i = 0
for p in workers:
p.join()
workers = []
logger.info('Tokenize done')
@benchmark
def combine_cache_files(outpath):
logger.info('Combining tokenized corpus...')
with open(outpath, 'w') as out:
for fpath in glob.glob('/tmp/tokenize.*.cache'):
with open(fpath, 'r') as infile:
for line in infile:
out.write(line)
@benchmark
def train_model(corpus_path, outpath):
logger.info('Training model...')
model = Word2Vec(LineSentence(corpus_path), workers=cpu_count())
model.save(outpath)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'source',
type=str,
help='The wikipedia source dir'
)
parser.add_argument(
'outpath',
type=str,
help='The output path for word2vec model'
)
args = parser.parse_args(sys.argv[1:])
temp_file = '/tmp/train.{}.cache'.format(random.randint(10000, 99999))
tokenize_corpus_task(args.source)
combine_cache_files(temp_file)
train_model(temp_file, args.outpath)