Skip to content

Commit

Permalink
Separate tokenizer from tagger
Browse files Browse the repository at this point in the history
  • Loading branch information
Hironsan committed Nov 23, 2017
1 parent acad40a commit 0b528c1
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 27 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,9 @@ After evaluation, F1 value is output:
Let's try tagging a sentence, "President Obama is speaking at the White House."
We can do it as follows:
```python
>>> sent = 'President Obama is speaking at the White House.'
>>> model.analyze(sent)
>>> words = 'President Obama is speaking at the White House.'.split()
>>> model.analyze(words)
{
'text': 'President Obama is speaking at the White House.',
'words': [
'President',
'Obama',
Expand Down
28 changes: 9 additions & 19 deletions anago/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,9 @@

class Tagger(object):

def __init__(self,
model,
preprocessor=None,
tokenizer=str.split):

def __init__(self, model, preprocessor=None):
self.model = model
self.preprocessor = preprocessor
self.tokenizer = tokenizer

def predict(self, words):
length = np.array([len(words)])
Expand All @@ -34,10 +29,8 @@ def _get_prob(self, pred):

return prob

def _build_response(self, sent, tags, prob):
words = self.tokenizer(sent)
def _build_response(self, words, tags, prob):
res = {
'text': sent,
'words': words,
'entities': [

Expand All @@ -57,18 +50,17 @@ def _build_response(self, sent, tags, prob):

return res

def analyze(self, sent):
assert isinstance(sent, str)
def analyze(self, words):
assert isinstance(words, list)

words = self.tokenizer(sent)
pred = self.predict(words)
tags = self._get_tags(pred)
prob = self._get_prob(pred)
res = self._build_response(sent, tags, prob)
res = self._build_response(words, tags, prob)

return res

def tag(self, sent):
def tag(self, words):
"""Tags a sentence named entities.
Args:
Expand All @@ -84,15 +76,14 @@ def tag(self, sent):
('speaking', 'O'), ('at', 'O'), ('the', 'O'),
('White', 'LOCATION'), ('House', 'LOCATION'), ('.', 'O')]
"""
assert isinstance(sent, str)
assert isinstance(words, list)

words = self.tokenizer(sent)
pred = self.predict(words)
pred = [t.split('-')[-1] for t in pred] # remove prefix: e.g. B-Person -> Person

return list(zip(words, pred))

def get_entities(self, sent):
def get_entities(self, words):
"""Gets entities from a sentence.
Args:
Expand All @@ -105,9 +96,8 @@ def get_entities(self, sent):
sent = 'President Obama is speaking at the White House.'
result = {'Person': ['Obama'], 'LOCATION': ['White House']}
"""
assert isinstance(sent, str)
assert isinstance(words, list)

words = self.tokenizer(sent)
pred = self.predict(words)
entities = self._get_chunks(words, pred)

Expand Down
4 changes: 2 additions & 2 deletions anago/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def eval(self, x_test, y_test):
else:
raise (OSError('Could not find a model. Call load(dir_path).'))

def analyze(self, sent):
def analyze(self, words):
if self.model:
tagger = Tagger(self.model, preprocessor=self.p)
return tagger.analyze(sent)
return tagger.analyze(words)
else:
raise (OSError('Could not find a model. Call load(dir_path).'))

Expand Down
5 changes: 2 additions & 3 deletions tests/wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def setUpClass(cls):

cls.embeddings = load_glove(EMBEDDING_PATH)

cls.sent = 'President Obama is speaking at the White House.'
cls.words = 'President Obama is speaking at the White House.'.split()

cls.dir_path = 'models'

Expand All @@ -53,10 +53,9 @@ def test_eval(self):
def test_analyze(self):
model = anago.Sequence(max_epoch=1, embeddings=self.embeddings)
model.train(self.x_train, self.y_train, self.x_valid, self.y_valid)
res = model.analyze(self.sent)
res = model.analyze(self.words)
pprint(res)

self.assertIn('text', res)
self.assertIn('words', res)
self.assertIn('entities', res)

Expand Down

0 comments on commit 0b528c1

Please sign in to comment.