Skip to content

Commit

Permalink
Update vocabulary
Browse files Browse the repository at this point in the history
  • Loading branch information
Hironsan committed May 30, 2018
1 parent 0c0b8a1 commit 35ec1db
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 16 deletions.
40 changes: 28 additions & 12 deletions anago/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,33 +101,33 @@ def data_generator():


class Vocabulary(object):
"""
A vocabulary that maps tokens to ints (storing a vocabulary)
"""A vocabulary that maps tokens to ints (storing a vocabulary).
Attributes:
_token_count: A collections.Counter object holding the frequencies of tokens
in the data used to build the Vocab.
in the data used to build the Vocabulary.
_token2id: A collections.defaultdict instance mapping token strings to
numerical identifiers.
_id2token: A list of token strings indexed by their numerical identifiers.
"""

def __init__(self, max_size=None, lower=True, unk=True, specials=('<pad>',)):
"""Create a Vocab object from a collections.Counter.
def __init__(self, max_size=None, lower=True, unk_token=True, specials=('<pad>',)):
"""Create a Vocabulary object.
Args:
max_size: The maximum size of the vocabulary, or None for no
maximum. Default: None.
lower: boolean. Whether to convert the texts to lowercase.
unk_token: boolean. Whether to add unknown token.
specials: The list of special tokens (e.g., padding or eos) that
will be prepended to the vocabulary in addition to an <unk>
token. Default: ['<pad>']
will be prepended to the vocabulary. Default: ('<pad>',)
"""
self._max_size = max_size
self._lower = lower
self._unk = unk_token
self._token2id = {token: i for i, token in enumerate(specials)}
self._id2token = list(specials)
self._lower = lower
self._max_size = max_size
self._token_count = Counter()
self._unk = unk

def __len__(self):
return len(self._token2id)
Expand All @@ -138,6 +138,7 @@ def add_token(self, token):
Args:
token (str): token to add.
"""
token = self.process_token(token)
self._token_count.update([token])

def add_documents(self, docs):
Expand All @@ -148,6 +149,7 @@ def add_documents(self, docs):
docs (list): documents to add.
"""
for sent in docs:
sent = map(self.process_token, sent)
self._token_count.update(sent)

def doc2id(self, doc):
Expand All @@ -159,6 +161,7 @@ def doc2id(self, doc):
Returns:
list: int id of doc.
"""
doc = map(self.process_token, doc)
return [self.token_to_id(token) for token in doc]

def id2doc(self, ids):
Expand Down Expand Up @@ -188,6 +191,18 @@ def build(self):
self._id2token.append(unk)

def process_token(self, token):
"""Process token before following methods:
* add_token
* add_documents
* doc2id
* token_to_id
Args:
token (str): token to process.
Returns:
str: processed token string.
"""
if self._lower:
token = token.lower()

Expand All @@ -202,6 +217,7 @@ def token_to_id(self, token):
Returns:
int: int id of token.
"""
token = self.process_token(token)
return self._token2id.get(token, len(self._token2id) - 1)

def id_to_token(self, idx):
Expand All @@ -220,7 +236,7 @@ def vocab(self):
"""Return the vocabulary.
Returns:
dict: get the dict object of the vocabulary
dict: get the dict object of the vocabulary.
"""
return self._token2id

Expand All @@ -229,6 +245,6 @@ def reverse_vocab(self):
"""Return the vocabulary as a reversed dict object.
Returns:
dict: reversed vocabulary object
dict: reversed vocabulary object.
"""
return self._id2token
16 changes: 12 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,19 @@ def test_add_documents(self):
self.assertEqual(vocab._token2id, token2id)

token2id = {'<pad>': 0, 'a': 1, 'b': 2, 'c': 3}
vocab = Vocabulary(unk=False)
vocab = Vocabulary(unk_token=False)
vocab.add_documents(docs)
vocab.build()
self.assertEqual(vocab._token2id, token2id)

token2id = {'<pad>': 0, '<s>': 1, 'a': 2, 'b': 3, 'c': 4}
vocab = Vocabulary(unk=False, specials=('<pad>', '<s>'))
vocab = Vocabulary(unk_token=False, specials=('<pad>', '<s>'))
vocab.add_documents(docs)
vocab.build()
self.assertEqual(vocab._token2id, token2id)

token2id = {'a': 0, 'b': 1, 'c': 2}
vocab = Vocabulary(unk=False, specials=())
vocab = Vocabulary(unk_token=False, specials=())
vocab.add_documents(docs)
vocab.build()
self.assertEqual(vocab._token2id, token2id)
Expand Down Expand Up @@ -80,4 +80,12 @@ def test_doc2id(self):
self.assertEqual(doc_ids, correct)

def test_id2doc(self):
pass
# word ids.
docs = [['B-PSN'], ['B-ORG', 'I-ORG'], ['B-LOC', 'I-LOC', 'O']]
vocab = Vocabulary(unk_token=False, lower=False)
vocab.add_documents(docs)
vocab.build()
true_doc = ['O', 'B-LOC', 'O', 'O']
doc_ids = vocab.doc2id(true_doc)
pred_doc = vocab.id2doc(doc_ids)
self.assertEqual(pred_doc, true_doc)

0 comments on commit 35ec1db

Please sign in to comment.