-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
/
base_any2vec.py
1458 lines (1223 loc) · 66.7 KB
/
base_any2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Author: Shiva Manne <manneshiva@gmail.com>
# Copyright (C) 2018 RaRe Technologies s.r.o.
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
r"""This module contains base classes required for implementing \*2vec algorithms.
The class hierarchy is designed to facilitate adding more concrete implementations for creating embeddings.
In the most general case, the purpose of this class is to transform an arbitrary representation to a numerical vector
(embedding). This is represented by the base :class:`~gensim.models.base_any2vec.BaseAny2VecModel`. The input space in
most cases (in the NLP field at least) is plain text. For this reason, we enrich the class hierarchy with the abstract
:class:`~gensim.models.base_any2vec.BaseWordEmbeddingsModel` to be used as a base for models where the input
space is text.
Notes
-----
Even though this is the usual case, not all embeddings transform text, such as the
:class:`~gensim.models.poincare.PoincareModel` that embeds graphs.
See Also
--------
:class:`~gensim.models.word2vec.Word2Vec`.
Word2Vec model - embeddings for words.
:class:`~gensim.models.fasttext.FastText`.
FastText model - embeddings for words (ngram-based).
:class:`~gensim.models.doc2vec.Doc2Vec`.
Doc2Vec model - embeddings for documents.
:class:`~gensim.models.poincare.PoincareModel`
Poincare model - embeddings for graphs.
"""
from gensim import utils
import logging
from timeit import default_timer
import threading
from six.moves import range
from six import itervalues, string_types
from gensim import matutils
from numpy import float32 as REAL, ones, random, dtype
from types import GeneratorType
from gensim.utils import deprecated
import os
import copy
try:
from queue import Queue
except ImportError:
from Queue import Queue
logger = logging.getLogger(__name__)
class BaseAny2VecModel(utils.SaveLoad):
r"""Base class for training, using and evaluating \*2vec model.
Contains implementation for multi-threaded training. The purpose of this class is to provide a
reference interface for concrete embedding implementations, whether the input space is a corpus
of words, documents or anything else. At the same time, functionality that we expect to be common
for those implementations is provided here to avoid code duplication.
In the special but usual case where the input space consists of words, a more specialized layer
is provided, consider inheriting from :class:`~gensim.models.base_any2vec.BaseWordEmbeddingsModel`
Notes
-----
A subclass should initialize the following attributes:
* self.kv - keyed vectors in model (see :class:`~gensim.models.keyedvectors.Word2VecKeyedVectors` as example)
* self.vocabulary - vocabulary (see :class:`~gensim.models.word2vec.Word2VecVocab` as example)
* self.trainables - internal matrices (see :class:`~gensim.models.word2vec.Word2VecTrainables` as example)
"""
def __init__(self, workers=3, vector_size=100, epochs=5, callbacks=(), batch_words=10000):
"""
Parameters
----------
workers : int, optional
Number of working threads, used for multithreading.
vector_size : int, optional
Dimensionality of the feature vectors.
epochs : int, optional
Number of iterations (epochs) of training through the corpus.
callbacks : list of :class:`~gensim.models.callbacks.CallbackAny2Vec`, optional
List of callbacks that need to be executed/run at specific stages during training.
batch_words : int, optional
Number of words to be processed by a single job.
"""
self.vector_size = int(vector_size)
self.workers = int(workers)
self.epochs = epochs
self.train_count = 0
self.total_train_time = 0
self.batch_words = batch_words
self.model_trimmed_post_training = False
self.callbacks = callbacks
def _get_job_params(self, cur_epoch):
"""Get job parameters required for each batch."""
raise NotImplementedError()
def _set_train_params(self, **kwargs):
"""Set model parameters required for training."""
raise NotImplementedError()
def _update_job_params(self, job_params, epoch_progress, cur_epoch):
"""Get updated job parameters based on the epoch_progress and cur_epoch."""
raise NotImplementedError()
def _get_thread_working_mem(self):
"""Get private working memory per thread."""
raise NotImplementedError()
def _raw_word_count(self, job):
"""Get the number of words in a given job."""
raise NotImplementedError()
def _clear_post_train(self):
"""Resets certain properties of the model post training. eg. `keyedvectors.vectors_norm`."""
raise NotImplementedError()
def _do_train_epoch(self, corpus_file, thread_id, offset, cython_vocab, thread_private_mem, cur_epoch,
total_examples=None, total_words=None, **kwargs):
raise NotImplementedError()
def _do_train_job(self, data_iterable, job_parameters, thread_private_mem):
"""Train a single batch. Return 2-tuple `(effective word count, total word count)`."""
raise NotImplementedError()
def _check_training_sanity(self, epochs=None, total_examples=None, total_words=None, **kwargs):
"""Check that the training parameters provided make sense. e.g. raise error if `epochs` not provided."""
raise NotImplementedError()
def _check_input_data_sanity(self, data_iterable=None, corpus_file=None):
"""Check that only one argument is None."""
if not (data_iterable is None) ^ (corpus_file is None):
raise ValueError("You must provide only one of singlestream or corpus_file arguments.")
def _worker_loop_corpusfile(self, corpus_file, thread_id, offset, cython_vocab, progress_queue, cur_epoch=0,
total_examples=None, total_words=None, **kwargs):
"""Train the model on a `corpus_file` in LineSentence format.
This function will be called in parallel by multiple workers (threads or processes) to make
optimal use of multicore machines.
Parameters
----------
corpus_file : str
Path to a corpus file in :class:`~gensim.models.word2vec.LineSentence` format.
thread_id : int
Thread index starting from 0 to `number of workers - 1`.
offset : int
Offset (in bytes) in the `corpus_file` for particular worker.
cython_vocab : :class:`~gensim.models.word2vec_inner.CythonVocab`
Copy of the vocabulary in order to access it without GIL.
progress_queue : Queue of (int, int, int)
A queue of progress reports. Each report is represented as a tuple of these 3 elements:
* Size of data chunk processed, for example number of sentences in the corpus chunk.
* Effective word count used in training (after ignoring unknown words and trimming the sentence length).
* Total word count used in training.
**kwargs : object
Additional key word parameters for the specific model inheriting from this class.
"""
thread_private_mem = self._get_thread_working_mem()
examples, tally, raw_tally = self._do_train_epoch(
corpus_file, thread_id, offset, cython_vocab, thread_private_mem, cur_epoch,
total_examples=total_examples, total_words=total_words, **kwargs)
progress_queue.put((examples, tally, raw_tally))
progress_queue.put(None)
def _worker_loop(self, job_queue, progress_queue):
"""Train the model, lifting batches of data from the queue.
This function will be called in parallel by multiple workers (threads or processes) to make
optimal use of multicore machines.
Parameters
----------
job_queue : Queue of (list of objects, (str, int))
A queue of jobs still to be processed. The worker will take up jobs from this queue.
Each job is represented by a tuple where the first element is the corpus chunk to be processed and
the second is the dictionary of parameters.
progress_queue : Queue of (int, int, int)
A queue of progress reports. Each report is represented as a tuple of these 3 elements:
* Size of data chunk processed, for example number of sentences in the corpus chunk.
* Effective word count used in training (after ignoring unknown words and trimming the sentence length).
* Total word count used in training.
"""
thread_private_mem = self._get_thread_working_mem()
jobs_processed = 0
while True:
job = job_queue.get()
if job is None:
progress_queue.put(None)
break # no more jobs => quit this worker
data_iterable, job_parameters = job
for callback in self.callbacks:
callback.on_batch_begin(self)
tally, raw_tally = self._do_train_job(data_iterable, job_parameters, thread_private_mem)
for callback in self.callbacks:
callback.on_batch_end(self)
progress_queue.put((len(data_iterable), tally, raw_tally)) # report back progress
jobs_processed += 1
logger.debug("worker exiting, processed %i jobs", jobs_processed)
def _job_producer(self, data_iterator, job_queue, cur_epoch=0, total_examples=None, total_words=None):
"""Fill the jobs queue using the data found in the input stream.
Each job is represented by a tuple where the first element is the corpus chunk to be processed and
the second is a dictionary of parameters.
Parameters
----------
data_iterator : iterable of list of objects
The input dataset. This will be split in chunks and these chunks will be pushed to the queue.
job_queue : Queue of (list of object, dict of (str, int))
A queue of jobs still to be processed. The worker will take up jobs from this queue.
Each job is represented by a tuple where the first element is the corpus chunk to be processed and
the second is the dictionary of parameters.
cur_epoch : int, optional
The current training epoch, needed to compute the training parameters for each job.
For example in many implementations the learning rate would be dropping with the number of epochs.
total_examples : int, optional
Count of objects in the `data_iterator`. In the usual case this would correspond to the number of sentences
in a corpus. Used to log progress.
total_words : int, optional
Count of total objects in `data_iterator`. In the usual case this would correspond to the number of raw
words in a corpus. Used to log progress.
"""
job_batch, batch_size = [], 0
pushed_words, pushed_examples = 0, 0
next_job_params = self._get_job_params(cur_epoch)
job_no = 0
for data_idx, data in enumerate(data_iterator):
data_length = self._raw_word_count([data])
# can we fit this sentence into the existing job batch?
if batch_size + data_length <= self.batch_words:
# yes => add it to the current job
job_batch.append(data)
batch_size += data_length
else:
job_no += 1
job_queue.put((job_batch, next_job_params))
# update the learning rate for the next job
if total_examples:
# examples-based decay
pushed_examples += len(job_batch)
epoch_progress = 1.0 * pushed_examples / total_examples
else:
# words-based decay
pushed_words += self._raw_word_count(job_batch)
epoch_progress = 1.0 * pushed_words / total_words
next_job_params = self._update_job_params(next_job_params, epoch_progress, cur_epoch)
# add the sentence that didn't fit as the first item of a new job
job_batch, batch_size = [data], data_length
# add the last job too (may be significantly smaller than batch_words)
if job_batch:
job_no += 1
job_queue.put((job_batch, next_job_params))
if job_no == 0 and self.train_count == 0:
logger.warning(
"train() called with an empty iterator (if not intended, "
"be sure to provide a corpus that offers restartable iteration = an iterable)."
)
# give the workers heads up that they can finish -- no more work!
for _ in range(self.workers):
job_queue.put(None)
logger.debug("job loop exiting, total %i jobs", job_no)
def _log_progress(self, job_queue, progress_queue, cur_epoch, example_count, total_examples,
raw_word_count, total_words, trained_word_count, elapsed):
raise NotImplementedError()
def _log_epoch_end(self, cur_epoch, example_count, total_examples, raw_word_count, total_words,
trained_word_count, elapsed, is_corpus_file_mode):
raise NotImplementedError()
def _log_train_end(self, raw_word_count, trained_word_count, total_elapsed, job_tally):
raise NotImplementedError()
def _log_epoch_progress(self, progress_queue=None, job_queue=None, cur_epoch=0, total_examples=None,
total_words=None, report_delay=1.0, is_corpus_file_mode=None):
"""Get the progress report for a single training epoch.
Parameters
----------
progress_queue : Queue of (int, int, int)
A queue of progress reports. Each report is represented as a tuple of these 3 elements:
* size of data chunk processed, for example number of sentences in the corpus chunk.
* Effective word count used in training (after ignoring unknown words and trimming the sentence length).
* Total word count used in training.
job_queue : Queue of (list of object, dict of (str, int))
A queue of jobs still to be processed. The worker will take up jobs from this queue.
Each job is represented by a tuple where the first element is the corpus chunk to be processed and
the second is the dictionary of parameters.
cur_epoch : int, optional
The current training epoch, needed to compute the training parameters for each job.
For example in many implementations the learning rate would be dropping with the number of epochs.
total_examples : int, optional
Count of objects in the `data_iterator`. In the usual case this would correspond to the number of sentences
in a corpus. Used to log progress.
total_words : int, optional
Count of total objects in `data_iterator`. In the usual case this would correspond to the number of raw
words in a corpus. Used to log progress.
report_delay : float, optional
Number of seconds between two consecutive progress report messages in the logger.
is_corpus_file_mode : bool, optional
Whether training is file-based (corpus_file argument) or not.
Returns
-------
(int, int, int)
The epoch report consisting of three elements:
* size of data chunk processed, for example number of sentences in the corpus chunk.
* Effective word count used in training (after ignoring unknown words and trimming the sentence length).
* Total word count used in training.
"""
example_count, trained_word_count, raw_word_count = 0, 0, 0
start, next_report = default_timer() - 0.00001, 1.0
job_tally = 0
unfinished_worker_count = self.workers
while unfinished_worker_count > 0:
report = progress_queue.get() # blocks if workers too slow
if report is None: # a thread reporting that it finished
unfinished_worker_count -= 1
logger.info("worker thread finished; awaiting finish of %i more threads", unfinished_worker_count)
continue
examples, trained_words, raw_words = report
job_tally += 1
# update progress stats
example_count += examples
trained_word_count += trained_words # only words in vocab & sampled
raw_word_count += raw_words
# log progress once every report_delay seconds
elapsed = default_timer() - start
if elapsed >= next_report:
self._log_progress(
job_queue, progress_queue, cur_epoch, example_count, total_examples,
raw_word_count, total_words, trained_word_count, elapsed)
next_report = elapsed + report_delay
# all done; report the final stats
elapsed = default_timer() - start
self._log_epoch_end(
cur_epoch, example_count, total_examples, raw_word_count, total_words,
trained_word_count, elapsed, is_corpus_file_mode)
self.total_train_time += elapsed
return trained_word_count, raw_word_count, job_tally
def _train_epoch_corpusfile(self, corpus_file, cur_epoch=0, total_examples=None, total_words=None, **kwargs):
"""Train the model for a single epoch.
Parameters
----------
corpus_file : str
Path to a corpus file in :class:`~gensim.models.word2vec.LineSentence` format.
cur_epoch : int, optional
The current training epoch, needed to compute the training parameters for each job.
For example in many implementations the learning rate would be dropping with the number of epochs.
total_examples : int, optional
Count of objects in the `data_iterator`. In the usual case this would correspond to the number of sentences
in a corpus, used to log progress.
total_words : int
Count of total objects in `data_iterator`. In the usual case this would correspond to the number of raw
words in a corpus, used to log progress. Must be provided in order to seek in `corpus_file`.
**kwargs : object
Additional key word parameters for the specific model inheriting from this class.
Returns
-------
(int, int, int)
The training report for this epoch consisting of three elements:
* Size of data chunk processed, for example number of sentences in the corpus chunk.
* Effective word count used in training (after ignoring unknown words and trimming the sentence length).
* Total word count used in training.
"""
if not total_words:
raise ValueError("total_words must be provided alongside corpus_file argument.")
from gensim.models.word2vec_corpusfile import CythonVocab
from gensim.models.fasttext import FastText
cython_vocab = CythonVocab(self.wv, hs=self.hs, fasttext=isinstance(self, FastText))
progress_queue = Queue()
corpus_file_size = os.path.getsize(corpus_file)
thread_kwargs = copy.copy(kwargs)
thread_kwargs['cur_epoch'] = cur_epoch
thread_kwargs['total_examples'] = total_examples
thread_kwargs['total_words'] = total_words
workers = [
threading.Thread(
target=self._worker_loop_corpusfile,
args=(
corpus_file, thread_id, corpus_file_size / self.workers * thread_id, cython_vocab, progress_queue
),
kwargs=thread_kwargs
) for thread_id in range(self.workers)
]
for thread in workers:
thread.daemon = True
thread.start()
trained_word_count, raw_word_count, job_tally = self._log_epoch_progress(
progress_queue=progress_queue, job_queue=None, cur_epoch=cur_epoch,
total_examples=total_examples, total_words=total_words, is_corpus_file_mode=True)
return trained_word_count, raw_word_count, job_tally
def _train_epoch(self, data_iterable, cur_epoch=0, total_examples=None, total_words=None,
queue_factor=2, report_delay=1.0):
"""Train the model for a single epoch.
Parameters
----------
data_iterable : iterable of list of object
The input corpus. This will be split in chunks and these chunks will be pushed to the queue.
cur_epoch : int, optional
The current training epoch, needed to compute the training parameters for each job.
For example in many implementations the learning rate would be dropping with the number of epochs.
total_examples : int, optional
Count of objects in the `data_iterator`. In the usual case this would correspond to the number of sentences
in a corpus, used to log progress.
total_words : int, optional
Count of total objects in `data_iterator`. In the usual case this would correspond to the number of raw
words in a corpus, used to log progress.
queue_factor : int, optional
Multiplier for size of queue -> size = number of workers * queue_factor.
report_delay : float, optional
Number of seconds between two consecutive progress report messages in the logger.
Returns
-------
(int, int, int)
The training report for this epoch consisting of three elements:
* Size of data chunk processed, for example number of sentences in the corpus chunk.
* Effective word count used in training (after ignoring unknown words and trimming the sentence length).
* Total word count used in training.
"""
job_queue = Queue(maxsize=queue_factor * self.workers)
progress_queue = Queue(maxsize=(queue_factor + 1) * self.workers)
workers = [
threading.Thread(
target=self._worker_loop,
args=(job_queue, progress_queue,))
for _ in range(self.workers)
]
workers.append(threading.Thread(
target=self._job_producer,
args=(data_iterable, job_queue),
kwargs={'cur_epoch': cur_epoch, 'total_examples': total_examples, 'total_words': total_words}))
for thread in workers:
thread.daemon = True # make interrupting the process with ctrl+c easier
thread.start()
trained_word_count, raw_word_count, job_tally = self._log_epoch_progress(
progress_queue, job_queue, cur_epoch=cur_epoch, total_examples=total_examples, total_words=total_words,
report_delay=report_delay, is_corpus_file_mode=False)
return trained_word_count, raw_word_count, job_tally
def train(self, data_iterable=None, corpus_file=None, epochs=None, total_examples=None,
total_words=None, queue_factor=2, report_delay=1.0, callbacks=(), **kwargs):
"""Train the model for multiple epochs using multiple workers.
Parameters
----------
data_iterable : iterable of list of object
The input corpus. This will be split in chunks and these chunks will be pushed to the queue.
corpus_file : str, optional
Path to a corpus file in :class:`~gensim.models.word2vec.LineSentence` format.
If you use this argument instead of `data_iterable`, you must provide `total_words` argument as well.
epochs : int, optional
Number of epochs (training iterations over the whole input) of training.
total_examples : int, optional
Count of objects in the `data_iterator`. In the usual case this would correspond to the number of sentences
in a corpus, used to log progress.
total_words : int, optional
Count of total objects in `data_iterator`. In the usual case this would correspond to the number of raw
words in a corpus, used to log progress.
queue_factor : int, optional
Multiplier for size of queue -> size = number of workers * queue_factor.
report_delay : float, optional
Number of seconds between two consecutive progress report messages in the logger.
callbacks : list of :class:`~gensim.models.callbacks.CallbackAny2Vec`, optional
List of callbacks to execute at specific stages during training.
**kwargs : object
Additional key word parameters for the specific model inheriting from this class.
Returns
-------
(int, int)
The total training report consisting of two elements:
* size of total data processed, for example number of sentences in the whole corpus.
* Effective word count used in training (after ignoring unknown words and trimming the sentence length).
"""
self._set_train_params(**kwargs)
if callbacks:
self.callbacks = callbacks
self.epochs = epochs
self._check_training_sanity(
epochs=epochs,
total_examples=total_examples,
total_words=total_words, **kwargs)
for callback in self.callbacks:
callback.on_train_begin(self)
trained_word_count = 0
raw_word_count = 0
start = default_timer() - 0.00001
job_tally = 0
for cur_epoch in range(self.epochs):
for callback in self.callbacks:
callback.on_epoch_begin(self)
if data_iterable is not None:
trained_word_count_epoch, raw_word_count_epoch, job_tally_epoch = self._train_epoch(
data_iterable, cur_epoch=cur_epoch, total_examples=total_examples,
total_words=total_words, queue_factor=queue_factor, report_delay=report_delay)
else:
trained_word_count_epoch, raw_word_count_epoch, job_tally_epoch = self._train_epoch_corpusfile(
corpus_file, cur_epoch=cur_epoch, total_examples=total_examples, total_words=total_words, **kwargs)
trained_word_count += trained_word_count_epoch
raw_word_count += raw_word_count_epoch
job_tally += job_tally_epoch
for callback in self.callbacks:
callback.on_epoch_end(self)
# Log overall time
total_elapsed = default_timer() - start
self._log_train_end(raw_word_count, trained_word_count, total_elapsed, job_tally)
self.train_count += 1 # number of times train() has been called
self._clear_post_train()
for callback in self.callbacks:
callback.on_train_end(self)
return trained_word_count, raw_word_count
@classmethod
def load(cls, fname_or_handle, **kwargs):
"""Load a previously saved object (using :meth:`gensim.models.base_any2vec.BaseAny2VecModel.save`) from a file.
Parameters
----------
fname_or_handle : {str, file-like object}
Path to file that contains needed object or handle to an open file.
**kwargs : object
Keyword arguments propagated to :meth:`~gensim.utils.SaveLoad.load`.
See Also
--------
:meth:`~gensim.models.base_any2vec.BaseAny2VecModel.save`
Method for save a model.
Returns
-------
object
Object loaded from `fname_or_handle`.
Raises
------
IOError
When methods are called on an instance (should be called on a class, this is a class method).
"""
return super(BaseAny2VecModel, cls).load(fname_or_handle, **kwargs)
def save(self, fname_or_handle, **kwargs):
""""Save the object to file.
Parameters
----------
fname_or_handle : {str, file-like object}
Path to file where the model will be persisted.
**kwargs : object
Key word arguments propagated to :meth:`~gensim.utils.SaveLoad.save`.
See Also
--------
:meth:`~gensim.models.base_any2vec.BaseAny2VecModel.load`
Method for load model after current method.
"""
super(BaseAny2VecModel, self).save(fname_or_handle, **kwargs)
class BaseWordEmbeddingsModel(BaseAny2VecModel):
"""Base class containing common methods for training, using & evaluating word embeddings learning models.
See Also
--------
:class:`~gensim.models.word2vec.Word2Vec`.
Word2Vec model - embeddings for words.
:class:`~gensim.models.fasttext.FastText`.
FastText model - embeddings for words (ngram-based).
:class:`~gensim.models.doc2vec.Doc2Vec`.
Doc2Vec model - embeddings for documents.
:class:`~gensim.models.poincare.PoincareModel`
Poincare model - embeddings for graphs.
"""
def _clear_post_train(self):
raise NotImplementedError()
def _do_train_job(self, data_iterable, job_parameters, thread_private_mem):
raise NotImplementedError()
def _set_train_params(self, **kwargs):
raise NotImplementedError()
def __init__(self, sentences=None, corpus_file=None, workers=3, vector_size=100, epochs=5, callbacks=(),
batch_words=10000, trim_rule=None, sg=0, alpha=0.025, window=5, seed=1, hs=0, negative=5,
ns_exponent=0.75, cbow_mean=1, min_alpha=0.0001, compute_loss=False, **kwargs):
"""
Parameters
----------
sentences : iterable of list of str, optional
Can be simply a list of lists of tokens, but for larger corpora,
consider an iterable that streams the sentences directly from disk/network.
See :class:`~gensim.models.word2vec.BrownCorpus`, :class:`~gensim.models.word2vec.Text8Corpus`
or :class:`~gensim.models.word2vec.LineSentence` for such examples.
corpus_file : str, optional
Path to a corpus file in :class:`~gensim.models.word2vec.LineSentence` format.
You may use this argument instead of `sentences` to get performance boost. Only one of `sentences` or
`corpus_file` arguments need to be passed (or none of them, in that case, the model is left uninitialized).
workers : int, optional
Number of working threads, used for multiprocessing.
vector_size : int, optional
Dimensionality of the feature vectors.
epochs : int, optional
Number of iterations (epochs) of training through the corpus.
callbacks : list of :class:`~gensim.models.callbacks.CallbackAny2Vec`, optional
List of callbacks that need to be executed/run at specific stages during training.
batch_words : int, optional
Number of words to be processed by a single job.
trim_rule : function, optional
Vocabulary trimming rule, specifies whether certain words should remain in the vocabulary,
be trimmed away, or handled using the default (discard if word count < min_count).
Can be None (min_count will be used, look to :func:`~gensim.utils.keep_vocab_item`),
or a callable that accepts parameters (word, count, min_count) and returns either
:attr:`gensim.utils.RULE_DISCARD`, :attr:`gensim.utils.RULE_KEEP` or :attr:`gensim.utils.RULE_DEFAULT`.
The rule, if given, is only used to prune vocabulary during current method call and is not stored as part
of the model.
The input parameters are of the following types:
* `word` (str) - the word we are examining
* `count` (int) - the word's frequency count in the corpus
* `min_count` (int) - the minimum count threshold.
sg : {1, 0}, optional
Defines the training algorithm. If 1, skip-gram is used, otherwise, CBOW is employed.
alpha : float, optional
The beginning learning rate. This will linearly reduce with iterations until it reaches `min_alpha`.
window : int, optional
The maximum distance between the current and predicted word within a sentence.
seed : int, optional
Seed for the random number generator. Initial vectors for each word are seeded with a hash of
the concatenation of word + `str(seed)`.
Note that for a fully deterministically-reproducible run, you must also limit the model to a single worker
thread (`workers=1`), to eliminate ordering jitter from OS thread scheduling.
In Python 3, reproducibility between interpreter launches also requires use of the `PYTHONHASHSEED`
environment variable to control hash randomization.
hs : {1,0}, optional
If 1, hierarchical softmax will be used for model training.
If set to 0, and `negative` is non-zero, negative sampling will be used.
negative : int, optional
If > 0, negative sampling will be used, the int for negative specifies how many "noise words"
should be drawn (usually between 5-20).
If set to 0, no negative sampling is used.
cbow_mean : {1,0}, optional
If 0, use the sum of the context word vectors. If 1, use the mean, only applies when cbow is used.
min_alpha : float, optional
Final learning rate. Drops linearly with the number of iterations from `alpha`.
compute_loss : bool, optional
If True, loss will be computed while training the Word2Vec model and stored in
:attr:`~gensim.models.base_any2vec.BaseWordEmbeddingsModel.running_training_loss` attribute.
**kwargs : object
Key word arguments needed to allow children classes to accept more arguments.
"""
self.sg = int(sg)
if vector_size % 4 != 0:
logger.warning("consider setting layer size to a multiple of 4 for greater performance")
self.alpha = float(alpha)
self.window = int(window)
self.random = random.RandomState(seed)
self.min_alpha = float(min_alpha)
self.hs = int(hs)
self.negative = int(negative)
self.ns_exponent = ns_exponent
self.cbow_mean = int(cbow_mean)
self.compute_loss = bool(compute_loss)
self.running_training_loss = 0
self.min_alpha_yet_reached = float(alpha)
self.corpus_count = 0
self.corpus_total_words = 0
super(BaseWordEmbeddingsModel, self).__init__(
workers=workers, vector_size=vector_size, epochs=epochs, callbacks=callbacks, batch_words=batch_words)
if sentences is not None or corpus_file is not None:
self._check_input_data_sanity(data_iterable=sentences, corpus_file=corpus_file)
if corpus_file is not None and not isinstance(corpus_file, string_types):
raise TypeError("You must pass string as the corpus_file argument.")
elif isinstance(sentences, GeneratorType):
raise TypeError("You can't pass a generator as the sentences argument. Try a sequence.")
self.build_vocab(sentences=sentences, corpus_file=corpus_file, trim_rule=trim_rule)
self.train(
sentences=sentences, corpus_file=corpus_file, total_examples=self.corpus_count,
total_words=self.corpus_total_words, epochs=self.epochs, start_alpha=self.alpha,
end_alpha=self.min_alpha, compute_loss=compute_loss)
else:
if trim_rule is not None:
logger.warning(
"The rule, if given, is only used to prune vocabulary during build_vocab() "
"and is not stored as part of the model. Model initialized without sentences. "
"trim_rule provided, if any, will be ignored.")
# for backward compatibility (aliases pointing to corresponding variables in trainables, vocabulary)
@property
@deprecated("Attribute will be removed in 4.0.0, use self.epochs instead")
def iter(self):
return self.epochs
@iter.setter
@deprecated("Attribute will be removed in 4.0.0, use self.epochs instead")
def iter(self, value):
self.epochs = value
@property
@deprecated("Attribute will be removed in 4.0.0, use self.trainables.syn1 instead")
def syn1(self):
return self.trainables.syn1
@syn1.setter
@deprecated("Attribute will be removed in 4.0.0, use self.trainables.syn1 instead")
def syn1(self, value):
self.trainables.syn1 = value
@syn1.deleter
@deprecated("Attribute will be removed in 4.0.0, use self.trainables.syn1 instead")
def syn1(self):
del self.trainables.syn1
@property
@deprecated("Attribute will be removed in 4.0.0, use self.trainables.syn1neg instead")
def syn1neg(self):
return self.trainables.syn1neg
@syn1neg.setter
@deprecated("Attribute will be removed in 4.0.0, use self.trainables.syn1neg instead")
def syn1neg(self, value):
self.trainables.syn1neg = value
@syn1neg.deleter
@deprecated("Attribute will be removed in 4.0.0, use self.trainables.syn1neg instead")
def syn1neg(self):
del self.trainables.syn1neg
@property
@deprecated("Attribute will be removed in 4.0.0, use self.trainables.vectors_lockf instead")
def syn0_lockf(self):
return self.trainables.vectors_lockf
@syn0_lockf.setter
@deprecated("Attribute will be removed in 4.0.0, use self.trainables.vectors_lockf instead")
def syn0_lockf(self, value):
self.trainables.vectors_lockf = value
@syn0_lockf.deleter
@deprecated("Attribute will be removed in 4.0.0, use self.trainables.vectors_lockf instead")
def syn0_lockf(self):
del self.trainables.vectors_lockf
@property
@deprecated("Attribute will be removed in 4.0.0, use self.trainables.layer1_size instead")
def layer1_size(self):
return self.trainables.layer1_size
@layer1_size.setter
@deprecated("Attribute will be removed in 4.0.0, use self.trainables.layer1_size instead")
def layer1_size(self, value):
self.trainables.layer1_size = value
@property
@deprecated("Attribute will be removed in 4.0.0, use self.trainables.hashfxn instead")
def hashfxn(self):
return self.trainables.hashfxn
@hashfxn.setter
@deprecated("Attribute will be removed in 4.0.0, use self.trainables.hashfxn instead")
def hashfxn(self, value):
self.trainables.hashfxn = value
@property
@deprecated("Attribute will be removed in 4.0.0, use self.vocabulary.sample instead")
def sample(self):
return self.vocabulary.sample
@sample.setter
@deprecated("Attribute will be removed in 4.0.0, use self.vocabulary.sample instead")
def sample(self, value):
self.vocabulary.sample = value
@property
@deprecated("Attribute will be removed in 4.0.0, use self.vocabulary.min_count instead")
def min_count(self):
return self.vocabulary.min_count
@min_count.setter
@deprecated("Attribute will be removed in 4.0.0, use self.vocabulary.min_count instead")
def min_count(self, value):
self.vocabulary.min_count = value
@property
@deprecated("Attribute will be removed in 4.0.0, use self.vocabulary.cum_table instead")
def cum_table(self):
return self.vocabulary.cum_table
@cum_table.setter
@deprecated("Attribute will be removed in 4.0.0, use self.vocabulary.cum_table instead")
def cum_table(self, value):
self.vocabulary.cum_table = value
@cum_table.deleter
@deprecated("Attribute will be removed in 4.0.0, use self.vocabulary.cum_table instead")
def cum_table(self):
del self.vocabulary.cum_table
def __str__(self):
"""Get a human readable representation of the object.
Returns
-------
str
A human readable string containing the class name, as well as the size of dictionary, number of
features and starting learning rate used by the object.
"""
return "%s(vocab=%s, size=%s, alpha=%s)" % (
self.__class__.__name__, len(self.wv.index2word), self.vector_size, self.alpha
)
def build_vocab(self, sentences=None, corpus_file=None, update=False, progress_per=10000,
keep_raw_vocab=False, trim_rule=None, **kwargs):
"""Build vocabulary from a sequence of sentences (can be a once-only generator stream).
Parameters
----------
sentences : iterable of list of str
Can be simply a list of lists of tokens, but for larger corpora,
consider an iterable that streams the sentences directly from disk/network.
See :class:`~gensim.models.word2vec.BrownCorpus`, :class:`~gensim.models.word2vec.Text8Corpus`
or :class:`~gensim.models.word2vec.LineSentence` module for such examples.
corpus_file : str, optional
Path to a corpus file in :class:`~gensim.models.word2vec.LineSentence` format.
You may use this argument instead of `sentences` to get performance boost. Only one of `sentences` or
`corpus_file` arguments need to be passed (not both of them).
update : bool
If true, the new words in `sentences` will be added to model's vocab.
progress_per : int, optional
Indicates how many words to process before showing/updating the progress.
keep_raw_vocab : bool, optional
If False, the raw vocabulary will be deleted after the scaling is done to free up RAM.
trim_rule : function, optional
Vocabulary trimming rule, specifies whether certain words should remain in the vocabulary,
be trimmed away, or handled using the default (discard if word count < min_count).
Can be None (min_count will be used, look to :func:`~gensim.utils.keep_vocab_item`),
or a callable that accepts parameters (word, count, min_count) and returns either
:attr:`gensim.utils.RULE_DISCARD`, :attr:`gensim.utils.RULE_KEEP` or :attr:`gensim.utils.RULE_DEFAULT`.
The rule, if given, is only used to prune vocabulary during current method call and is not stored as part
of the model.
The input parameters are of the following types:
* `word` (str) - the word we are examining
* `count` (int) - the word's frequency count in the corpus
* `min_count` (int) - the minimum count threshold.
**kwargs : object
Key word arguments propagated to `self.vocabulary.prepare_vocab`
"""
total_words, corpus_count = self.vocabulary.scan_vocab(
sentences=sentences, corpus_file=corpus_file, progress_per=progress_per, trim_rule=trim_rule)
self.corpus_count = corpus_count
self.corpus_total_words = total_words
report_values = self.vocabulary.prepare_vocab(
self.hs, self.negative, self.wv, update=update, keep_raw_vocab=keep_raw_vocab,
trim_rule=trim_rule, **kwargs)
report_values['memory'] = self.estimate_memory(vocab_size=report_values['num_retained_words'])
self.trainables.prepare_weights(self.hs, self.negative, self.wv, update=update, vocabulary=self.vocabulary)
def build_vocab_from_freq(self, word_freq, keep_raw_vocab=False, corpus_count=None, trim_rule=None, update=False):
"""Build vocabulary from a dictionary of word frequencies.
Parameters
----------
word_freq : dict of (str, int)
A mapping from a word in the vocabulary to its frequency count.
keep_raw_vocab : bool, optional
If False, delete the raw vocabulary after the scaling is done to free up RAM.
corpus_count : int, optional
Even if no corpus is provided, this argument can set corpus_count explicitly.
trim_rule : function, optional
Vocabulary trimming rule, specifies whether certain words should remain in the vocabulary,
be trimmed away, or handled using the default (discard if word count < min_count).
Can be None (min_count will be used, look to :func:`~gensim.utils.keep_vocab_item`),
or a callable that accepts parameters (word, count, min_count) and returns either
:attr:`gensim.utils.RULE_DISCARD`, :attr:`gensim.utils.RULE_KEEP` or :attr:`gensim.utils.RULE_DEFAULT`.
The rule, if given, is only used to prune vocabulary during current method call and is not stored as part
of the model.
The input parameters are of the following types:
* `word` (str) - the word we are examining
* `count` (int) - the word's frequency count in the corpus
* `min_count` (int) - the minimum count threshold.
update : bool, optional
If true, the new provided words in `word_freq` dict will be added to model's vocab.
"""
logger.info("Processing provided word frequencies")
# Instead of scanning text, this will assign provided word frequencies dictionary(word_freq)
# to be directly the raw vocab
raw_vocab = word_freq
logger.info(
"collected %i different raw word, with total frequency of %i",
len(raw_vocab), sum(itervalues(raw_vocab))
)
# Since no sentences are provided, this is to control the corpus_count.
self.corpus_count = corpus_count or 0
self.vocabulary.raw_vocab = raw_vocab
# trim by min_count & precalculate downsampling
report_values = self.vocabulary.prepare_vocab(
self.hs, self.negative, self.wv, keep_raw_vocab=keep_raw_vocab,
trim_rule=trim_rule, update=update)
report_values['memory'] = self.estimate_memory(vocab_size=report_values['num_retained_words'])
self.trainables.prepare_weights(
self.hs, self.negative, self.wv, update=update, vocabulary=self.vocabulary) # build tables & arrays
def estimate_memory(self, vocab_size=None, report=None):
"""Estimate required memory for a model using current settings and provided vocabulary size.
Parameters
----------
vocab_size : int, optional
Number of unique tokens in the vocabulary
report : dict of (str, int), optional
A dictionary from string representations of the model's memory consuming members to their size in bytes.
Returns
-------
dict of (str, int)
A dictionary from string representations of the model's memory consuming members to their size in bytes.
"""
vocab_size = vocab_size or len(self.wv.vocab)
report = report or {}
report['vocab'] = vocab_size * (700 if self.hs else 500)
report['vectors'] = vocab_size * self.vector_size * dtype(REAL).itemsize