diff --git a/rasa_nlu/extractors/spacy_entity_extractor.py b/rasa_nlu/extractors/spacy_entity_extractor.py index 46c0e87f03b..c9f8280db35 100644 --- a/rasa_nlu/extractors/spacy_entity_extractor.py +++ b/rasa_nlu/extractors/spacy_entity_extractor.py @@ -20,13 +20,19 @@ class SpacyEntityExtractor(EntityExtractor): provides = ["entities"] - requires = ["spacy_doc"] + requires = ["spacy_nlp"] def process(self, message, **kwargs): # type: (Message, **Any) -> None - extracted = self.add_extractor_name(self.extract_entities(message.get("spacy_doc"))) - message.set("entities", message.get("entities", []) + extracted, add_to_output=True) + # can't use the existing doc here (spacy_doc on the message) + # because tokens are lower cased which is bad for NER + spacy_nlp = kwargs.get("spacy_nlp", None) + doc = spacy_nlp(message.text) + extracted = self.add_extractor_name(self.extract_entities(doc)) + message.set("entities", + message.get("entities", []) + extracted, + add_to_output=True) def extract_entities(self, doc): # type: (Doc) -> List[Dict[Text, Any]] diff --git a/rasa_nlu/utils/spacy_utils.py b/rasa_nlu/utils/spacy_utils.py index f24e08350f7..d90a147d629 100644 --- a/rasa_nlu/utils/spacy_utils.py +++ b/rasa_nlu/utils/spacy_utils.py @@ -71,7 +71,7 @@ def provide_context(self): return {"spacy_nlp": self.nlp} def train(self, training_data, config, **kwargs): - # type: (TrainingData) -> Dict[Text, Any] + # type: (TrainingData, RasaNLUConfig, **Any) -> None for example in training_data.training_examples: example.set("spacy_doc", self.nlp(example.text.lower())) diff --git a/tests/base/test_extractors.py b/tests/base/test_extractors.py index 765bd115923..1488d36543d 100644 --- a/tests/base/test_extractors.py +++ b/tests/base/test_extractors.py @@ -4,6 +4,7 @@ from __future__ import print_function from __future__ import unicode_literals +from rasa_nlu.extractors.spacy_entity_extractor import SpacyEntityExtractor from tests import utilities from rasa_nlu.training_data import TrainingData, Message @@ -107,3 +108,18 @@ def test_unintentional_synonyms_capitalized(component_builder): ner_syn.train(TrainingData(training_examples=examples), _config) assert ner_syn.synonyms.get("mexican") is None assert ner_syn.synonyms.get("tacos") == "Mexican" + + +def test_spacy_ner_extractor(spacy_nlp): + ext = SpacyEntityExtractor() + example = Message("anywhere in the West", { + "intent": "restaurant_search", + "entities": [], + "spacy_doc": spacy_nlp("anywhere in the west")}) + + ext.process(example, spacy_nlp=spacy_nlp) + + assert len(example.get("entities", [])) == 1 + assert example.get("entities")[0] == { + u'start': 16, u'extractor': u'ner_spacy', + u'end': 20, u'value': u'West', u'entity': u'LOC'}