This repository has been archived by the owner on Oct 19, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
prepare_vae_dataset.py
135 lines (104 loc) · 4.85 KB
/
prepare_vae_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# -*- coding: utf-8 -*-
"""
Подготовка датасета для экспериментов с автоэнкодерами и вариационными
автоэнкодерами для русскогоязычных предложений.
На входе используется файл со списком предложений.
На выходе - два файла с векторизованными предложениями и словарь
с соответствиями слов и векторов.
"""
from __future__ import print_function, division
import gensim
import codecs
import os
import pickle
import numpy as np
from future.utils import iteritems
data_folder = '../data'
# Текстовый файл с векторыми слов в word2vec формате, который может прочитать
# gensim. Векторы слов отсюда используются для заполнения тензора с предложениями.
#w2v_path = r'f:\Word2Vec\word_vectors_cbow=1_win=5_dim=32.txt'
w2v_path = '/home/eek/polygon/w2v/w2v.CBOW=0_WIN=5_DIM=8.txt'
# Путь к файлу со списком фраз, на которых будет тренироваться модель.
corpus_path = '../data/phrases.txt'
# Минимальное и максимальное количество слов в предложениях, которые
# будут векторизованы и попадут в датасет.
MIN_SENT_LEN = 1
MAX_SENT_LEN = 6
# Обучение на полном наборе предложений может занять слишком много времени, это
# делает эксперименты с архитектурой некомфортными. Поэтому будем ограничивать датасет
# зданным числом фраз.
MAX_NB_PHRASES = 10000000
def decode_output(y, v2w):
"""
Декодируем выходной тензор автоэнкодера, получаем читабельные
предложения в том же порядке, как входные.
"""
decoded_phrases = []
for iphrase in range(y.shape[0]):
phrase_vectors = y[iphrase]
phrase_words = []
for iword in range(y.shape[1]):
word_vector = phrase_vectors[iword]
l2 = np.linalg.norm(word_vector)
if l2<0.1:
break
min_dist = 1e38
best_word = u''
for v, w in v2w:
d = np.linalg.norm(v - word_vector)
if d < min_dist:
min_dist = d
best_word = w
phrase_words.append(best_word)
decoded_phrases.append(u' '.join(phrase_words))
return decoded_phrases
if __name__ == '__main__':
print('Loading the w2v model {}'.format(w2v_path))
w2v = gensim.models.KeyedVectors.load_word2vec_format(w2v_path, binary=False)
w2v_dims = len(w2v.syn0[0])
print('w2v_dims={0}'.format(w2v_dims))
phrases = []
all_words = set()
with codecs.open(corpus_path, 'r', 'utf-8') as rdr:
for line in rdr:
words = line.strip().split()
if MIN_SENT_LEN <= len(words) <= MAX_SENT_LEN:
all_words_known = True
for word in words:
if word not in w2v:
all_words_known = False
break
if all_words_known:
phrases.append(words)
all_words.update(words)
if len(phrases) >= MAX_NB_PHRASES:
break
nb_phrases = len(phrases)
print('nb_phrases={}'.format(nb_phrases))
max_sent_len = max(map(len, phrases))
print('max_sent_len={}'.format(max_sent_len))
vtexts = np.zeros((nb_phrases, max_sent_len, w2v_dims))
for iphrase, phrase_words in enumerate(phrases):
for iword, word in enumerate(phrase_words):
vtexts[iphrase, iword, :] = w2v[word]
# выполним нормализацию векторов
vmin = np.amin(vtexts)
vmax = np.amax(vtexts)
scale = 1.0 / max(vmax, -vmin) # приводим к диапазону -1..+1
vtexts *= scale
print('scale={}'.format(scale))
word2vec = dict([(word, w2v[word]*scale) for word in all_words])
vmin = np.amin(vtexts)
vmax = np.amax(vtexts)
print('After scaling: vmin={} vmax={}'.format(vmin, vmax))
# тестируем векторизацию
v2w = [(v, w) for w, v in iteritems(word2vec)]
X_probe = vtexts[0:10]
probe_phrases = decode_output(X_probe, v2w)
for phrase in probe_phrases:
print(u'{}'.format(phrase))
print('Storing dataset...')
with open('../data/vtexts.npz', 'wb') as f:
np.savez_compressed(f, vtexts)
with open('../data/word2vec.pkl', 'wb') as f:
pickle.dump(word2vec, f)