Skip to content

Commit

Permalink
fixed capitalization issues during spacy NER
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Feb 13, 2018
1 parent 7e37f9c commit 3b77c42
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
12 changes: 9 additions & 3 deletions rasa_nlu/extractors/spacy_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@ class SpacyEntityExtractor(EntityExtractor):

provides = ["entities"]

requires = ["spacy_doc"]
requires = ["spacy_nlp"]

def process(self, message, **kwargs):
# type: (Message, **Any) -> None

extracted = self.add_extractor_name(self.extract_entities(message.get("spacy_doc")))
message.set("entities", message.get("entities", []) + extracted, add_to_output=True)
# can't use the existing doc here (spacy_doc on the message)
# because tokens are lower cased which is bad for NER
spacy_nlp = kwargs.get("spacy_nlp", None)
doc = spacy_nlp(message.text)
extracted = self.add_extractor_name(self.extract_entities(doc))
message.set("entities",
message.get("entities", []) + extracted,
add_to_output=True)

def extract_entities(self, doc):
# type: (Doc) -> List[Dict[Text, Any]]
Expand Down
2 changes: 1 addition & 1 deletion rasa_nlu/utils/spacy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def provide_context(self):
return {"spacy_nlp": self.nlp}

def train(self, training_data, config, **kwargs):
# type: (TrainingData) -> Dict[Text, Any]
# type: (TrainingData, RasaNLUConfig, **Any) -> None

for example in training_data.training_examples:
example.set("spacy_doc", self.nlp(example.text.lower()))
Expand Down
16 changes: 16 additions & 0 deletions tests/base/test_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import print_function
from __future__ import unicode_literals

from rasa_nlu.extractors.spacy_entity_extractor import SpacyEntityExtractor
from tests import utilities
from rasa_nlu.training_data import TrainingData, Message

Expand Down Expand Up @@ -107,3 +108,18 @@ def test_unintentional_synonyms_capitalized(component_builder):
ner_syn.train(TrainingData(training_examples=examples), _config)
assert ner_syn.synonyms.get("mexican") is None
assert ner_syn.synonyms.get("tacos") == "Mexican"


def test_spacy_ner_extractor(spacy_nlp):
ext = SpacyEntityExtractor()
example = Message("anywhere in the West", {
"intent": "restaurant_search",
"entities": [],
"spacy_doc": spacy_nlp("anywhere in the west")})

ext.process(example, spacy_nlp=spacy_nlp)

assert len(example.get("entities", [])) == 1
assert example.get("entities")[0] == {
u'start': 16, u'extractor': u'ner_spacy',
u'end': 20, u'value': u'West', u'entity': u'LOC'}

0 comments on commit 3b77c42

Please sign in to comment.