Skip to content

Commit

Permalink
Add pythainlp.coref docs and testset
Browse files Browse the repository at this point in the history
  • Loading branch information
wannaphong committed Jun 4, 2023
1 parent a027c30 commit a40d6e3
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 6 deletions.
10 changes: 10 additions & 0 deletions docs/api/coref.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
.. currentmodule:: pythainlp.coref

pythainlp.coref
===============
The :class:`pythainlp.coref` is Coreference Resolution for Thai.

Modules
-------

.. autofunction:: coreference_resolution
15 changes: 12 additions & 3 deletions pythainlp/coref/_fastcoref.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import spacy


class FastCoref:
def __init__(self, model_name, nlp=spacy.blank("th"), device="cpu", type="FCoref") -> None:
def __init__(self, model_name, nlp=spacy.blank("th"), device:str="cpu", type:str="FCoref") -> None:
if type == "FCoref":
from fastcoref import FCoref as _model
else:
Expand All @@ -25,5 +26,13 @@ def __init__(self, model_name, nlp=spacy.blank("th"), device="cpu", type="FCoref
self.nlp = nlp
self.model = _model(self.model_name,device=device,nlp=self.nlp)

def predict(self, texts:list):
return self.model.predict(texts=texts)
def _to_json(self, _predict):
return {
"text":_predict.text,
"clusters_string":_predict.get_clusters(as_strings=True),
"clusters":_predict.get_clusters(as_strings=False)
}


def predict(self, texts:List[str])->dict:
return [self._to_json(i) for i in self.model.predict(texts=texts)]
36 changes: 34 additions & 2 deletions pythainlp/coref/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,44 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
model = None


def coreference_resolution(text, model_name="han-coref-v1.0", device="cpu"):
def coreference_resolution(texts:List[str], model_name:str="han-coref-v1.0", device:str="cpu"):
"""
Coreference Resolution
:param List[str] texts: list texts to do coreference resolution
:param str model_name: coreference resolution model
:param str device: device for running coreference resolution model (cpu, cuda, and other)
:return: List txets of coreference resolution
:rtype: List[dict]
:Options for model_name:
* *han-coref-v1.0* - (default) Han-Corf: Thai oreference resolution by PyThaiNLP v1.0
:Example:
::
from pythainlp.coref import coreference_resolution
print(
coreference_resolution(
["Bill Gates ได้รับวัคซีน COVID-19 เข็มแรกแล้ว ระบุ ผมรู้สึกสบายมาก"]
)
)
# output:
# [
# {'text': 'Bill Gates ได้รับวัคซีน COVID-19 เข็มแรกแล้ว ระบุ ผมรู้สึกสบายมาก',
# 'clusters_string': [['Bill Gates', 'ผม']],
# 'clusters': [[(0, 10), (50, 52)]]}
# ]
"""
global model
if isinstance(texts, str):
texts = [texts]
if model == None and model_name=="han-coref-v1.0":
from pythainlp.coref.han_coref import HanCoref
model = HanCoref(device=device)
return model.predict(text)
return model.predict(texts)
2 changes: 1 addition & 1 deletion pythainlp/coref/han_coref.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


class HanCoref(FastCoref):
def __init__(self,device="cpu",nlp=spacy.blank("th")) -> None:
def __init__(self,device:str="cpu",nlp=spacy.blank("th")) -> None:
super(self.__class__, self).__init__(
model_name="pythainlp/han-coref-v1.0",
device=device,
Expand Down
13 changes: 13 additions & 0 deletions tests/test_coref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-

import unittest
from pythainlp.coref import coreference_resolution


class TestParsePackage(unittest.TestCase):
def test_coreference_resolution(self):
self.assertIsNotNone(
coreference_resolution(
"Bill Gates ได้รับวัคซีน COVID-19 เข็มแรกแล้ว ระบุ ผมรู้สึกสบายมาก"
)
)

0 comments on commit a40d6e3

Please sign in to comment.