Skip to content

Commit

Permalink
Record lifecycle events in Gensim models (#3060)
Browse files Browse the repository at this point in the history
* fix docs

* re #2863: record lifecycle events

* log events even if lifecycle attribute is turned off

* fix overlong lines + use f-strings

* bump up internal version + improve docs

* record more info for load + save

* lifecycle events for KeyedVectors

* add lifecycle events to remaining models

* ask for lifecycle log in ISSUE_TEMPLATE

* improve logging

* Update word2vec.py

remove unused import

* Update CHANGELOG.md

Co-authored-by: Michael Penkov <m@penkov.dev>
  • Loading branch information
piskvorky and mpenkov committed Mar 7, 2021
1 parent dd9c01c commit 60ad052
Show file tree
Hide file tree
Showing 18 changed files with 230 additions and 92 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -4,6 +4,7 @@ Changes
## Unreleased

- fix RuntimeError in export_phrases (change defaultdict to dict) (PR [#3041](https://github.com/RaRe-Technologies/gensim/pull/3041), [@thalishsajeed](https://github.com/thalishsajeed))
- Record lifecycle events in Gensim models (PR [#3060](https://github.com/RaRe-Technologies/gensim/pull/3060), [@piskvorky](https://github.com/piskvorky))

## 4.0.0beta, 2020-10-31

Expand Down
6 changes: 6 additions & 0 deletions ISSUE_TEMPLATE.md
Expand Up @@ -15,6 +15,12 @@ What are you trying to achieve? What is the expected result? What are you seeing

Include full tracebacks, logs and datasets if necessary. Please keep the examples minimal ("minimal reproducible example").

If your problem is with a specific Gensim model (word2vec, lsimodel, doc2vec, fasttext, ldamodel etc), include the following:

```python
print(my_model.lifecycle_events)
```

#### Versions

Please provide the output of:
Expand Down
2 changes: 1 addition & 1 deletion docs/src/conf.py
Expand Up @@ -63,7 +63,7 @@
# The short X.Y version.
version = '4.0.0beta'
# The full version, including alpha/beta/rc tags.
release = '4.0.0beta'
release = '4.0.0rc1'

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
3 changes: 2 additions & 1 deletion gensim/__init__.py
Expand Up @@ -4,11 +4,12 @@
"""

__version__ = '4.0.0rc1'

import logging

from gensim import parsing, corpora, matutils, interfaces, models, similarities, utils # noqa:F401

__version__ = '4.0.0beta'

logger = logging.getLogger('gensim')
if not logger.handlers: # To ensure reload() doesn't add another one
Expand Down
4 changes: 4 additions & 0 deletions gensim/corpora/dictionary.py
Expand Up @@ -77,6 +77,10 @@ def __init__(self, documents=None, prune_at=2000000):

if documents is not None:
self.add_documents(documents, prune_at=prune_at)
self.add_lifecycle_event(
"created",
msg=f"built {self} from {self.num_docs} documents (total {self.num_pos} corpus positions)",
)

def __getitem__(self, tokenid):
"""Get the string token that corresponds to `tokenid`.
Expand Down
6 changes: 0 additions & 6 deletions gensim/models/doc2vec.py
Expand Up @@ -51,12 +51,6 @@
>>> model.save(fname)
>>> model = Doc2Vec.load(fname) # you can continue training with the loaded model!
If you're finished training a model (=no more updates, only querying, reduce memory usage), you can do:
.. sourcecode:: pycon
>>> model.delete_temporary_training_data(keep_doctags_vectors=True, keep_inference=True)
Infer vector for a new document:
.. sourcecode:: pycon
Expand Down
7 changes: 5 additions & 2 deletions gensim/models/fasttext.py
Expand Up @@ -653,7 +653,7 @@ def _pad_ones(m, new_len):


def load_facebook_model(path, encoding='utf-8'):
"""Load the input-hidden weight matrix from Facebook's native fasttext `.bin` output file.
"""Load the model from Facebook's native fasttext `.bin` output file.
Notes
------
Expand Down Expand Up @@ -835,7 +835,10 @@ def _load_fasttext_format(model_file, encoding='utf-8', full_model=True):

_check_model(model)

logger.info("loaded %s weight matrix for fastText model from %s", m.vectors_ngrams.shape, fin.name)
model.add_lifecycle_event(
"load_fasttext_format",
msg=f"loaded {m.vectors_ngrams.shape} weight matrix for fastText model from {fin.name}",
)
return model


Expand Down
58 changes: 34 additions & 24 deletions gensim/models/keyedvectors.py
Expand Up @@ -191,8 +191,9 @@


class KeyedVectors(utils.SaveLoad):

def __init__(self, vector_size, count=0, dtype=np.float32, mapfile_path=None):
"""Mapping between keys (such as words) and vectors for :class:`~gensim.models.Word2Vec`
"""Mapping between keys (such as words) and vectors for :class:`~gensim.models.Word2Vec`
and related models.
Used to perform operations on the vectors such as vector lookup, distance, similarity etc.
Expand All @@ -215,7 +216,7 @@ def __init__(self, vector_size, count=0, dtype=np.float32, mapfile_path=None):
Vector dimensions will default to `np.float32` (AKA `REAL` in some Gensim code) unless
another type is provided here.
mapfile_path : string, optional
FIXME: UNDER CONSTRUCTION / WILL CHANGE PRE-4.0.0 PER #2955 / #2975.
Currently unused.
"""
self.vector_size = vector_size
# pre-allocating `index_to_key` to full size helps avoid redundant re-allocations, esp for `expandos`
Expand Down Expand Up @@ -259,7 +260,7 @@ def _load_specials(self, *args, **kwargs):
self._upconvert_old_vocab()

def _upconvert_old_vocab(self):
"""Convert a loaded, pre-gensim-4.0.0 version instance that had a 'vocab' dict of data objects"""
"""Convert a loaded, pre-gensim-4.0.0 version instance that had a 'vocab' dict of data objects."""
old_vocab = self.__dict__.pop('vocab', None)
self.key_to_index = {}
for k in old_vocab.keys():
Expand All @@ -277,6 +278,7 @@ def allocate_vecattrs(self, attrs=None, types=None):
The length of the index_to_key list is canonical 'intended size' of KeyedVectors,
even if other properties (vectors array) hasn't yet been allocated or expanded.
So this allocation targets that size.
"""
# with no arguments, adjust lengths of existing vecattr arrays to match length of index_to_key
if attrs is None:
Expand Down Expand Up @@ -351,13 +353,8 @@ def get_vecattr(self, key, attr):

def resize_vectors(self, seed=0):
"""Make underlying vectors match index_to_key size; random-initialize any new rows."""

target_shape = (len(self.index_to_key), self.vector_size)
self.vectors = prep_vectors(target_shape, prior_vectors=self.vectors, seed=seed)
# FIXME BEFORE 4.0.0 PER #2955 / #2975 : support memmap & cleanup
# if hasattr(self, 'mapfile_path') and self.mapfile_path:
# self.vectors = np.memmap(self.mapfile_path, shape=(target_count, self.vector_size), mode='w+', dtype=REAL)

self.allocate_vecattrs()
self.norms = None

Expand All @@ -370,7 +367,7 @@ def __getitem__(self, key_or_keys):
Parameters
----------
key_or_keys : {str, list of str, int, list of int}
Requested key or list-of-keys
Requested key or list-of-keys.
Returns
-------
Expand Down Expand Up @@ -784,7 +781,7 @@ def most_similar(
return result[:topn]

def similar_by_word(self, word, topn=10, restrict_vocab=None):
"""Compatibility alias for similar_by_key()"""
"""Compatibility alias for similar_by_key()."""
return self.similar_by_key(word, topn, restrict_vocab)

def similar_by_key(self, key, topn=10, restrict_vocab=None):
Expand Down Expand Up @@ -1193,14 +1190,19 @@ def _log_evaluate_word_analogies(section):
Returns
-------
float
Accuracy score.
Accuracy score if at least one prediction was made (correct or incorrect).
Or return 0.0 if there were no predictions at all in this section.
"""
correct, incorrect = len(section['correct']), len(section['incorrect'])
if correct + incorrect > 0:
score = correct / (correct + incorrect)
logger.info("%s: %.1f%% (%i/%i)", section['section'], 100.0 * score, correct, correct + incorrect)
return score

if correct + incorrect == 0:
return 0.0

score = correct / (correct + incorrect)
logger.info("%s: %.1f%% (%i/%i)", section['section'], 100.0 * score, correct, correct + incorrect)
return score

def evaluate_word_analogies(self, analogies, restrict_vocab=300000, case_insensitive=True, dummy4unknown=False):
"""Compute performance of the model on an analogy test set.
Expand Down Expand Up @@ -1324,7 +1326,7 @@ def log_accuracy(section):
if correct + incorrect > 0:
logger.info(
"%s: %.1f%% (%i/%i)",
section['section'], 100.0 * correct / (correct + incorrect), correct, correct + incorrect
section['section'], 100.0 * correct / (correct + incorrect), correct, correct + incorrect,
)

@staticmethod
Expand Down Expand Up @@ -1463,7 +1465,7 @@ def init_sims(self, replace=False):
def unit_normalize_all(self):
"""Destructively scale all vectors to unit-length.
(You cannot sensibly continue training after such a step.)
You cannot sensibly continue training after such a step.
"""
self.fill_norms()
Expand Down Expand Up @@ -1495,7 +1497,8 @@ def relative_cosine_similarity(self, wa, wb, topn=10):
"""
sims = self.similar_by_word(wa, topn)
assert sims, "Failed code invariant: list of similar words must never be empty."
if not sims:
raise ValueError("Cannot calculate relative cosine similarity without any similar words.")
rcs = float(self.similarity(wa, wb)) / (sum(sim for _, sim in sims))

return rcs
Expand Down Expand Up @@ -1583,7 +1586,7 @@ def load_word2vec_format(
cls, fname, fvocab=None, binary=False, encoding='utf8', unicode_errors='strict',
limit=None, datatype=REAL, no_header=False,
):
"""Load the input-hidden weight matrix from the original C word2vec-tool format.
"""Load KeyedVectors from a file produced by the original C word2vec-tool format.
Warnings
--------
Expand Down Expand Up @@ -1660,7 +1663,7 @@ def intersect_word2vec_format(self, fname, lockf=0.0, binary=False, encoding='ut
vocab_size, vector_size = (int(x) for x in header.split()) # throws for invalid file format
if not vector_size == self.vector_size:
raise ValueError("incompatible vector size %d in file %s" % (vector_size, fname))
# TOCONSIDER: maybe mismatched vectors still useful enough to merge (truncating/padding)?
# TODO: maybe mismatched vectors still useful enough to merge (truncating/padding)?
if binary:
binary_len = dtype(REAL).itemsize * vector_size
for _ in range(vocab_size):
Expand Down Expand Up @@ -1688,7 +1691,10 @@ def intersect_word2vec_format(self, fname, lockf=0.0, binary=False, encoding='ut
overlap_count += 1
self.vectors[self.get_index(word)] = weights
self.vectors_lockf[self.get_index(word)] = lockf # lock-factor: 0.0=no changes
logger.info("merged %d vectors into %s matrix from %s", overlap_count, self.wv.vectors.shape, fname)
self.add_lifecycle_event(
"intersect_word2vec_format",
msg=f"merged {overlap_count} vectors into {self.vectors.shape} matrix from {fname}",
)

def _upconvert_old_d2vkv(self):
"""Convert a deserialized older Doc2VecKeyedVectors instance to latest generic KeyedVectors"""
Expand Down Expand Up @@ -1721,6 +1727,7 @@ def similarity_unseen_docs(self, *args, **kwargs):


class CompatVocab:

def __init__(self, **kwargs):
"""A single vocabulary item, used internally for collecting per-word frequency/sampling info,
and for constructing binary trees (incl. both word leaves and inner nodes).
Expand Down Expand Up @@ -1847,7 +1854,7 @@ def _load_word2vec_format(
fname : str
The file path to the saved word2vec-format file.
fvocab : str, optional
File path to the vocabulary.Word counts are read from `fvocab` filename, if set
File path to the vocabulary. Word counts are read from `fvocab` filename, if set
(this is the file generated by `-save-vocab` flag of the original C tool).
binary : bool, optional
If True, indicates whether the data is in binary word2vec format.
Expand Down Expand Up @@ -1913,7 +1920,11 @@ def _load_word2vec_format(
kv.vectors = ascontiguousarray(kv.vectors[: len(kv)])
assert (len(kv), vector_size) == kv.vectors.shape

logger.info("loaded %s matrix from %s", kv.vectors.shape, fname)
kv.add_lifecycle_event(
"load_word2vec_format",
msg=f"loaded {kv.vectors.shape} matrix of type {kv.vectors.dtype} from {fname}",
binary=binary, encoding=encoding,
)
return kv


Expand All @@ -1939,7 +1950,6 @@ def prep_vectors(target_shape, prior_vectors=None, seed=0, dtype=REAL):
"""Return a numpy array of the given shape. Reuse prior_vectors object or values
to extent possible. Initialize new values randomly if requested.
FIXME: NAME/DOCS CHANGES PRE-4.0.0 FOR #2955/#2975 MMAP & OTHER INITIALIZATION CLEANUP WORK.
"""
if prior_vectors is None:
prior_vectors = np.zeros((0, 0))
Expand Down
12 changes: 9 additions & 3 deletions gensim/models/ldamodel.py
Expand Up @@ -88,16 +88,17 @@
import logging
import numbers
import os
import time
from collections import defaultdict

import numpy as np
from scipy.special import gammaln, psi # gamma function utils
from scipy.special import polygamma
from collections import defaultdict

from gensim import interfaces, utils, matutils
from gensim.matutils import (
kullback_leibler, hellinger, jaccard_distance, jensen_shannon,
dirichlet_expectation, logsumexp, mean_absolute_difference
dirichlet_expectation, logsumexp, mean_absolute_difference,
)
from gensim.models import basemodel, CoherenceModel
from gensim.models.callbacks import Callback
Expand Down Expand Up @@ -375,7 +376,7 @@ def __init__(self, corpus=None, num_topics=100, id2word=None,
Set to 0 for batch learning, > 1 for online iterative learning.
alpha : {numpy.ndarray, str}, optional
Can be set to an 1D array of length equal to the number of expected topics that expresses
our a-priori belief for the each topics' probability.
our a-priori belief for each topics' probability.
Alternatively default prior selecting strategies can be employed by supplying a string:
* 'symmetric': Default; uses a fixed symmetric prior per topic,
Expand Down Expand Up @@ -518,7 +519,12 @@ def __init__(self, corpus=None, num_topics=100, id2word=None,
# if a training corpus was provided, start estimating the model right away
if corpus is not None:
use_numpy = self.dispatcher is not None
start = time.time()
self.update(corpus, chunks_as_numpy=use_numpy)
self.add_lifecycle_event(
"created",
msg=f"trained {self} in {time.time() - start:.2f}s",
)

def init_dir_prior(self, prior, name):
"""Initialize priors for the Dirichlet distribution.
Expand Down
2 changes: 1 addition & 1 deletion gensim/models/ldamulticore.py
Expand Up @@ -181,7 +181,7 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, workers=None,
id2word=id2word, chunksize=chunksize, passes=passes, alpha=alpha, eta=eta,
decay=decay, offset=offset, eval_every=eval_every, iterations=iterations,
gamma_threshold=gamma_threshold, random_state=random_state, minimum_probability=minimum_probability,
minimum_phi_value=minimum_phi_value, per_word_topics=per_word_topics, dtype=dtype
minimum_phi_value=minimum_phi_value, per_word_topics=per_word_topics, dtype=dtype,
)

def update(self, corpus, chunks_as_numpy=False):
Expand Down
20 changes: 13 additions & 7 deletions gensim/models/lsimodel.py
Expand Up @@ -61,6 +61,7 @@

import logging
import sys
import time

import numpy as np
import scipy.linalg
Expand Down Expand Up @@ -351,17 +352,17 @@ class LsiModel(interfaces.TransformationABC, basemodel.BaseTopicModel):
"""

def __init__(self, corpus=None, num_topics=200, id2word=None, chunksize=20000,
decay=1.0, distributed=False, onepass=True,
power_iters=P2_EXTRA_ITERS, extra_samples=P2_EXTRA_DIMS, dtype=np.float64):
"""Construct an `LsiModel` object.
Either `corpus` or `id2word` must be supplied in order to train the model.
def __init__(
self, corpus=None, num_topics=200, id2word=None, chunksize=20000,
decay=1.0, distributed=False, onepass=True,
power_iters=P2_EXTRA_ITERS, extra_samples=P2_EXTRA_DIMS, dtype=np.float64
):
"""Build an LSI model.
Parameters
----------
corpus : {iterable of list of (int, float), scipy.sparse.csc}, optional
Stream of document vectors or sparse matrix of shape (`num_documents`, `num_terms`).
Stream of document vectors or a sparse matrix of shape (`num_documents`, `num_terms`).
num_topics : int, optional
Number of requested factors (latent dimensions)
id2word : dict of {int: str}, optional
Expand Down Expand Up @@ -440,7 +441,12 @@ def __init__(self, corpus=None, num_topics=200, id2word=None, chunksize=20000,
raise RuntimeError("failed to initialize distributed LSI (%s)" % err)

if corpus is not None:
start = time.time()
self.add_documents(corpus)
self.add_lifecycle_event(
"created",
msg=f"trained {self} in {time.time() - start:.2f}s",
)

def add_documents(self, corpus, chunksize=None, decay=None):
"""Update model with new `corpus`.
Expand Down

0 comments on commit 60ad052

Please sign in to comment.