Skip to content

Commit

Permalink
Merge 91535f6 into 986e24c
Browse files Browse the repository at this point in the history
  • Loading branch information
MetcalfeTom committed Jul 25, 2019
2 parents 986e24c + 91535f6 commit 768d163
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Expand Up @@ -27,6 +27,7 @@ Removed

Fixed
-----
- interactive learning bug where reverted user utterances were dumped to training data
- added timeout to terminal input channel to avoid freezing input in case of server
errors
- ``rasa train core`` in comparison mode stores the model files compressed (``tar.gz`` files)
Expand Down
27 changes: 18 additions & 9 deletions rasa/core/training/interactive.py
Expand Up @@ -572,9 +572,12 @@ async def _write_data_to_file(sender_id: Text, endpoint: EndpointConfig):
tracker = await retrieve_tracker(endpoint, sender_id)
events = tracker.get("events", [])

await _write_stories_to_file(story_path, events)
serialised_domain = await retrieve_domain(endpoint)
domain = Domain.from_dict(serialised_domain)

await _write_stories_to_file(story_path, events, domain)
await _write_nlu_to_file(nlu_path, events)
await _write_domain_to_file(domain_path, events, endpoint)
await _write_domain_to_file(domain_path, events, domain)

logger.info("Successfully wrote stories and NLU data")

Expand Down Expand Up @@ -759,6 +762,9 @@ def _collect_messages(events: List[Dict[Text, Any]]) -> List[Message]:
msg = Message.build(data["text"], data["intent"]["name"], data["entities"])
msgs.append(msg)

elif event.get("event") == UserUtteranceReverted.type_name and msgs:
msgs.pop() # user corrected the nlu, remove incorrect example

return msgs


Expand All @@ -769,7 +775,7 @@ def _collect_actions(events: List[Dict[Text, Any]]) -> List[Dict[Text, Any]]:


async def _write_stories_to_file(
export_story_path: Text, events: List[Dict[Text, Any]]
export_story_path: Text, events: List[Dict[Text, Any]], domain: Domain
) -> None:
"""Write the conversation of the sender_id to the file paths."""

Expand All @@ -783,10 +789,16 @@ async def _write_stories_to_file(
append_write = "w" # make a new file if not

with open(export_story_path, append_write, encoding="utf-8") as f:
i = 1
for conversation in sub_conversations:
parsed_events = rasa.core.events.deserialise_events(conversation)
s = Story.from_events(parsed_events)
f.write("\n" + s.as_story_string(flat=True))
tracker = DialogueStateTracker.from_events(
"interactive_story_{}".format(i), evts=parsed_events, slots=domain.slots
)

if any(isinstance(event, UserUttered) for event in tracker.applied_events()):
i += 1
f.write("\n" + tracker.export_stories())


async def _write_nlu_to_file(
Expand Down Expand Up @@ -838,15 +850,12 @@ def _intents_from_messages(messages):


async def _write_domain_to_file(
domain_path: Text, events: List[Dict[Text, Any]], endpoint: EndpointConfig
domain_path: Text, events: List[Dict[Text, Any]], old_domain: Domain
) -> None:
"""Write an updated domain file to the file path."""

io_utils.create_path(domain_path)

domain = await retrieve_domain(endpoint)
old_domain = Domain.from_dict(domain)

messages = _collect_messages(events)
actions = _collect_actions(events)
templates = NEW_TEMPLATES
Expand Down
6 changes: 5 additions & 1 deletion tests/core/test_interactive.py
Expand Up @@ -8,6 +8,7 @@
from rasa.core.training import interactive
from rasa.utils.endpoints import EndpointConfig
from rasa.core.actions.action import default_actions
from rasa.core.domain import Domain
from tests.utilities import latest_request, json_of_latest_request


Expand Down Expand Up @@ -311,7 +312,10 @@ async def test_interactive_domain_persistence(mock_endpoint, tmpdir):
with aioresponses() as mocked:
mocked.get(url, payload={})

await interactive._write_domain_to_file(domain_path, events, mock_endpoint)
serialised_domain = await interactive.retrieve_domain(mock_endpoint)
old_domain = Domain.from_dict(serialised_domain)

await interactive._write_domain_to_file(domain_path, events, old_domain)

saved_domain = rasa.utils.io.read_config_file(domain_path)

Expand Down

0 comments on commit 768d163

Please sign in to comment.