Skip to content
This repository has been archived by the owner on Aug 22, 2019. It is now read-only.

Commit

Permalink
Merge pull request #1521 from RasaHQ/e2e-entity-fix
Browse files Browse the repository at this point in the history
fix error with e2e and unseen entities
  • Loading branch information
EPedrotti committed Dec 31, 2018
2 parents 534d221 + 2f54166 commit 5a65c49
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 3 deletions.
5 changes: 5 additions & 0 deletions data/test_evaluations/story_unknown_entity.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
## simple_story_with_unknown_entity
* greet: I am [rasa](name)
- utter_greet
* goodbye:/goodbye
- utter_goodbye
5 changes: 3 additions & 2 deletions rasa_core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,9 @@ def get_parsing_states(self,
intent_config = self.intent_config(intent_name)
should_use_entity = intent_config.get('use_entities', True)
if should_use_entity:
key = "entity_{0}".format(entity["entity"])
state_dict[key] = 1.0
if "entity" in entity:
key = "entity_{0}".format(entity["entity"])
state_dict[key] = 1.0

# Set all set slots with the featurization of the stored value
for key, slot in tracker.slots.items():
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

END_TO_END_STORY_FILE = "data/test_evaluations/end_to_end_story.md"

E2E_STORY_FILE_UNKNOWN_ENTITY = "data/test_evaluations/story_unknown_entity.md"

MOODBOT_MODEL_PATH = "examples/moodbot/models/dialogue"

DEFAULT_ENDPOINTS_FILE = "data/test_endpoints/example_endpoints.yml"
Expand Down
21 changes: 20 additions & 1 deletion tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from rasa_core.evaluate import (
run_story_evaluation,
collect_story_predictions)
from tests.conftest import DEFAULT_STORIES_FILE, END_TO_END_STORY_FILE
from tests.conftest import DEFAULT_STORIES_FILE, END_TO_END_STORY_FILE, \
E2E_STORY_FILE_UNKNOWN_ENTITY


# from tests.conftest import E2E_STORY_FILE_UNKNOWN_ENTITY


def test_evaluation_image_creation(tmpdir, default_agent):
Expand Down Expand Up @@ -50,3 +54,18 @@ def test_end_to_end_evaluation_script(tmpdir, default_agent):
has_prediction_target_mismatch()
assert len(story_evaluation.failed_stories) == 0
assert num_stories == 2


def test_end_to_end_evaluation_script_unknown_entity(tmpdir, default_agent):
completed_trackers = evaluate._generate_trackers(
E2E_STORY_FILE_UNKNOWN_ENTITY, default_agent, use_e2e=True)

story_evaluation, num_stories = collect_story_predictions(
completed_trackers,
default_agent,
use_e2e=True)

assert story_evaluation.evaluation_store. \
has_prediction_target_mismatch()
assert len(story_evaluation.failed_stories) == 1
assert num_stories == 1

0 comments on commit 5a65c49

Please sign in to comment.