Skip to content

Commit

Permalink
Merge branch 'master' into docs-with-contents
Browse files Browse the repository at this point in the history
  • Loading branch information
erohmensing committed Jul 25, 2019
2 parents 4574060 + b7188c2 commit b8dfd76
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 12 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Expand Up @@ -27,11 +27,12 @@ Removed

Fixed
-----
- interactive learning bug where reverted user utterances were dumped to training data
- added timeout to terminal input channel to avoid freezing input in case of server
errors
- fill slots for image, buttons, quick_replies and attachments in templates
- ``rasa train core`` in comparison mode stores the model files compressed (``tar.gz`` files)

- slot setting in interactive learning with the TwoStageFallbackPolicy

[1.1.7] - 2019-07-18
^^^^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions rasa/cli/scaffold.py
Expand Up @@ -81,6 +81,7 @@ def print_run_or_instructions(args: argparse.Namespace, path: Text) -> None:
"jwt_secret",
"jwt_method",
"enable_api",
"remote_storage",
]
for a in attributes:
setattr(args, a, None)
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/processor.py
Expand Up @@ -522,7 +522,7 @@ def _log_action_on_tracker(self, tracker, action_name, events, policy, confidenc
# the timestamp would indicate a time before the time
# of the action executed
e.timestamp = time.time()
tracker.update(e)
tracker.update(e, self.domain)

def _get_tracker(self, sender_id: Text) -> Optional[DialogueStateTracker]:
sender_id = sender_id or UserMessage.DEFAULT_SENDER_ID
Expand Down
29 changes: 20 additions & 9 deletions rasa/core/training/interactive.py
Expand Up @@ -572,9 +572,12 @@ async def _write_data_to_file(sender_id: Text, endpoint: EndpointConfig):
tracker = await retrieve_tracker(endpoint, sender_id)
events = tracker.get("events", [])

await _write_stories_to_file(story_path, events)
serialised_domain = await retrieve_domain(endpoint)
domain = Domain.from_dict(serialised_domain)

await _write_stories_to_file(story_path, events, domain)
await _write_nlu_to_file(nlu_path, events)
await _write_domain_to_file(domain_path, events, endpoint)
await _write_domain_to_file(domain_path, events, domain)

logger.info("Successfully wrote stories and NLU data")

Expand Down Expand Up @@ -759,6 +762,9 @@ def _collect_messages(events: List[Dict[Text, Any]]) -> List[Message]:
msg = Message.build(data["text"], data["intent"]["name"], data["entities"])
msgs.append(msg)

elif event.get("event") == UserUtteranceReverted.type_name and msgs:
msgs.pop() # user corrected the nlu, remove incorrect example

return msgs


Expand All @@ -769,7 +775,7 @@ def _collect_actions(events: List[Dict[Text, Any]]) -> List[Dict[Text, Any]]:


async def _write_stories_to_file(
export_story_path: Text, events: List[Dict[Text, Any]]
export_story_path: Text, events: List[Dict[Text, Any]], domain: Domain
) -> None:
"""Write the conversation of the sender_id to the file paths."""

Expand All @@ -783,10 +789,18 @@ async def _write_stories_to_file(
append_write = "w" # make a new file if not

with open(export_story_path, append_write, encoding="utf-8") as f:
i = 1
for conversation in sub_conversations:
parsed_events = rasa.core.events.deserialise_events(conversation)
s = Story.from_events(parsed_events)
f.write("\n" + s.as_story_string(flat=True))
tracker = DialogueStateTracker.from_events(
"interactive_story_{}".format(i), evts=parsed_events, slots=domain.slots
)

if any(
isinstance(event, UserUttered) for event in tracker.applied_events()
):
i += 1
f.write("\n" + tracker.export_stories())


async def _write_nlu_to_file(
Expand Down Expand Up @@ -838,15 +852,12 @@ def _intents_from_messages(messages):


async def _write_domain_to_file(
domain_path: Text, events: List[Dict[Text, Any]], endpoint: EndpointConfig
domain_path: Text, events: List[Dict[Text, Any]], old_domain: Domain
) -> None:
"""Write an updated domain file to the file path."""

io_utils.create_path(domain_path)

domain = await retrieve_domain(endpoint)
old_domain = Domain.from_dict(domain)

messages = _collect_messages(events)
actions = _collect_actions(events)
templates = NEW_TEMPLATES
Expand Down
6 changes: 5 additions & 1 deletion tests/core/test_interactive.py
Expand Up @@ -8,6 +8,7 @@
from rasa.core.training import interactive
from rasa.utils.endpoints import EndpointConfig
from rasa.core.actions.action import default_actions
from rasa.core.domain import Domain
from tests.utilities import latest_request, json_of_latest_request


Expand Down Expand Up @@ -311,7 +312,10 @@ async def test_interactive_domain_persistence(mock_endpoint, tmpdir):
with aioresponses() as mocked:
mocked.get(url, payload={})

await interactive._write_domain_to_file(domain_path, events, mock_endpoint)
serialised_domain = await interactive.retrieve_domain(mock_endpoint)
old_domain = Domain.from_dict(serialised_domain)

await interactive._write_domain_to_file(domain_path, events, old_domain)

saved_domain = rasa.utils.io.read_config_file(domain_path)

Expand Down

0 comments on commit b8dfd76

Please sign in to comment.