-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
/
fasttext.py
360 lines (292 loc) · 14.2 KB
/
fasttext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Author: Jayant Jain <jayantjain1992@gmail.com>
# Copyright (C) 2017 Radim Rehurek <me@radimrehurek.com>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
"""
Python wrapper around word representation learning from FastText, a library for efficient learning
of word representations and sentence classification [1].
This module allows training a word embedding from a training corpus with the additional ability
to obtain word vectors for out-of-vocabulary words, using the fastText C implementation.
The wrapped model can NOT be updated with new documents for online training -- use gensim's
`Word2Vec` for that.
Example:
>>> from gensim.models.wrappers import FastText
>>> model = fasttext.FastText.train('/Users/kofola/fastText/fasttext', corpus_file='text8')
>>> print model['forests'] # prints vector for given out-of-vocabulary word
.. [1] https://github.com/facebookresearch/fastText#enriching-word-vectors-with-subword-information
"""
import logging
import tempfile
import os
import struct
import numpy as np
from numpy import float32 as REAL, sqrt, newaxis
from gensim import utils
from gensim.models.keyedvectors import KeyedVectors
from gensim.models.word2vec import Word2Vec
from six import string_types
logger = logging.getLogger(__name__)
class FastTextKeyedVectors(KeyedVectors):
"""
Class to contain vectors, vocab and ngrams for the FastText training class and other methods not directly
involved in training such as most_similar().
Subclasses KeyedVectors to implement oov lookups, storing ngrams and other FastText specific methods
"""
def __init__(self):
super(FastTextKeyedVectors, self).__init__()
self.syn0_all_norm = None
self.ngrams = {}
def save(self, *args, **kwargs):
# don't bother storing the cached normalized vectors
kwargs['ignore'] = kwargs.get('ignore', ['syn0norm', 'syn0_all_norm'])
super(FastTextKeyedVectors, self).save(*args, **kwargs)
def word_vec(self, word, use_norm=False):
"""
Accept a single word as input.
Returns the word's representations in vector space, as a 1D numpy array.
The word can be out-of-vocabulary as long as ngrams for the word are present.
For words with all ngrams absent, a KeyError is raised.
Example::
>>> trained_model['office']
array([ -1.40128313e-02, ...])
"""
if word in self.vocab:
return super(FastTextKeyedVectors, self).word_vec(word, use_norm)
else:
word_vec = np.zeros(self.syn0_all.shape[1])
ngrams = FastText.compute_ngrams(word, self.min_n, self.max_n)
if use_norm:
ngram_weights = self.syn0_all_norm
else:
ngram_weights = self.syn0_all
for ngram in ngrams:
if ngram in self.ngrams:
word_vec += ngram_weights[self.ngrams[ngram]]
if word_vec.any():
return word_vec / len(ngrams)
else: # No ngrams of the word are present in self.ngrams
raise KeyError('all ngrams for word %s absent from model' % word)
def init_sims(self, replace=False):
"""
Precompute L2-normalized vectors.
If `replace` is set, forget the original vectors and only keep the normalized
ones = saves lots of memory!
Note that you **cannot continue training** after doing a replace. The model becomes
effectively read-only = you can only call `most_similar`, `similarity` etc.
"""
super(FastTextKeyedVectors, self).init_sims(replace)
if getattr(self, 'syn0_all_norm', None) is None or replace:
logger.info("precomputing L2-norms of ngram weight vectors")
if replace:
for i in xrange(self.syn0_all.shape[0]):
self.syn0_all[i, :] /= sqrt((self.syn0_all[i, :] ** 2).sum(-1))
self.syn0_all_norm = self.syn0_all
else:
self.syn0_all_norm = (self.syn0_all / sqrt((self.syn0_all ** 2).sum(-1))[..., newaxis]).astype(REAL)
def __contains__(self, word):
"""
Check if word is present in the vocabulary, or if any word ngrams are present. A vector for the word is
guaranteed to exist if `__contains__` returns True.
"""
if word in self.vocab:
return True
else:
word_ngrams = set(FastText.compute_ngrams(word, self.min_n, self.max_n))
if len(word_ngrams & set(self.ngrams.keys())):
return True
else:
return False
class FastText(Word2Vec):
"""
Class for word vector training using FastText. Communication between FastText and Python
takes place by working with data files on disk and calling the FastText binary with
subprocess.call().
Implements functionality similar to [fasttext.py](https://github.com/salestock/fastText.py),
improving speed and scope of functionality like `most_similar`, `similarity` by extracting vectors
into numpy matrix.
"""
def initialize_word_vectors(self):
self.wv = FastTextKeyedVectors()
@classmethod
def train(cls, ft_path, corpus_file, output_file=None, model='cbow', size=100, alpha=0.025, window=5, min_count=5,
loss='ns', sample=1e-3, negative=5, iter=5, min_n=3, max_n=6, sorted_vocab=1, threads=12):
"""
`ft_path` is the path to the FastText executable, e.g. `/home/kofola/fastText/fasttext`.
`corpus_file` is the filename of the text file to be used for training the FastText model.
Expects file to contain utf-8 encoded text.
`model` defines the training algorithm. By default, cbow is used. Accepted values are
'cbow', 'skipgram'.
`size` is the dimensionality of the feature vectors.
`window` is the maximum distance between the current and predicted word within a sentence.
`alpha` is the initial learning rate.
`min_count` = ignore all words with total occurrences lower than this.
`loss` = defines training objective. Allowed values are `hs` (hierarchical softmax),
`ns` (negative sampling) and `softmax`. Defaults to `ns`
`sample` = threshold for configuring which higher-frequency words are randomly downsampled;
default is 1e-3, useful range is (0, 1e-5).
`negative` = the value for negative specifies how many "noise words" should be drawn
(usually between 5-20). Default is 5. If set to 0, no negative samping is used.
Only relevant when `loss` is set to `ns`
`iter` = number of iterations (epochs) over the corpus. Default is 5.
`min_n` = min length of char ngrams to be used for training word representations. Default is 3.
`max_n` = max length of char ngrams to be used for training word representations. Set `max_n` to be
lesser than `min_n` to avoid char ngrams being used. Default is 6.
`sorted_vocab` = if 1 (default), sort the vocabulary by descending frequency before
assigning word indexes.
`threads` = number of threads to use. Default is 12.
"""
ft_path = ft_path
output_file = output_file or os.path.join(tempfile.gettempdir(), 'ft_model')
ft_args = {
'input': corpus_file,
'output': output_file,
'lr': alpha,
'dim': size,
'ws': window,
'epoch': iter,
'minCount': min_count,
'neg': negative,
'loss': loss,
'minn': min_n,
'maxn': max_n,
'thread': threads,
't': sample
}
cmd = [ft_path, model]
for option, value in ft_args.items():
cmd.append("-%s" % option)
cmd.append(str(value))
output = utils.check_output(args=cmd)
model = cls.load_fasttext_format(output_file)
cls.delete_training_files(output_file)
return model
def save(self, *args, **kwargs):
# don't bother storing the cached normalized vectors
kwargs['ignore'] = kwargs.get('ignore', ['syn0norm', 'syn0_all_norm'])
super(FastText, self).save(*args, **kwargs)
@classmethod
def load_word2vec_format(cls, *args, **kwargs):
return FastTextKeyedVectors.load_word2vec_format(*args, **kwargs)
@classmethod
def load_fasttext_format(cls, model_file, encoding='utf8'):
"""
Load the input-hidden weight matrix from the fast text output files.
Note that due to limitations in the FastText API, you cannot continue training
with a model loaded this way, though you can query for word similarity etc.
`model_file` is the path to the FastText output files.
FastText outputs two training files - `/path/to/train.vec` and `/path/to/train.bin`
Expected value for this example: `/path/to/train`
"""
model = cls()
model.wv = cls.load_word2vec_format('%s.vec' % model_file, encoding=encoding)
model.load_binary_data('%s.bin' % model_file, encoding=encoding)
return model
@classmethod
def delete_training_files(cls, model_file):
"""Deletes the files created by FastText training"""
try:
os.remove('%s.vec' % model_file)
os.remove('%s.bin' % model_file)
except FileNotFoundError:
logger.debug('Training files %s not found when attempting to delete', model_file)
pass
def load_binary_data(self, model_binary_file, encoding='utf8'):
"""Loads data from the output binary file created by FastText training"""
with utils.smart_open(model_binary_file, 'rb') as f:
self.load_model_params(f)
self.load_dict(f, encoding=encoding)
self.load_vectors(f)
def load_model_params(self, file_handle):
(dim, ws, epoch, minCount, neg, _, loss, model, bucket, minn, maxn, _, t) = self.struct_unpack(file_handle, '@12i1d')
# Parameters stored by [Args::save](https://github.com/facebookresearch/fastText/blob/master/src/args.cc)
self.vector_size = dim
self.window = ws
self.iter = epoch
self.min_count = minCount
self.negative = neg
self.hs = loss == 1
self.sg = model == 2
self.bucket = bucket
self.wv.min_n = minn
self.wv.max_n = maxn
self.sample = t
def load_dict(self, file_handle, encoding='utf8'):
(vocab_size, nwords, _) = self.struct_unpack(file_handle, '@3i')
# Vocab stored by [Dictionary::save](https://github.com/facebookresearch/fastText/blob/master/src/dictionary.cc)
assert len(self.wv.vocab) == nwords, 'mismatch between vocab sizes'
assert len(self.wv.vocab) == vocab_size, 'mismatch between vocab sizes'
ntokens, = self.struct_unpack(file_handle, '@q')
for i in range(nwords):
word_bytes = b''
char_byte = file_handle.read(1)
# Read vocab word
while char_byte != b'\x00':
word_bytes += char_byte
char_byte = file_handle.read(1)
word = word_bytes.decode(encoding)
count, _ = self.struct_unpack(file_handle, '@ib')
_ = self.struct_unpack(file_handle, '@i')
assert self.wv.vocab[word].index == i, 'mismatch between gensim word index and fastText word index'
self.wv.vocab[word].count = count
def load_vectors(self, file_handle):
num_vectors, dim = self.struct_unpack(file_handle, '@2q')
# Vectors stored by [Matrix::save](https://github.com/facebookresearch/fastText/blob/master/src/matrix.cc)
assert self.vector_size == dim, 'mismatch between model sizes'
float_size = struct.calcsize('@f')
if float_size == 4:
dtype = np.dtype(np.float32)
elif float_size == 8:
dtype = np.dtype(np.float64)
self.num_original_vectors = num_vectors
self.wv.syn0_all = np.fromfile(file_handle, dtype=dtype, count=num_vectors * dim)
self.wv.syn0_all = self.wv.syn0_all.reshape((num_vectors, dim))
assert self.wv.syn0_all.shape == (self.bucket + len(self.wv.vocab), self.vector_size), \
'mismatch between weight matrix shape and vocab/model size'
self.init_ngrams()
def struct_unpack(self, file_handle, fmt):
num_bytes = struct.calcsize(fmt)
return struct.unpack(fmt, file_handle.read(num_bytes))
def init_ngrams(self):
"""
Computes ngrams of all words present in vocabulary and stores vectors for only those ngrams.
Vectors for other ngrams are initialized with a random uniform distribution in FastText. These
vectors are discarded here to save space.
"""
self.wv.ngrams = {}
all_ngrams = []
for w, v in self.wv.vocab.items():
all_ngrams += self.compute_ngrams(w, self.wv.min_n, self.wv.max_n)
all_ngrams = set(all_ngrams)
self.num_ngram_vectors = len(all_ngrams)
ngram_indices = []
for i, ngram in enumerate(all_ngrams):
ngram_hash = self.ft_hash(ngram)
ngram_indices.append((len(self.wv.vocab) + ngram_hash) % self.bucket)
self.wv.ngrams[ngram] = i
self.wv.syn0_all = self.wv.syn0_all.take(ngram_indices, axis=0)
@staticmethod
def compute_ngrams(word, min_n, max_n):
ngram_indices = []
BOW, EOW = ('<', '>') # Used by FastText to attach to all words as prefix and suffix
extended_word = BOW + word + EOW
ngrams = set()
for i in range(len(extended_word) - min_n + 1):
for j in range(min_n, max(len(extended_word) - max_n, max_n + 1)):
ngrams.add(extended_word[i:i+j])
return ngrams
@staticmethod
def ft_hash(string):
"""
Reproduces [hash method](https://github.com/facebookresearch/fastText/blob/master/src/dictionary.cc)
used in fastText.
"""
# Runtime warnings for integer overflow are raised, this is expected behaviour. These warnings are suppressed.
old_settings = np.seterr(all='ignore')
h = np.uint32(2166136261)
for c in string:
h = h ^ np.uint32(ord(c))
h = h * np.uint32(16777619)
np.seterr(**old_settings)
return h