Skip to content

Commit

Permalink
Add rules for BaseTokenizer, update Vocabulary (#61)
Browse files Browse the repository at this point in the history
* Add pre-rules and post-rules for BaseTokenizer, update Vocabulary

* Update tests for BaseTokenizer and Vocabulary
  • Loading branch information
tqtg committed Mar 21, 2019
1 parent d1d68a7 commit 469fc52
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 23 deletions.
85 changes: 70 additions & 15 deletions cornac/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
"""

from . import FeatureModule
from typing import List, Dict
from typing import List, Dict, Callable
from collections import defaultdict, Counter
import pickle
import numpy as np
import itertools
import re

SPECIAL_TOKENS = ['<PAD>', '<UNK>', '<BOS>', '<EOS>']
PAD, UNK, BOS, EOS = '<PAD>', '<UNK>', '<BOS>', '<EOS>'
SPECIAL_TOKENS = [PAD, UNK, BOS, EOS]


class Tokenizer():
Expand Down Expand Up @@ -41,13 +43,35 @@ def batch_tokenize(self, texts: List[str]) -> List[List[str]]:
raise NotImplementedError


def rm_dup_spaces(t: str) -> str:
"""
Remove duplicate spaces in `t`.
"""
return re.sub(' {2,}', ' ', t)


def lower(tokens: List[str]) -> List[str]:
"""
Lowercase all characters of every token in `tokens`.
"""
return [t.lower() for t in tokens]


DEFAULT_PRE_RULES = [rm_dup_spaces]
DEFAULT_POST_RULES = [lower]


class BaseTokenizer(Tokenizer):
"""
A base tokenizer use a provided delimiter `sep` to split text.
"""

def __init__(self, sep=' '):
def __init__(self, sep=' ',
pre_rules: List[Callable[[str], str]] = None,
post_rules: List[Callable[[str], List[str]]] = None):
self.sep = sep
self.pre_rules = DEFAULT_PRE_RULES if pre_rules is None else pre_rules
self.post_rules = DEFAULT_POST_RULES if post_rules is None else post_rules

def tokenize(self, t: str) -> List[str]:
"""
Expand All @@ -57,7 +81,12 @@ def tokenize(self, t: str) -> List[str]:
-------
tokens : ``List[str]``
"""
return t.split(self.sep)
for rule in self.pre_rules:
t = rule(t)
tokens = t.split(self.sep)
for rule in self.post_rules:
tokens = rule(tokens)
return tokens

# TODO: this function can be parallelized
def batch_tokenize(self, texts: List[str]) -> List[List[str]]:
Expand All @@ -77,9 +106,16 @@ class Vocabulary():
"""

def __init__(self, idx2tok: List[str]):
self.idx2tok = idx2tok
self.idx2tok = self._add_special_tokens(idx2tok)
self.tok2idx = defaultdict(int, {tok: idx for idx, tok in enumerate(self.idx2tok)})

def _add_special_tokens(self, idx2tok: List[str]) -> List[str]:
for tok in reversed(SPECIAL_TOKENS): # <PAD>:0, '<UNK>':1, '<BOS>':2, '<EOS>':3
if tok in idx2tok:
idx2tok.remove(tok)
idx2tok.insert(0, tok)
return idx2tok

@property
def size(self):
return len(self.idx2tok)
Expand All @@ -88,7 +124,7 @@ def to_idx(self, tokens: List[str]) -> List[int]:
"""
Convert a list of `tokens` to their integer indices.
"""
return [self.tok2idx.get(tok, 1) for tok in tokens] # 1 is <UNK> idx
return [self.tok2idx.get(tok, 1) for tok in tokens] # 1 is <UNK> idx

def to_text(self, indices: List[int], sep=' ') -> List[str]:
"""
Expand All @@ -109,10 +145,6 @@ def from_tokens(cls, tokens: List[str], max_vocab: int = None, min_freq: int = 1
"""
freq = Counter(tokens)
idx2tok = [tok for tok, cnt in freq.most_common(max_vocab) if cnt >= min_freq]
for tok in reversed(SPECIAL_TOKENS): # <PAD>:0, '<UNK>':1, '<BOS>':2, '<EOS>':3
if tok in idx2tok:
idx2tok.remove(tok)
idx2tok.insert(0, tok)
return cls(idx2tok)

@classmethod
Expand All @@ -126,20 +158,42 @@ def load(cls, path):
class TextModule(FeatureModule):
"""Text module
Parameters
----------
id_text: Dict, optional, default = None
A dictionary contains mapping between user/item id to their text.
tokenizer: Tokenizer, optional, default = None
Tokenizer for text splitting. If None, the BaseTokenizer will be used.
vocab: Vocabulary, optional, default = None
Vocabulary of tokens. It contains mapping between tokens to their
integer ids and vice versa.
max_vocab: int, optional, default = None
The maximum size of the vocabulary.
If vocab is provided, this will be ignored.
min_freq: int, default = 1
The minimum frequency of tokens to be included into vocabulary.
If vocab is provided, this will be ignored.
"""

def __init__(self,
id_text: Dict = None,
vocab: List[str] = None,
max_vocab: int = None,
tokenizer: Tokenizer = None,
vocab: Vocabulary = None,
max_vocab: int = None,
min_freq: int = 1,
**kwargs):
super().__init__(**kwargs)

self._id_text = id_text
self.tokenizer = tokenizer
self.vocab = vocab
self.max_vocab = max_vocab
self.tokenizer = tokenizer
self.min_freq = min_freq
self.sequences = None

def _build_text(self, global_id_map: Dict):
Expand All @@ -163,7 +217,8 @@ def _build_text(self, global_id_map: Dict):

if self.vocab is None:
self.vocab = Vocabulary.from_tokens(tokens=list(itertools.chain(*self.sequences)),
max_vocab=self.max_vocab)
max_vocab=self.max_vocab,
min_freq=self.min_freq)

