-
Notifications
You must be signed in to change notification settings - Fork 0
/
utilities.py
132 lines (118 loc) · 4.8 KB
/
utilities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import chainer.functions as chainFunc
import chainer.links as chainLink
import chainer
import itertools
import datetime
import collections
import sys
import random
from gensim.models import word2vec
import cupy as xp
class Vocabulary:
def make(file_path, vocab_size):
self = Vocabulary()
self.word2id = collections.defaultdict(int)
self.id2word = dict()
with open(file_path) as i_f:
word_frequency = collections.defaultdict(int)
for line in i_f:
for word in line.strip().split():
word_frequency[word] += 1
self.word2id['<unk>'] = 0
self.word2id['<bos>'] = 1
self.word2id['<eos>'] = 2
self.word2id[''] = -1
self.id2word[0] = '<unk>'
self.id2word[1] = '<bos>'
self.id2word[2] = '<eos>'
self.id2word[-1] = ''
for i, (word, freq) in zip(range(vocab_size-3), \
sorted(sorted(word_frequency.items()), key=lambda x:x[1], reverse=True)):
self.word2id[word] = i + 3
self.id2word[i+3] = word
self.size = len(self.word2id) - 1
return self
def load(file_path):
self = Vocabulary()
self.word2id = collections.defaultdict(int)
self.id2word = dict()
self.word2id['<unk>'] = 0
self.word2id['<bos>'] = 1
self.word2id['<eos>'] = 2
self.word2id[''] = -1
self.id2word[0] = '<unk>'
self.id2word[1] = '<bos>'
self.id2word[2] = '<eos>'
self.id2word[-1] = ''
with open(file_path) as vocab_file:
for i, word in enumerate(vocab_file):
self.word2id[word.strip()] = i+3
self.id2word[i+3] = word.strip()
self.size = i+4
trace('vocab size: {}'.format(self.size))
return self
def save(self, file_path):
with open(file_path, 'w') as output_file:
for i in range(self.size):
output_file.write(self.id2word[i]+'\n')
def miniBatch(src_corpus, tgt_corpus, src_vocab, tgt_vocab, batch_size, pooling):
batches = list()
mini_batch = list()
for Npairs in gen_Npairs(src_corpus, tgt_corpus, src_vocab, tgt_vocab, batch_size*pooling):
for sent_pair in sorted(Npairs, key=lambda x:len(x[0]), reverse=True):
mini_batch.append(sent_pair)
if len(mini_batch) == batch_size:
batches.append(mini_batch)
mini_batch = list()
if mini_batch:
batches.append(mini_batch)
random.shuffle(batches)
for mini_batch in batches:
batch_src = [mini_batch[i][0] for i in range(len(mini_batch))]
batch_tgt = [mini_batch[i][1] for i in range(len(mini_batch))]
yield ([chainer.Variable(xp.array(list(x), dtype=xp.int32)) for x in itertools.zip_longest(*batch_src, fillvalue=-1)],
[chainer.Variable(xp.array(list(y), dtype=xp.int32)) for y in itertools.zip_longest(*batch_tgt, fillvalue=-1)])
def monoBatch(corpus_file, vocab, args):
with open(corpus_file) as i_f:
for line in i_f:
wordIDs = [vocab.word2id[word] for word in line.strip().split()]
wordIDs.append(vocab.word2id['<eos>'])
yield [chainer.Variable(xp.array([x], dtype=xp.int32)) for x in wordIDs]
def make_word2vec(corpus, embed_size):
word2vec_model = word2vec.Word2Vec(word2vec.LineSentence(corpus), size=embed_size, min_count=1, workers=5)
return word2vec_model
def save_word2vec(model, save_path):
model.save(save_path)
def load_word2vec(path):
return word2vec.Word2Vec.load(path)
def convert_b2w(batch, vocab):
for sentence in list(chainer.cuda.to_cpu(chainFunc.transpose(chainFunc.vstack(batch)).data)):
word_list = list()
for i in list(sentence):
word_list.append(vocab.id2word[i])
if '<eos>' in word_list:
yield ' '.join(word_list[:word_list.index('<eos>')])
else:
yield ' '.join(word_list)
def gen_Npairs(src_corpus, tgt_corpus, src_vocab, tgt_vocab, N):
flag = True
with open(src_corpus) as src, open(tgt_corpus) as tgt:
Npairs = list()
for line_src, line_tgt in zip(src, tgt):
sent_src = list()
sent_tgt = [tgt_vocab.word2id['<bos>']]
for word in line_src.strip().split():
sent_src.append(src_vocab.word2id[word])
sent_src.append(src_vocab.word2id['<eos>'])
for word in line_tgt.strip().split():
sent_tgt.append(tgt_vocab.word2id[word])
sent_tgt.append(tgt_vocab.word2id['<eos>'])
Npairs.append([sent_src, sent_tgt])
if len(Npairs) == N:
yield Npairs
Npairs = list()
if Npairs:
yield Npairs
def trace(*args):
print(datetime.datetime.now(), '...', *args, file=sys.stderr)
sys.stderr.flush()