Skip to content

Commit

Permalink
Update vocabulary
Browse files Browse the repository at this point in the history
  • Loading branch information
Hironsan committed May 29, 2018
1 parent dfd71e4 commit 0c0b8a1
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 33 deletions.
117 changes: 88 additions & 29 deletions anago/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,73 +102,132 @@ def data_generator():

class Vocabulary(object):
"""
A vocabulary that maps words to ints (storing a vocabulary)
A vocabulary that maps tokens to ints (storing a vocabulary)
Attributes:
_token_count: A collections.Counter object holding the frequencies of tokens
in the data used to build the Vocab.
_token2id: A collections.defaultdict instance mapping token strings to
numerical identifiers.
_id2token: A list of token strings indexed by their numerical identifiers.
"""

def __init__(self, num_words=None, lower=True, start=1, oov_token=None):
self._token2id = {}
self._id2token = {}
self._start = start
def __init__(self, max_size=None, lower=True, unk=True, specials=('<pad>',)):
"""Create a Vocab object from a collections.Counter.
Args:
max_size: The maximum size of the vocabulary, or None for no
maximum. Default: None.
specials: The list of special tokens (e.g., padding or eos) that
will be prepended to the vocabulary in addition to an <unk>
token. Default: ['<pad>']
"""
self._token2id = {token: i for i, token in enumerate(specials)}
self._id2token = list(specials)
self._lower = lower
self._num_words = num_words
self._max_size = max_size
self._token_count = Counter()
self._unk = unk

def add_word(self, token):
def __len__(self):
return len(self._token2id)

def add_token(self, token):
"""Add token to vocabulary.
Args:
token (str): token to add
token (str): token to add.
"""
self._token_count.update([token])

def add_documents(self, docs):
"""Update dictionary from a collection of documents. Each document is a list
of tokens.
Args:
docs (list): documents to add.
"""
for sent in docs:
self._token_count.update(sent)

def doc2id(self, doc):
return [self._token2id.get(token) for token in doc]
"""Get the list of token_id given doc.
Args:
doc (list): document.
Returns:
list: int id of doc.
"""
return [self.token_to_id(token) for token in doc]

def id2doc(self, ids):
"""Get the token list.
Args:
ids (list): token ids.
Returns:
list: token list.
"""
return [self.id_to_token(idx) for idx in ids]

def build(self):
token_freq = self._token_count.most_common(self._num_words)
idx = self._start
"""
Build vocabulary.
"""
token_freq = self._token_count.most_common(self._max_size)
idx = len(self.vocab)
for token, _ in token_freq:
self._token2id[token] = idx
self._id2token[idx] = token
self._id2token.append(token)
idx += 1
if self._unk:
unk = '<unk>'
self._token2id[unk] = idx
self._id2token.append(unk)

def process_token(self, token):
if self._lower:
token = token.lower()

def word_id(self, word):
"""Get the word_id of given word.
return token

def token_to_id(self, token):
"""Get the token_id of given token.
Args:
word (str): word from vocabulary
token (str): token from vocabulary.
Returns:
int: int id of word
int: int id of token.
"""
return self._token2id.get(word, None)
return self._token2id.get(token, len(self._token2id) - 1)

def __len__(self):
return len(self._token2id)

def id_to_word(self, wid):
"""Word-id to word (string).
def id_to_token(self, idx):
"""token-id to token (string).
Args:
wid (int): word id
idx (int): token id.
Returns:
str: string of given word id
str: string of given token id.
"""
return self._id2token.get(wid)
return self._id2token[idx]

@property
def vocab(self):
"""
dict: get the dict object of the vocabulary
"""Return the vocabulary.
Returns:
dict: get the dict object of the vocabulary
"""
return self._token2id

@property
def reverse_vocab(self):
"""
Return the vocabulary as a reversed dict object
"""Return the vocabulary as a reversed dict object.
Returns:
dict: reversed vocabulary object
"""
Expand Down
53 changes: 49 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,62 @@ def test_batch_iter(self):
steps, batches = batch_iter(list(zip(sents, labels)), batch_size, preprocessor=p)
self.assertEqual(len([_ for _ in batches]), steps) # Todo: infinite loop

def test_vocabulary(self):

class TestVocabulary(unittest.TestCase):

def test_add_documents(self):
# word vocabulary.
docs = [['a'], ['a', 'b'], ['a', 'b', 'c']]
token2id = {'a': 1, 'b': 2, 'c': 3}
token2id = {'<pad>': 0, 'a': 1, 'b': 2, 'c': 3, '<unk>': 4}
vocab = Vocabulary()
vocab.add_documents(docs)
vocab.build()
print(vocab._token2id)
self.assertEqual(vocab._token2id, token2id)

token2id = {'<pad>': 0, 'a': 1, 'b': 2, 'c': 3}
vocab = Vocabulary(unk=False)
vocab.add_documents(docs)
vocab.build()
self.assertEqual(vocab._token2id, token2id)

token2id = {'<pad>': 0, '<s>': 1, 'a': 2, 'b': 3, 'c': 4}
vocab = Vocabulary(unk=False, specials=('<pad>', '<s>'))
vocab.add_documents(docs)
vocab.build()
self.assertEqual(vocab._token2id, token2id)

token2id = {'a': 0, 'b': 1, 'c': 2}
vocab = Vocabulary(unk=False, specials=())
vocab.add_documents(docs)
vocab.build()
self.assertEqual(vocab._token2id, token2id)

# char vocabulary.
docs = ['hoge', 'fuga', 'bar']
vocab = Vocabulary()
vocab.add_documents(docs)
vocab.build()
num_chars = len(set(''.join(docs))) + 2
self.assertEqual(len(vocab._token2id), num_chars)

def test_doc2id(self):
# word ids.
docs = [['a'], ['a', 'b'], ['a', 'b', 'c']]
vocab = Vocabulary()
vocab.add_documents(docs)
vocab.build()
another_doc = ['a', 'b', 'c', 'd']
doc_ids = vocab.doc2id(another_doc)
self.assertEqual(doc_ids, [1, 2, 3, 4])

# char_ids.
docs = ['hoge', 'fuga', 'bar']
vocab = Vocabulary()
vocab.add_documents(docs)
vocab.build()
print(vocab._token2id)
doc_ids = vocab.doc2id(docs[0])
correct = [vocab.token_to_id(c) for c in docs[0]]
self.assertEqual(doc_ids, correct)

def test_id2doc(self):
pass

0 comments on commit 0c0b8a1

Please sign in to comment.