# Map tokens into integer ids
for i, seq in enumerate(self.sequences):
Expand Down Expand Up @@ -206,4 +261,4 @@ def batch_freq(self, batch_ids):
def batch_tfidf(self, batch_ids):
"""Return matrix of TF-IDF features corresponding to provided batch_ids
"""
raise NotImplementedError
raise NotImplementedError
25 changes: 17 additions & 8 deletions tests/cornac/data/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from cornac.data import TextModule
from cornac.data.text import BaseTokenizer
from cornac.data.text import Vocabulary
from cornac.data.text import SPECIAL_TOKENS
from collections import OrderedDict
from cornac.data.text import SPECIAL_TOKENS, DEFAULT_PRE_RULES, DEFAULT_POST_RULES
from collections import OrderedDict, defaultdict
import numpy as np


Expand All @@ -32,19 +32,29 @@ def test_batch_tokenize(self):
self.assertListEqual(token_list, [['a', 'b', 'c'],
['d', 'e', 'f']])

def test_default_rules(self):
tok = BaseTokenizer(pre_rules=DEFAULT_PRE_RULES, post_rules=DEFAULT_POST_RULES)
token_list = tok.tokenize('a B C d E')
self.assertListEqual(token_list, ['a', 'b', 'c', 'd', 'e'])


class TestVocabulary(unittest.TestCase):

def setUp(self):
self.tokens = ['a', 'b', 'c']
self.vocab = Vocabulary(self.tokens)
(a, b, c) = (self.vocab.tok2idx[tok] for tok in self.tokens[-3:])
self.tok_seq = ['a', 'a', 'b', 'c']
self.idx_seq = [0, 0, 1, 2]
self.idx_seq = [a, a, b, c]

def test_init(self):
self.assertEqual(self.vocab.size, 3)
self.assertListEqual(self.vocab.idx2tok, ['a', 'b', 'c'])
self.assertDictEqual(self.vocab.tok2idx, {'a': 0, 'b': 1, 'c': 2})
self.assertEqual(self.vocab.size, len(SPECIAL_TOKENS) + 3)
self.assertListEqual(self.vocab.idx2tok, SPECIAL_TOKENS + ['a', 'b', 'c'])

tok2idx = defaultdict()
for tok in SPECIAL_TOKENS + self.tokens:
tok2idx.setdefault(tok, len(tok2idx))
self.assertDictEqual(self.vocab.tok2idx, tok2idx)

def test_to_idx(self):
self.assertEqual(self.vocab.to_idx(self.tok_seq), self.idx_seq)
Expand All @@ -60,8 +70,7 @@ def test_save(self):

def test_from_tokens(self):
from_tokens_vocab = Vocabulary.from_tokens(self.tokens)
self.assertCountEqual(SPECIAL_TOKENS + self.vocab.idx2tok,
from_tokens_vocab.idx2tok)
self.assertCountEqual(self.vocab.idx2tok, from_tokens_vocab.idx2tok)


class TestTextModule(unittest.TestCase):
Expand Down

0 comments on commit 469fc52

Please sign in to comment.