Skip to content

Commit

Permalink
Initial tests to named entity detector
Browse files Browse the repository at this point in the history
  • Loading branch information
aCampello committed Oct 29, 2020
1 parent ecd9654 commit aaf36a0
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions tests/test_detector_named_entity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest

from scrubadub.detectors import NamedEntityDetector
from scrubadub.filth import NameFilth, OrganizationFilth, NamedEntityFilth
import scrubadub

from base import BaseTestCase


class NamedEntityTestCase(unittest.TestCase, BaseTestCase):
"""
Tests whether the detector is performing correctly from a function point of view.
For accuracy tests use .benchmark_accuracy instead
"""

def setUp(self):
self.detector = NamedEntityDetector()

def _assert_filth_type_and_pos(self, doc_list, beg_end_list, filth_class):
doc_names = [str(x) for x in range(len(doc_list))]

filth_list = list(self.detector.iter_filth_documents(doc_names, doc_list))

for filth, beg_end in zip(filth_list, beg_end_list):
self.assertIsInstance(filth, filth_class)
self.assertEqual((filth.beg, filth.end), beg_end)

def test_names(self):
doc_list = ["John is a cat",
"When was Maria born?",
"john is a cat",
"when was maria born"]
beg_end_list = [(0, 4),
(9, 14),
(0, 4),
(9, 14)]

self._assert_filth_type_and_pos(doc_list, beg_end_list, NameFilth)

def test_organisations(self):
doc_list = ["She started working for Apple this year",
"But used to work for Google"]
beg_end_list = [(24, 30),
(21, 27)]

self._assert_filth_type_and_pos(doc_list, beg_end_list, OrganizationFilth)

def test_other_entity(self):
self.detector.named_entities = {"GPE"}
doc_list = ["London is a city in England"]
beg_end_list = [(0, 6),
(20, 27)]

self._assert_filth_type_and_pos(doc_list, beg_end_list, NamedEntityFilth)

def test_wrong_model(self):
"""Test that it raises an error if user inputs invalid spacy model"""
with self.assertRaises(SystemExit):
NamedEntityDetector(model='not_a_valid_spacy_model')

def test_iter_filth(self):
doc = "John is a cat"

output_iter_docs = list(self.detector.iter_filth_documents(doc_list=[doc], doc_names=["0"]))
output_iter = list(self.detector.iter_filth(text=doc, document_name="0"))

self.assertListEqual(output_iter, output_iter_docs)

0 comments on commit aaf36a0

Please sign in to comment.