Skip to content

Commit

Permalink
Update TextModule to support batch of sequences (#60)
Browse files Browse the repository at this point in the history
* Update TextModule to support a batch of sequences

* Update .gitignore to exclude tests/vocab.pkl

* Directly map tokens to integer ids in sequences after building vocab

* Update tests for TextModule
  • Loading branch information
tqtg committed Mar 21, 2019
1 parent 793bef7 commit d1d68a7
Show file tree
Hide file tree
Showing 4 changed files with 301 additions and 14 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
tests/vocab.pkl

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
15 changes: 6 additions & 9 deletions cornac/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,16 @@
from .text import TextModule
from .image import ImageModule
from .graph import GraphModule

from .trainset import TrainSet
from .trainset import MatrixTrainSet
from .trainset import MultimodalTrainSet

from .testset import TestSet
from .testset import MultimodalTestSet

from . import reader

__all__ = ['FeatureModule',
'TextModule',
'ImageModule',
'GraphModule',
'TrainSet',
'MatrixTrainSet',
'MultimodalTrainSet',
'TestSet',
'MultimodalTestSet']
__all__ = ['FeatureModule', 'TextModule', 'ImageModule', 'GraphModule',
'TrainSet', 'MatrixTrainSet', 'MultimodalTrainSet',
'TestSet', 'MultimodalTestSet']
194 changes: 192 additions & 2 deletions cornac/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,205 @@
"""

from . import FeatureModule
from typing import List, Dict
from collections import defaultdict, Counter
import pickle
import numpy as np
import itertools

SPECIAL_TOKENS = ['<PAD>', '<UNK>', '<BOS>', '<EOS>']


class Tokenizer():
"""
Generic class for other subclasses to extend from. This typically
either splits text into word tokens or character tokens.
"""

def tokenize(self, t: str) -> List[str]:
"""
Splitting text into tokens.
Returns
-------
tokens : ``List[str]``
"""
raise NotImplementedError

def batch_tokenize(self, texts: List[str]) -> List[List[str]]:
"""
Splitting a corpus with multiple text documents.
Returns
-------
tokens : ``List[List[str]]``
"""
raise NotImplementedError


class BaseTokenizer(Tokenizer):
"""
A base tokenizer use a provided delimiter `sep` to split text.
"""

def __init__(self, sep=' '):
self.sep = sep

def tokenize(self, t: str) -> List[str]:
"""
Splitting text into tokens.
Returns
-------
tokens : ``List[str]``
"""
return t.split(self.sep)

# TODO: this function can be parallelized
def batch_tokenize(self, texts: List[str]) -> List[List[str]]:
"""
Splitting a corpus with multiple text documents.
Returns
-------
tokens : ``List[List[str]]``
"""
return [self.tokenize(t) for t in texts]


class Vocabulary():
"""
Vocabulary basically contains mapping between numbers and tokens and vice versa.
"""

def __init__(self, idx2tok: List[str]):
self.idx2tok = idx2tok
self.tok2idx = defaultdict(int, {tok: idx for idx, tok in enumerate(self.idx2tok)})

@property
def size(self):
return len(self.idx2tok)

def to_idx(self, tokens: List[str]) -> List[int]:
"""
Convert a list of `tokens` to their integer indices.
"""
return [self.tok2idx.get(tok, 1) for tok in tokens] # 1 is <UNK> idx

def to_text(self, indices: List[int], sep=' ') -> List[str]:
"""
Convert a list of integer `indices` to their tokens.
"""
return sep.join([self.idx2tok[i] for i in indices]) if sep is not None else [self.idx2tok[i] for i in indices]

def save(self, path):
"""
Save idx2tok into a pickle file.
"""
pickle.dump(self.idx2tok, open(path, 'wb'))

@classmethod
def from_tokens(cls, tokens: List[str], max_vocab: int = None, min_freq: int = 1) -> 'Vocabulary':
"""
Build a vocabulary from list of tokens.
"""
freq = Counter(tokens)
idx2tok = [tok for tok, cnt in freq.most_common(max_vocab) if cnt >= min_freq]
for tok in reversed(SPECIAL_TOKENS): # <PAD>:0, '<UNK>':1, '<BOS>':2, '<EOS>':3
if tok in idx2tok:
idx2tok.remove(tok)
idx2tok.insert(0, tok)
return cls(idx2tok)

@classmethod
def load(cls, path):
"""
Load a vocabulary from `path` to a pickle file.
"""
return cls(pickle.load(open(path, 'rb')))


class TextModule(FeatureModule):
"""Text module
"""

def __init__(self, **kwargs):
def __init__(self,
id_text: Dict = None,
vocab: List[str] = None,
max_vocab: int = None,
tokenizer: Tokenizer = None,
**kwargs):
super().__init__(**kwargs)

self._id_text = id_text
self.vocab = vocab
self.max_vocab = max_vocab
self.tokenizer = tokenizer
self.sequences = None

def _build_text(self, global_id_map: Dict):
"""Build the text based on provided global id map
"""
if self._id_text is None:
return

if self.tokenizer is None:
self.tokenizer = BaseTokenizer()

# Tokenize texts
self.sequences = []
mapped2raw = {mapped_id: raw_id for raw_id, mapped_id in global_id_map.items()}
for mapped_id in range(len(global_id_map)):
raw_id = mapped2raw[mapped_id]
text = self._id_text[raw_id]
self.sequences.append(self.tokenizer.tokenize(text))
del self._id_text[raw_id]
del self._id_text

if self.vocab is None:
self.vocab = Vocabulary.from_tokens(tokens=list(itertools.chain(*self.sequences)),
max_vocab=self.max_vocab)

# Map tokens into integer ids
for i, seq in enumerate(self.sequences):
self.sequences[i] = self.vocab.to_idx(seq)

def build(self, global_id_map):
pass
"""Build the model based on provided list of ordered ids
"""
FeatureModule.build(self, global_id_map)
self._build_text(global_id_map)

def batch_seq(self, batch_ids, max_length=None):
"""Return a numpy matrix of text sequences containing token ids with size=(len(batch_ids), max_length).
If max_length=None, it will be inferred based on retrieved sequences.
"""
if self.sequences is None:
raise ValueError('self.sequences is required but None!')

if max_length is None:
max_length = max(len(self.sequences[mapped_id]) for mapped_id in batch_ids)

seq_mat = np.zeros((len(batch_ids), max_length), dtype=np.int)
for i, mapped_id in enumerate(batch_ids):
idx_seq = self.sequences[mapped_id][:max_length]
for j, idx in enumerate(idx_seq):
seq_mat[i, j] = idx

return seq_mat

def batch_bow(self, batch_ids):
"""Return matrix of bag-of-words corresponding to provided batch_ids
"""
raise NotImplementedError

def batch_freq(self, batch_ids):
"""Return matrix of word frequencies corresponding to provided batch_ids
"""
raise NotImplementedError

def batch_tfidf(self, batch_ids):
"""Return matrix of TF-IDF features corresponding to provided batch_ids
"""
raise NotImplementedError
104 changes: 101 additions & 3 deletions tests/cornac/data/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,106 @@
@author: Quoc-Tuan Truong <tuantq.vnu@gmail.com>
"""

import unittest

from cornac.data import TextModule
from cornac.data.text import BaseTokenizer
from cornac.data.text import Vocabulary
from cornac.data.text import SPECIAL_TOKENS
from collections import OrderedDict
import numpy as np


class TestBaseTokenizer(unittest.TestCase):

def setUp(self):
self.tok = BaseTokenizer()

def test_init(self):
self.assertEqual(self.tok.sep, ' ')

def test_tokenize(self):
tokens = self.tok.tokenize('a b c')
self.assertListEqual(tokens, ['a', 'b', 'c'])

def test_batch_tokenize(self):
token_list = self.tok.batch_tokenize(['a b c',
'd e f'])
self.assertListEqual(token_list, [['a', 'b', 'c'],
['d', 'e', 'f']])


class TestVocabulary(unittest.TestCase):

def setUp(self):
self.tokens = ['a', 'b', 'c']
self.vocab = Vocabulary(self.tokens)
self.tok_seq = ['a', 'a', 'b', 'c']
self.idx_seq = [0, 0, 1, 2]

def test_init(self):
self.assertEqual(self.vocab.size, 3)
self.assertListEqual(self.vocab.idx2tok, ['a', 'b', 'c'])
self.assertDictEqual(self.vocab.tok2idx, {'a': 0, 'b': 1, 'c': 2})

def test_to_idx(self):
self.assertEqual(self.vocab.to_idx(self.tok_seq), self.idx_seq)

def test_to_text(self):
self.assertEqual(self.vocab.to_text(self.idx_seq), ' '.join(self.tok_seq))
self.assertEqual(self.vocab.to_text(self.idx_seq, sep=None), self.tok_seq)

def test_save(self):
self.vocab.save('tests/vocab.pkl')
loaded_vocab = Vocabulary.load('tests/vocab.pkl')
self.assertListEqual(self.vocab.idx2tok, loaded_vocab.idx2tok)

def test_from_tokens(self):
from_tokens_vocab = Vocabulary.from_tokens(self.tokens)
self.assertCountEqual(SPECIAL_TOKENS + self.vocab.idx2tok,
from_tokens_vocab.idx2tok)


class TestTextModule(unittest.TestCase):

def setUp(self):
self.tokens = ['a', 'b', 'c', 'd', 'e', 'f']
self.id_text = {'u1': 'a b c',
'u2': 'b c d d',
'u3': 'c b e c f'}

self.module = TextModule(self.id_text)
self.id_map = OrderedDict({'u1': 0, 'u2': 1, 'u3': 2})
self.module.build(self.id_map)
self.token_ids = (self.module.vocab.tok2idx[tok] for tok in self.tokens)


def test_init(self):
self.assertCountEqual(self.module.vocab.idx2tok,
SPECIAL_TOKENS + self.tokens)

def test_sequences(self):
(a, b, c, d, e, f) = self.token_ids

self.assertListEqual(self.module.sequences,
[[a, b, c],
[b, c, d, d],
[c, b, e, c, f]])

def test_batch_seq(self):
(a, b, c, d, e, f) = self.token_ids

batch_seqs = self.module.batch_seq([2, 1])
self.assertEqual((2, 5), batch_seqs.shape)
np.testing.assert_array_equal(batch_seqs,
np.asarray([[c, b, e, c, f],
[b, c, d, d, 0]]))

batch_seqs = self.module.batch_seq([0, 2], max_length=4)
self.assertEqual((2, 4), batch_seqs.shape)
np.testing.assert_array_equal(batch_seqs,
np.asarray([[a, b, c, 0],
[c, b, e, c]]))

def test_init():
md = TextModule()
md.build(global_id_map=None)
if __name__ == '__main__':
unittest.main()

0 comments on commit d1d68a7

Please sign in to comment.