Skip to content

Commit

Permalink
Merge branch 'master' into change_defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
Ghostvv committed Apr 27, 2020
2 parents 8512704 + 850344f commit 2b077b2
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 16 deletions.
3 changes: 3 additions & 0 deletions changelog/3419.improvement.rst
@@ -0,0 +1,3 @@
Include the source filename of a story in the failed stories

Include the source filename of a story in the failed stories to make it easier to identify the file which contains the failed story.
11 changes: 8 additions & 3 deletions rasa/core/test.py
Expand Up @@ -18,6 +18,8 @@

import matplotlib

FAILED_STORIES_FILE = "failed_stories.md"

# At first, matplotlib will be initialized with default OS-specific available backend
# if that didn't happen, we'll try to set it up manually
if matplotlib.get_backend() is not None:
Expand Down Expand Up @@ -357,7 +359,10 @@ def _predict_tracker_actions(
events = list(tracker.events)

partial_tracker = DialogueStateTracker.from_events(
tracker.sender_id, events[:1], agent.domain.slots
tracker.sender_id,
events[:1],
agent.domain.slots,
sender_source=tracker.sender_source,
)

tracker_actions = []
Expand Down Expand Up @@ -496,13 +501,13 @@ def log_failed_stories(failed, out_directory):
if not out_directory:
return
with open(
os.path.join(out_directory, "failed_stories.md"), "w", encoding=DEFAULT_ENCODING
os.path.join(out_directory, FAILED_STORIES_FILE), "w", encoding=DEFAULT_ENCODING
) as f:
if len(failed) == 0:
f.write("<!-- All stories passed -->")
else:
for failure in failed:
f.write(failure.export_stories())
f.write(failure.export_stories(include_source=True))
f.write("\n\n")


Expand Down
15 changes: 12 additions & 3 deletions rasa/core/trackers.py
Expand Up @@ -104,8 +104,9 @@ def from_events(
evts: List[Event],
slots: Optional[List[Slot]] = None,
max_event_history: Optional[int] = None,
sender_source: Optional[Text] = None,
):
tracker = cls(sender_id, slots, max_event_history)
tracker = cls(sender_id, slots, max_event_history, sender_source)
for e in evts:
tracker.update(e)
return tracker
Expand All @@ -115,6 +116,7 @@ def __init__(
sender_id: Text,
slots: Optional[Iterable[Slot]],
max_event_history: Optional[int] = None,
sender_source: Optional[Text] = None,
) -> None:
"""Initialize the tracker.
Expand All @@ -133,6 +135,8 @@ def __init__(
self.slots = {slot.name: copy.deepcopy(slot) for slot in slots}
else:
self.slots = AnySlotDict()
# file source of the messages
self.sender_source = sender_source

###
# current state of the tracker - MUST be re-creatable by processing
Expand Down Expand Up @@ -471,13 +475,18 @@ def update(self, event: Event, domain: Optional[Domain] = None) -> None:
for e in domain.slots_for_entities(event.parse_data["entities"]):
self.update(e)

def export_stories(self, e2e: bool = False) -> Text:
def export_stories(self, e2e: bool = False, include_source: bool = False) -> Text:
"""Dump the tracker as a story in the Rasa Core story format.
Returns the dumped tracker as a string."""
from rasa.core.training.structures import Story

story = Story.from_events(self.applied_events(), self.sender_id)
story_name = (
f"{self.sender_id} ({self.sender_source})"
if include_source
else self.sender_id
)
story = Story.from_events(self.applied_events(), story_name)
return story.as_story_string(flat=True, e2e=e2e)

def export_stories_to_file(self, export_path: Text = "debug.md") -> None:
Expand Down
21 changes: 15 additions & 6 deletions rasa/core/training/dsl.py
Expand Up @@ -75,8 +75,9 @@ def _parse_item(self, line: Text) -> Optional["Message"]:


class StoryStepBuilder:
def __init__(self, name):
def __init__(self, name: Text, source_name: Text):
self.name = name
self.source_name = source_name
self.story_steps = []
self.current_steps = []
self.start_checkpoints = []
Expand Down Expand Up @@ -160,7 +161,11 @@ def _next_story_steps(self):
if not start_checkpoints:
start_checkpoints = [Checkpoint(STORY_START)]
current_turns = [
StoryStep(block_name=self.name, start_checkpoints=start_checkpoints)
StoryStep(
block_name=self.name,
start_checkpoints=start_checkpoints,
source_name=self.source_name,
)
]
return current_turns

Expand All @@ -174,13 +179,15 @@ def __init__(
domain: Optional[Domain] = None,
template_vars: Optional[Dict] = None,
use_e2e: bool = False,
source_name: Text = None,
):
self.story_steps = []
self.current_step_builder: Optional[StoryStepBuilder] = None
self.domain = domain
self.interpreter = interpreter
self.template_variables = template_vars if template_vars else {}
self.use_e2e = use_e2e
self.source_name = source_name

@staticmethod
async def read_from_folder(
Expand Down Expand Up @@ -250,7 +257,9 @@ async def read_from_file(
try:
with open(filename, "r", encoding=io_utils.DEFAULT_ENCODING) as f:
lines = f.readlines()
reader = StoryFileReader(interpreter, domain, template_variables, use_e2e)
reader = StoryFileReader(
interpreter, domain, template_variables, use_e2e, filename
)
return await reader.process_lines(lines)
except ValueError as err:
file_info = "Invalid story file format. Failed to parse '{}'".format(
Expand Down Expand Up @@ -327,7 +336,7 @@ async def process_lines(self, lines: List[Text]) -> List[StoryStep]:
elif line.startswith("#"):
# reached a new story block
name = line[1:].strip("# ")
self.new_story_part(name)
self.new_story_part(name, self.source_name)
elif line.startswith(">"):
# reached a checkpoint
name, conditions = self._parse_event_line(line[1:].strip())
Expand Down Expand Up @@ -389,9 +398,9 @@ def _add_current_stories_to_result(self):
self.current_step_builder.flush()
self.story_steps.extend(self.current_step_builder.story_steps)

def new_story_part(self, name):
def new_story_part(self, name: Text, source_name: Text):
self._add_current_stories_to_result()
self.current_step_builder = StoryStepBuilder(name)
self.current_step_builder = StoryStepBuilder(name, source_name)

def add_checkpoint(self, name: Text, conditions: Optional[Dict[Text, Any]]) -> None:

Expand Down
7 changes: 5 additions & 2 deletions rasa/core/training/generator.py
Expand Up @@ -81,7 +81,9 @@ def init_copy(self) -> "TrackerWithCachedStates":
self.is_augmented,
)

def copy(self, sender_id: Text = "") -> "TrackerWithCachedStates":
def copy(
self, sender_id: Text = "", sender_source: Text = ""
) -> "TrackerWithCachedStates":
"""Creates a duplicate of this tracker.
A new tracker will be created and all events
Expand All @@ -92,6 +94,7 @@ def copy(self, sender_id: Text = "") -> "TrackerWithCachedStates":

tracker = self.init_copy()
tracker.sender_id = sender_id
tracker.sender_source = sender_source

for event in self.events:
tracker.update(event, skip_states=True)
Expand Down Expand Up @@ -526,7 +529,7 @@ def _process_step(
new_sender = tracker.sender_id
else:
new_sender = step.block_name
trackers.append(tracker.copy(new_sender))
trackers.append(tracker.copy(new_sender, step.source_name))

end_trackers = []
for event in events:
Expand Down
3 changes: 3 additions & 0 deletions rasa/core/training/structures.py
Expand Up @@ -113,12 +113,14 @@ def __init__(
start_checkpoints: Optional[List[Checkpoint]] = None,
end_checkpoints: Optional[List[Checkpoint]] = None,
events: Optional[List[Event]] = None,
source_name: Optional[Text] = None,
) -> None:

self.end_checkpoints = end_checkpoints if end_checkpoints else []
self.start_checkpoints = start_checkpoints if start_checkpoints else []
self.events = events if events else []
self.block_name = block_name
self.source_name = source_name
# put a counter prefix to uuid to get reproducible sorting results
global STEP_COUNT
self.id = "{}_{}".format(STEP_COUNT, uuid.uuid4().hex)
Expand All @@ -132,6 +134,7 @@ def create_copy(self, use_new_id: bool) -> "StoryStep":
self.start_checkpoints,
self.end_checkpoints,
self.events[:],
self.source_name,
)
if not use_new_id:
copied.id = self.id
Expand Down
29 changes: 27 additions & 2 deletions tests/core/test_evaluation.py
@@ -1,7 +1,13 @@
import os
from pathlib import Path

from rasa.core.test import _generate_trackers, collect_story_predictions, test
import rasa.utils.io
from rasa.core.test import (
_generate_trackers,
collect_story_predictions,
test,
FAILED_STORIES_FILE,
)
from rasa.core.policies.memoization import MemoizationPolicy

# we need this import to ignore the warning...
Expand All @@ -18,7 +24,7 @@


async def test_evaluation_image_creation(tmpdir: Path, default_agent: Agent):
stories_path = str(tmpdir / "failed_stories.md")
stories_path = str(tmpdir / FAILED_STORIES_FILE)
img_path = str(tmpdir / "story_confmat.pdf")

await test(
Expand Down Expand Up @@ -96,6 +102,25 @@ async def test_end_to_evaluation_with_forms(form_bot_agent: Agent):
assert not story_evaluation.evaluation_store.has_prediction_target_mismatch()


async def test_source_in_failed_stories(tmpdir: Path, default_agent: Agent):
stories_path = str(tmpdir / FAILED_STORIES_FILE)

await test(
stories=E2E_STORY_FILE_UNKNOWN_ENTITY,
agent=default_agent,
out_directory=str(tmpdir),
max_stories=None,
e2e=False,
)

failed_stories = rasa.utils.io.read_file(stories_path)

assert (
f"## simple_story_with_unknown_entity ({E2E_STORY_FILE_UNKNOWN_ENTITY})"
in failed_stories
)


async def test_end_to_evaluation_trips_circuit_breaker():
agent = Agent(
domain="data/test_domains/default.yml",
Expand Down

0 comments on commit 2b077b2

Please sign in to comment.