Skip to content

Commit

Permalink
Merge aaf36a0 into 0eec66e
Browse files Browse the repository at this point in the history
  • Loading branch information
aCampello committed Oct 29, 2020
2 parents 0eec66e + aaf36a0 commit 9c165b0
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 50 deletions.
3 changes: 2 additions & 1 deletion requirements/python
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ argcomplete
phonenumbers
pandas
sklearn
typing_extensions
spacy-nightly[transformers]>=3.0.0rc1
typing_extensions
4 changes: 2 additions & 2 deletions scrubadub/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from typing import Union, List, Dict, Sequence
from typing import Union, List, Dict, Sequence, Optional

# convenient imports
from .scrubbers import Scrubber
Expand Down Expand Up @@ -82,7 +82,7 @@ def list_filth(text: str, **kwargs) -> List[Filth]:
return list(scrubber.iter_filth(text, **kwargs))


def list_filth_documents(documents: Union[List[str], Dict[str, str]], **kwargs) -> List[Filth]:
def list_filth_documents(documents: Union[List[str], Dict[Optional[str], str]], **kwargs) -> List[Filth]:
"""Return a list of `Filth` that was detected in the string `text`.
`documents` can be in a dict, in the format of ``{'document_name': 'document'}``, or as a list of strings
Expand Down
1 change: 1 addition & 0 deletions scrubadub/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .credential import CredentialDetector
from .email import EmailDetector, NewEmailDetector
from .name import NameDetector
from .named_entity import NamedEntityDetector
from .phone import PhoneDetector
from .postalcode import PostalCodeDetector
from .known import KnownFilthDetector
Expand Down
52 changes: 52 additions & 0 deletions scrubadub/detectors/named_entity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Generator, Iterable, Optional, Sequence

import spacy
from wasabi import msg

from .base import Detector
from ..filth import NamedEntityFilth, Filth, NameFilth, OrganizationFilth
from ..utils import CanonicalStringSet


class NamedEntityDetector(Detector):
"""Use spacy's named entity recognition to clean named entities.
List specific entities to include passing ``named_entities``, e.g.
(PERSON)
"""
filth_cls_map = {
'PERSON': NameFilth,
'ORG': OrganizationFilth
}
name = 'named_entity'

disallowed_nouns = CanonicalStringSet(["skype"])

def __init__(self, named_entities: Iterable[str] = {'PERSON'},
model: str = "en_core_web_trf", **kwargs):
# Spacy NER are all upper cased
self.named_entities = {entity.upper() for entity in named_entities}
if model not in spacy.info()['pipelines']:
msg.warn("Downloading spacy model {}".format(model))
spacy.cli.download(model)

self.nlp = spacy.load(model)
# Only enable necessary pipes
self.nlp.select_pipes(enable=["transformer", "tagger", "parser", "ner"])
super(NamedEntityDetector, self).__init__(**kwargs)

def iter_filth_documents(self, doc_names: Sequence[Optional[str]],
doc_list: Sequence[str]) -> Generator[Filth, None, None]:
for doc_name, doc in zip(doc_names, self.nlp.pipe(doc_list)):
for ent in doc.ents:
if ent.label_ in self.named_entities:
# If there is no standard 'filth', returns a NamedEntity filth
filth_cls = self.filth_cls_map.get(ent.label_, NamedEntityFilth)
yield filth_cls(beg=ent.start_char,
end=ent.end_char,
text=ent.text,
document_name=(str(doc_name) if doc_name else None), # None if no doc_name provided
detector_name=self.name,
label=ent.label_)

def iter_filth(self, text: str, document_name: Optional[str] = None) -> Generator[Filth, None, None]:
yield from self.iter_filth_documents([document_name], [text])
1 change: 1 addition & 0 deletions scrubadub/filth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .email import EmailFilth
from .known import KnownFilth
from .name import NameFilth
from .named_entity import NamedEntityFilth
from .organization import OrganizationFilth
from .phone import PhoneFilth
from .postalcode import PostalCodeFilth
Expand Down
14 changes: 14 additions & 0 deletions scrubadub/filth/named_entity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .base import Filth


class NamedEntityFilth(Filth):
"""
Default filth type, for named entities (e.g. the ones in https://nightly.spacy.io/models/en#en_core_web_lg-labels),
except the ones represented in any other filth.
"""
type = 'named_entity'

def __init__(self, *args, label: str, **kwargs):
super(NamedEntityFilth, self).__init__(*args, **kwargs)
self.label = label.lower()
self.replacement_string = "{}_{}".format(self.type, self.label)
86 changes: 40 additions & 46 deletions scrubadub/scrubbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,23 +226,55 @@ def iter_filth(
) -> Generator[Filth, None, None]:
"""Iterate over the different types of filth that can exist.
"""
# Iterates using iter_filth documents.
# If a name is not provided, passes a list with one element, [text]

yield from self.iter_filth_documents(documents={document_name: text},
run_post_processors=run_post_processors)

def iter_filth_documents(
self,
documents: Union[Sequence[str], Dict[Optional[str], str]],
run_post_processors: bool = True
) -> Generator[Filth, None, None]:
"""Iterate over the different types of filth that can exist."""
if not isinstance(documents, (dict, list)):
raise TypeError('documents must be one of a string, list of strings or dict of strings.')

# Figures out which detectors have iter_filth_documents and applies to them

if isinstance(documents, dict):
document_names = list(documents.keys())
document_texts = list(documents.values())
elif isinstance(documents, (tuple, list)):
document_texts = documents
document_names = [str(x) for x in range(len(documents))]

# currently doing this by aggregating all_filths and then sorting
# inline instead of with a Filth.__cmp__ method, which is apparently
# much slower http://stackoverflow.com/a/988728/564709
#
# NOTE: we could probably do this in a more efficient way by iterating
# over all detectors simultaneously. just trying to get something
# working right now and we can worry about efficiency later
all_filths = [] # type: List[Filth]
for detector in self._detectors.values():
for filth in detector.iter_filth(text, document_name=document_name):
if not isinstance(filth, Filth):
raise TypeError('iter_filth must always yield Filth')
all_filths.append(filth)
filth_list = [] # type: List[Filth]
for name, detector in self._detectors.items():
document_iterator = getattr(detector, 'iter_filth_documents', None)
if callable(document_iterator):
for filth in document_iterator(document_names, document_texts):
if not isinstance(filth, Filth):
raise TypeError('iter_filth must always yield Filth')
filth_list.append(filth)
else:
for document_name, text in zip(document_names, document_texts):
for filth in detector.iter_filth(text, document_name=document_name):
if not isinstance(filth, Filth):
raise TypeError('iter_filth must always yield Filth')
filth_list.append(filth)

# This is split up so that we only have to use lists if we have to post_process Filth
if run_post_processors:
all_filths = list(self._merge_filths(all_filths))
all_filths = list(self._merge_filths(filth_list))
all_filths = list(self._post_process_filth_list(all_filths))

# Here we loop over a list of Filth...
Expand All @@ -251,47 +283,9 @@ def iter_filth(
else:
# ... but here, we're using a generator. If we try to use the same variable it would have two types and
# fail static typing in mypy
for filth in self._merge_filths(all_filths):
for filth in self._merge_filths(filth_list):
yield filth

def iter_filth_documents(
self,
documents: Union[Sequence[str], Dict[str, str]],
run_post_processors: bool = True
) -> Generator[Filth, None, None]:
"""Iterate over the different types of filth that can exist."""
if not isinstance(documents, (dict, list)):
raise TypeError('documents must be one of a string, list of strings or dict of strings.')

if run_post_processors:
# Only collect the filts into a list if we need to do post processing
filth_list = [] # type: List[Filth]
if isinstance(documents, dict):
filth_list = [
filth
for name, text in documents.items()
for filth in self.iter_filth(text, document_name=name, run_post_processors=False)
]
elif isinstance(documents, list):
filth_list = [
filth
for i_name, text in enumerate(documents)
for filth in self.iter_filth(text, document_name=str(i_name), run_post_processors=False)
]

for filth in self._post_process_filth_list(filth_list):
yield filth
else:
# Use generators when we dont post process the Filth
if isinstance(documents, dict):
for name, text in documents.items():
for filth in self.iter_filth(text, document_name=name, run_post_processors=False):
yield filth
elif isinstance(documents, list):
for i_name, text in enumerate(documents):
for filth in self.iter_filth(text, document_name=str(i_name), run_post_processors=False):
yield filth

@staticmethod
def _sort_filths(filth_list: Sequence[Filth]) -> List[Filth]:
"""Sorts a list of filths, needed before merging and concatenating"""
Expand Down
21 changes: 20 additions & 1 deletion tests/benchmark_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

def main():
general_docs = []
named_entity_docs = []
# address_docs = []
# uk_phone_docs = []
known_general_pii = []
known_named_entity_pii = []
# known_address_pii = []
# known_uk_phone_pii = []
start_time = time.time()
Expand All @@ -23,6 +25,15 @@ def main():
general_docs.append(new_doc)
known_general_pii += new_known_pii

#new_doc, new_known_pii = make_fake_document(paragraphs=4, seed=i_doc, filth_types=['name'])
# Change the filth name to allow for comparison with NamedEntityDetector. Probably there is a better way to do it

#for pii in new_known_pii:
# pii['filth_type'] = 'named_entity'

#named_entity_docs.append(new_doc)
#known_named_entity_pii += new_known_pii

# new_doc, new_known_pii = make_fake_document(paragraphs=4, seed=i_doc, filth_types=['gb_address', 'us_address'])
# address_docs.append(new_doc)
# known_address_pii += new_known_pii
Expand All @@ -35,7 +46,6 @@ def main():

scrubber_time = time.time()
scrubber = scrubadub.Scrubber()
# scrubber.add_detector(scrubadub.detectors.stanford_ner.StanfordNERDetector())
scrubber.add_detector(scrubadub.detectors.KnownFilthDetector(known_filth_items=known_general_pii))
filth_list = list(scrubber.iter_filth_documents(general_docs))

Expand All @@ -57,6 +67,15 @@ def main():
print("Scrubbed documents in {:.2f}s".format(end_time-scrubber_time))
print(get_filth_classification_report(filth_list))

# scrubber_time = time.time()
# scrubber = scrubadub.Scrubber(detector_list=[scrubadub.detectors.NamedEntityDetector(),
# scrubadub.detectors.KnownFilthDetector(known_filth_items=known_named_entity_pii)])
# filth_list = list(scrubber.iter_filth_documents(named_entity_docs))
# end_time = time.time()
# print("Documents generated in {:.2f}s".format(scrubber_time-start_time))
# print("Scrubbed documents in {:.2f}s".format(end_time-scrubber_time))
# print(get_filth_classification_report(filth_list))

sys.exit(0)


Expand Down
67 changes: 67 additions & 0 deletions tests/test_detector_named_entity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest

from scrubadub.detectors import NamedEntityDetector
from scrubadub.filth import NameFilth, OrganizationFilth, NamedEntityFilth
import scrubadub

from base import BaseTestCase


class NamedEntityTestCase(unittest.TestCase, BaseTestCase):
"""
Tests whether the detector is performing correctly from a function point of view.
For accuracy tests use .benchmark_accuracy instead
"""

def setUp(self):
self.detector = NamedEntityDetector()

def _assert_filth_type_and_pos(self, doc_list, beg_end_list, filth_class):
doc_names = [str(x) for x in range(len(doc_list))]

filth_list = list(self.detector.iter_filth_documents(doc_names, doc_list))

for filth, beg_end in zip(filth_list, beg_end_list):
self.assertIsInstance(filth, filth_class)
self.assertEqual((filth.beg, filth.end), beg_end)

def test_names(self):
doc_list = ["John is a cat",
"When was Maria born?",
"john is a cat",
"when was maria born"]
beg_end_list = [(0, 4),
(9, 14),
(0, 4),
(9, 14)]

self._assert_filth_type_and_pos(doc_list, beg_end_list, NameFilth)

def test_organisations(self):
doc_list = ["She started working for Apple this year",
"But used to work for Google"]
beg_end_list = [(24, 30),
(21, 27)]

self._assert_filth_type_and_pos(doc_list, beg_end_list, OrganizationFilth)

def test_other_entity(self):
self.detector.named_entities = {"GPE"}
doc_list = ["London is a city in England"]
beg_end_list = [(0, 6),
(20, 27)]

self._assert_filth_type_and_pos(doc_list, beg_end_list, NamedEntityFilth)

def test_wrong_model(self):
"""Test that it raises an error if user inputs invalid spacy model"""
with self.assertRaises(SystemExit):
NamedEntityDetector(model='not_a_valid_spacy_model')

def test_iter_filth(self):
doc = "John is a cat"

output_iter_docs = list(self.detector.iter_filth_documents(doc_list=[doc], doc_names=["0"]))
output_iter = list(self.detector.iter_filth(text=doc, document_name="0"))

self.assertListEqual(output_iter, output_iter_docs)

0 comments on commit 9c165b0

Please sign in to comment.