Skip to content

Commit

Permalink
Merge pull request #1402 from vlejd/56_fix_dict_save_as_text
Browse files Browse the repository at this point in the history
Fix Dictionary save_as_text method #56 + fix lint errors
  • Loading branch information
menshikh-iv committed Jun 15, 2017
2 parents f62ae5f + 76ed41d commit 09e16cd
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 35 deletions.
23 changes: 16 additions & 7 deletions gensim/corpora/dictionary.py
Expand Up @@ -24,13 +24,13 @@

from gensim import utils

if sys.version_info[0] >= 3:
unicode = str

from six import PY3, iteritems, iterkeys, itervalues, string_types
from six.moves import xrange
from six.moves import zip as izip

if sys.version_info[0] >= 3:
unicode = str


logger = logging.getLogger('gensim.corpora.dictionary')

Expand Down Expand Up @@ -180,7 +180,7 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=N
2. more than `no_above` documents (fraction of total corpus size, *not*
absolute number).
3. if tokens are given in keep_tokens (list of strings), they will be kept regardless of
the `no_below` and `no_above` settings
the `no_below` and `no_above` settings
4. after (1), (2) and (3), keep only the first `keep_n` most frequent tokens (or
keep all if `None`).
Expand All @@ -196,8 +196,7 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=N
keep_ids = [self.token2id[v] for v in keep_tokens if v in self.token2id]
good_ids = (
v for v in itervalues(self.token2id)
if no_below <= self.dfs.get(v, 0) <= no_above_abs
or v in keep_ids
if no_below <= self.dfs.get(v, 0) <= no_above_abs or v in keep_ids
)
else:
good_ids = (
Expand Down Expand Up @@ -232,7 +231,7 @@ def filter_n_most_frequent(self, remove_n):
# do the actual filtering, then rebuild dictionary to remove gaps in ids
most_frequent_words = [(self[id], self.dfs.get(id, 0)) for id in most_frequent_ids]
logger.info("discarding %i tokens: %s...", len(most_frequent_ids), most_frequent_words[:10])

self.filter_tokens(bad_ids=most_frequent_ids)
logger.info("resulting dictionary: %s" % self)

Expand Down Expand Up @@ -282,6 +281,7 @@ def compactify(self):
def save_as_text(self, fname, sort_by_word=True):
"""
Save this Dictionary to a text file, in format:
`num_docs`
`id[TAB]word_utf8[TAB]document frequency[NEWLINE]`. Sorted by word,
or by decreasing word frequency.
Expand All @@ -290,6 +290,8 @@ def save_as_text(self, fname, sort_by_word=True):
"""
logger.info("saving dictionary mapping to %s", fname)
with utils.smart_open(fname, 'wb') as fout:
numdocs_line = "%d\n" % self.num_docs
fout.write(utils.to_utf8(numdocs_line))
if sort_by_word:
for token, tokenid in sorted(iteritems(self.token2id)):
line = "%i\t%s\t%i\n" % (tokenid, token, self.dfs.get(tokenid, 0))
Expand Down Expand Up @@ -354,6 +356,13 @@ def load_from_text(fname):
with utils.smart_open(fname) as f:
for lineno, line in enumerate(f):
line = utils.to_unicode(line)
if lineno == 0:
if line.strip().isdigit():
# Older versions of save_as_text may not write num_docs on first line.
result.num_docs = int(line.strip())
continue
else:
logging.warning("Text does not contain num_docs on the first line.")
try:
wordid, word, docfreq = line[:-1].split('\t')
except Exception:
Expand Down
116 changes: 88 additions & 28 deletions gensim/test/test_corpora_dictionary.py
Expand Up @@ -120,53 +120,52 @@ def testFilter(self):
d.filter_extremes(no_below=2, no_above=1.0, keep_n=4)
expected = {0: 3, 1: 3, 2: 3, 3: 3}
self.assertEqual(d.dfs, expected)

def testFilterKeepTokens_keepTokens(self):
# provide keep_tokens argument, keep the tokens given
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['human', 'survey'])
expected = set(['graph', 'trees', 'human', 'system', 'user', 'survey'])
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterKeepTokens_unchangedFunctionality(self):
# do not provide keep_tokens argument, filter_extremes functionality is unchanged
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0)
expected = set(['graph', 'trees', 'system', 'user'])
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterKeepTokens_unseenToken(self):
# do provide keep_tokens argument with unseen tokens, filter_extremes functionality is unchanged
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['unknown_token'])
expected = set(['graph', 'trees', 'system', 'user'])
self.assertEqual(set(d.token2id.keys()), expected)
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterMostFrequent(self):
d = Dictionary(self.texts)
d.filter_n_most_frequent(4)
expected = {0: 2, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2}
self.assertEqual(d.dfs, expected)


d = Dictionary(self.texts)
d.filter_n_most_frequent(4)
expected = {0: 2, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2}
self.assertEqual(d.dfs, expected)

def testFilterTokens(self):
self.maxDiff = 10000
d = Dictionary(self.texts)

removed_word = d[0]
d.filter_tokens([0])

expected = {'computer': 0, 'eps': 8, 'graph': 10, 'human': 1,
'interface': 2, 'minors': 11, 'response': 3, 'survey': 4,
'system': 5, 'time': 6, 'trees': 9, 'user': 7}
expected = {
'computer': 0, 'eps': 8, 'graph': 10, 'human': 1,
'interface': 2, 'minors': 11, 'response': 3, 'survey': 4,
'system': 5, 'time': 6, 'trees': 9, 'user': 7}
del expected[removed_word]
self.assertEqual(sorted(d.token2id.keys()), sorted(expected.keys()))

