Skip to content

Commit

Permalink
Add lst20_onnx to NER class
Browse files Browse the repository at this point in the history
  • Loading branch information
wannaphong committed Apr 26, 2022
1 parent dc11a9b commit 6efee29
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 25 deletions.
32 changes: 32 additions & 0 deletions pythainlp/tag/lst20_ner_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
from typing import List
from pythainlp.tag.wangchanberta_onnx import WngchanBerta_ONNX


class LST20_NER_ONNX(WngchanBerta_ONNX):
def __init__(self, providers: List[str] = ['CPUExecutionProvider']) -> None:
WngchanBerta_ONNX.__init__(
self,
model_name="onnx_lst20ner",
model_version="1.0",
file_onnx="lst20-ner-model.onnx",
providers=providers
)

def clean_output(self, list_text):
new_list = []
if list_text[0][0] == "▁":
list_text = list_text[1:]
for i, j in list_text:
if i.startswith("▁") and i != '▁':
i = i.replace("▁", "", 1)
elif i == '▁':
i = " "
new_list.append((i, j))
return list_text

def _config(self, list_ner):
_n = []
for i,j in list_ner:
_n.append((i,j.replace('E_', 'I_').replace('_', '-')))
return _n
14 changes: 0 additions & 14 deletions pythainlp/tag/lst20ner.py

This file was deleted.

6 changes: 5 additions & 1 deletion pythainlp/tag/named_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class NER:
**Options for engine**
* *thainer* - Thai NER engine
* *wangchanberta* - wangchanberta model
* *lst20_onnx* - LST20 NER model by wangchanberta with ONNX runtime
* *tltk* - wrapper for `TLTK <https://pypi.org/project/tltk/>`_.
**Options for corpus**
Expand All @@ -33,6 +34,9 @@ def load_engine(self, engine: str, corpus: str) -> None:
if engine == "thainer" and corpus == "thainer":
from pythainlp.tag.thainer import ThaiNameTagger
self.engine = ThaiNameTagger()
elif engine == "lst20_onnx":
from pythainlp.tag.lst20_ner_onnx import LST20_NER_ONNX
self.engine = LST20_NER_ONNX()
elif engine == "wangchanberta":
from pythainlp.wangchanberta import ThaiNameTagger
self.engine = ThaiNameTagger(dataset_name=corpus)
Expand Down Expand Up @@ -88,7 +92,7 @@ def tag(
"""wangchanberta is not support part-of-speech tag.
It have not part-of-speech tag in output."""
)
if self.name_engine == "wangchanberta":
if self.name_engine == "wangchanberta" or self.name_engine == "lst20_onnx":
return self.engine.get_ner(text, tag=tag)
else:
return self.engine.get_ner(text, tag=tag, pos=pos)
39 changes: 29 additions & 10 deletions pythainlp/tag/wangchanberta_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,7 @@ def postprocess(self, logits_data):
return scores

def clean_output(self, list_text):
new_list = []
for i, j in list_text:
if i.startswith("▁") and i != '▁':
i = i.replace("▁", "", 1)
elif i == '▁':
i = " "
new_list.append((i, j))
return new_list
return list_text

def totag(self, post, sent):
tag = []
Expand All @@ -84,10 +77,36 @@ def totag(self, post, sent):
)
return tag

def get_ner(self, text: str):
def _config(self, list_ner):
return list_ner

def get_ner(self, text: str, tag: bool = False):
self._s = self.build_tokenizer(text)
logits = self.session.run(
output_names=[self.outputs_name],
input_feed=self._s
)[0]
return self.clean_output(self.totag(self.postprocess(logits), text))
_tag = self.clean_output(self.totag(self.postprocess(logits), text))
if tag:
_tag = self._config(_tag)
temp = ""
sent = ""
for idx, (word, ner) in enumerate(_tag):
if ner.startswith("B-") and temp != "":
sent += "</" + temp + ">"
temp = ner[2:]
sent += "<" + temp + ">"
elif ner.startswith("B-"):
temp = ner[2:]
sent += "<" + temp + ">"
elif ner == "O" and temp != "":
sent += "</" + temp + ">"
temp = ""
sent += word

if idx == len(_tag) - 1 and temp != "":
sent += "</" + temp + ">"

return sent
else:
return _tag
3 changes: 3 additions & 0 deletions tests/test_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ def test_NER_class(self):
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า"))
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", pos=False))
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", tag=True))
ner = NER(engine="lst20_onnx")
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า"))
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", tag=True))
ner = NER(engine="tltk")
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า"))
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", pos=False))
Expand Down

0 comments on commit 6efee29

Please sign in to comment.