expected[removed_word] = len(expected)
d.add_documents([[removed_word]])
self.assertEqual(sorted(d.token2id.keys()), sorted(expected.keys()))


def test_doc2bow(self):
d = Dictionary([["žluťoučký"], ["žluťoučký"]])

Expand All @@ -179,6 +178,66 @@ def test_doc2bow(self):
# unicode must be converted to utf8
self.assertEqual(d.doc2bow([u'\u017elu\u0165ou\u010dk\xfd']), [(0, 1)])

def test_saveAsText(self):
"""`Dictionary` can be saved as textfile. """
tmpf = get_tmpfile('save_dict_test.txt')
small_text = [
["prvé", "slovo"],
["slovo", "druhé"],
["druhé", "slovo"]]

d = Dictionary(small_text)

d.save_as_text(tmpf)
with open(tmpf) as file:
serialized_lines = file.readlines()
self.assertEqual(serialized_lines[0], "3\n")
self.assertEqual(len(serialized_lines), 4)
# We do not know, which word will have which index
self.assertEqual(serialized_lines[1][1:], "\tdruhé\t2\n")
self.assertEqual(serialized_lines[2][1:], "\tprvé\t1\n")
self.assertEqual(serialized_lines[3][1:], "\tslovo\t3\n")

d.save_as_text(tmpf, sort_by_word=False)
with open(tmpf) as file:
serialized_lines = file.readlines()
self.assertEqual(serialized_lines[0], "3\n")
self.assertEqual(len(serialized_lines), 4)
self.assertEqual(serialized_lines[1][1:], "\tslovo\t3\n")
self.assertEqual(serialized_lines[2][1:], "\tdruhé\t2\n")
self.assertEqual(serialized_lines[3][1:], "\tprvé\t1\n")

def test_loadFromText_legacy(self):
"""
`Dictionary` can be loaded from textfile in legacy format.
Legacy format does not have num_docs on the first line.
"""
tmpf = get_tmpfile('load_dict_test_legacy.txt')
no_num_docs_serialization = "1\tprvé\t1\n2\tslovo\t2\n"
with open(tmpf, "w") as file:
file.write(no_num_docs_serialization)

d = Dictionary.load_from_text(tmpf)
self.assertEqual(d.token2id[u"prvé"], 1)
self.assertEqual(d.token2id[u"slovo"], 2)
self.assertEqual(d.dfs[1], 1)
self.assertEqual(d.dfs[2], 2)
self.assertEqual(d.num_docs, 0)

def test_loadFromText(self):
"""`Dictionary` can be loaded from textfile."""
tmpf = get_tmpfile('load_dict_test.txt')
no_num_docs_serialization = "2\n1\tprvé\t1\n2\tslovo\t2\n"
with open(tmpf, "w") as file:
file.write(no_num_docs_serialization)

d = Dictionary.load_from_text(tmpf)
self.assertEqual(d.token2id[u"prvé"], 1)
self.assertEqual(d.token2id[u"slovo"], 2)
self.assertEqual(d.dfs[1], 1)
self.assertEqual(d.dfs[2], 2)
self.assertEqual(d.num_docs, 2)

def test_saveAsText_and_loadFromText(self):
"""`Dictionary` can be saved as textfile and loaded again from textfile. """
tmpf = get_tmpfile('dict_test.txt')
Expand All @@ -194,24 +253,25 @@ def test_saveAsText_and_loadFromText(self):
def test_from_corpus(self):
"""build `Dictionary` from an existing corpus"""

documents = ["Human machine interface for lab abc computer applications",
"A survey of user opinion of computer system response time",
"The EPS user interface management system",
"System and human system engineering testing of EPS",
"Relation of user perceived response time to error measurement",
"The generation of random binary unordered trees",
"The intersection graph of paths in trees",
"Graph minors IV Widths of trees and well quasi ordering",
"Graph minors A survey"]
documents = [
"Human machine interface for lab abc computer applications",
"A survey of user opinion of computer system response time",
"The EPS user interface management system",
"System and human system engineering testing of EPS",
"Relation of user perceived response time to error measurement",
"The generation of random binary unordered trees",
"The intersection graph of paths in trees",
"Graph minors IV Widths of trees and well quasi ordering",
"Graph minors A survey"]
stoplist = set('for a of the and to in'.split())
texts = [[word for word in document.lower().split() if word not in stoplist]
for document in documents]
texts = [
[word for word in document.lower().split() if word not in stoplist]
for document in documents]

# remove words that appear only once
all_tokens = sum(texts, [])
tokens_once = set(word for word in set(all_tokens) if all_tokens.count(word) == 1)
texts = [[word for word in text if word not in tokens_once]
for text in texts]
texts = [[word for word in text if word not in tokens_once] for text in texts]

dictionary = Dictionary(texts)
corpus = [dictionary.doc2bow(text) for text in texts]
Expand Down Expand Up @@ -260,7 +320,7 @@ def test_dict_interface(self):
self.assertTrue(isinstance(d.keys(), list))
self.assertTrue(isinstance(d.values(), list))

#endclass TestDictionary
# endclass TestDictionary


if __name__ == '__main__':
Expand Down

0 comments on commit 09e16cd

Please sign in to comment.