Skip to content

Commit

Permalink
Merge branch 'master' into sql-event-broker
Browse files Browse the repository at this point in the history
  • Loading branch information
federicotdn committed Jul 25, 2019
2 parents 84c41b4 + e44270a commit 0c520d1
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 17 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ Removed

Fixed
-----
- interactive learning bug where reverted user utterances were dumped to training data
- added timeout to terminal input channel to avoid freezing input in case of server
errors
- fill slots for image, buttons, quick_replies and attachments in templates
- ``rasa train core`` in comparison mode stores the model files compressed (``tar.gz`` files)

- slot setting in interactive learning with the TwoStageFallbackPolicy

[1.1.7] - 2019-07-18
^^^^^^^^^^^^^^^^^^^^
Expand Down
16 changes: 11 additions & 5 deletions rasa/core/nlg/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,18 @@ def _fill_template(
# Getting the slot values in the template variables
template_vars = self._template_variables(filled_slots, kwargs)

# Filling the template variables in the template text
keys_to_interpolate = [
"text",
"image",
"custom",
"button",
"attachment",
"quick_replies",
]
if template_vars:
if "text" in template:
template["text"] = interpolate(template["text"], template_vars)
elif "custom" in template:
template["custom"] = interpolate(template["custom"], template_vars)
for key in keys_to_interpolate:
if key in template:
template[key] = interpolate(template[key], template_vars)
return template

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def _log_action_on_tracker(self, tracker, action_name, events, policy, confidenc
# the timestamp would indicate a time before the time
# of the action executed
e.timestamp = time.time()
tracker.update(e)
tracker.update(e, self.domain)

def _get_tracker(self, sender_id: Text) -> Optional[DialogueStateTracker]:
sender_id = sender_id or UserMessage.DEFAULT_SENDER_ID
Expand Down
29 changes: 20 additions & 9 deletions rasa/core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,12 @@ async def _write_data_to_file(sender_id: Text, endpoint: EndpointConfig):
tracker = await retrieve_tracker(endpoint, sender_id)
events = tracker.get("events", [])

await _write_stories_to_file(story_path, events)
serialised_domain = await retrieve_domain(endpoint)
domain = Domain.from_dict(serialised_domain)

await _write_stories_to_file(story_path, events, domain)
await _write_nlu_to_file(nlu_path, events)
await _write_domain_to_file(domain_path, events, endpoint)
await _write_domain_to_file(domain_path, events, domain)

logger.info("Successfully wrote stories and NLU data")

Expand Down Expand Up @@ -759,6 +762,9 @@ def _collect_messages(events: List[Dict[Text, Any]]) -> List[Message]:
msg = Message.build(data["text"], data["intent"]["name"], data["entities"])
msgs.append(msg)

elif event.get("event") == UserUtteranceReverted.type_name and msgs:
msgs.pop() # user corrected the nlu, remove incorrect example

return msgs


Expand All @@ -769,7 +775,7 @@ def _collect_actions(events: List[Dict[Text, Any]]) -> List[Dict[Text, Any]]:


async def _write_stories_to_file(
export_story_path: Text, events: List[Dict[Text, Any]]
export_story_path: Text, events: List[Dict[Text, Any]], domain: Domain
) -> None:
"""Write the conversation of the sender_id to the file paths."""

Expand All @@ -783,10 +789,18 @@ async def _write_stories_to_file(
append_write = "w" # make a new file if not

with open(export_story_path, append_write, encoding="utf-8") as f:
i = 1
for conversation in sub_conversations:
parsed_events = rasa.core.events.deserialise_events(conversation)
s = Story.from_events(parsed_events)
f.write("\n" + s.as_story_string(flat=True))
tracker = DialogueStateTracker.from_events(
"interactive_story_{}".format(i), evts=parsed_events, slots=domain.slots
)

if any(
isinstance(event, UserUttered) for event in tracker.applied_events()
):
i += 1
f.write("\n" + tracker.export_stories())


async def _write_nlu_to_file(
Expand Down Expand Up @@ -838,15 +852,12 @@ def _intents_from_messages(messages):


async def _write_domain_to_file(
domain_path: Text, events: List[Dict[Text, Any]], endpoint: EndpointConfig
domain_path: Text, events: List[Dict[Text, Any]], old_domain: Domain
) -> None:
"""Write an updated domain file to the file path."""

io_utils.create_path(domain_path)

domain = await retrieve_domain(endpoint)
old_domain = Domain.from_dict(domain)

messages = _collect_messages(events)
actions = _collect_actions(events)
templates = NEW_TEMPLATES
Expand Down
6 changes: 5 additions & 1 deletion tests/core/test_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from rasa.core.training import interactive
from rasa.utils.endpoints import EndpointConfig
from rasa.core.actions.action import default_actions
from rasa.core.domain import Domain
from tests.utilities import latest_request, json_of_latest_request


Expand Down Expand Up @@ -311,7 +312,10 @@ async def test_interactive_domain_persistence(mock_endpoint, tmpdir):
with aioresponses() as mocked:
mocked.get(url, payload={})

await interactive._write_domain_to_file(domain_path, events, mock_endpoint)
serialised_domain = await interactive.retrieve_domain(mock_endpoint)
old_domain = Domain.from_dict(serialised_domain)

await interactive._write_domain_to_file(domain_path, events, old_domain)

saved_domain = rasa.utils.io.read_config_file(domain_path)

Expand Down
102 changes: 102 additions & 0 deletions tests/core/test_nlg.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,19 @@ def test_nlg_fill_template_text(slot_name, slot_value):
assert result == {"text": str(slot_value)}


@pytest.mark.parametrize(
"img_slot_name, img_slot_value",
[("url", "https://www.exampleimg.com"), ("img1", "https://www.appleimg.com")],
)
def test_nlg_fill_template_image(img_slot_name, img_slot_value):
template = {"image": "{" + img_slot_name + "}"}
t = TemplatedNaturalLanguageGenerator(templates=dict())
result = t._fill_template(
template=template, filled_slots={img_slot_name: img_slot_value}
)
assert result == {"image": str(img_slot_value)}


@pytest.mark.parametrize(
"slot_name, slot_value",
[
Expand Down Expand Up @@ -160,3 +173,92 @@ def test_nlg_fill_template_w_bad_slot_name2(slot_name, slot_value):
template={"text": template_text}, filled_slots={slot_name: slot_value}
)
assert result["text"] == template_text


@pytest.mark.parametrize(
"text_slot_name, text_slot_value, img_slot_name, img_slot_value",
[
("tag_w_underscore", "a", "url", "https://www.exampleimg.com"),
("tag with space", "bacon", "img1", "https://www.appleimg.com"),
],
)
def test_nlg_fill_template_image_and_text(
text_slot_name, text_slot_value, img_slot_name, img_slot_value
):
template = {"text": "{" + text_slot_name + "}", "image": "{" + img_slot_name + "}"}
t = TemplatedNaturalLanguageGenerator(templates=dict())
result = t._fill_template(
template=template,
filled_slots={text_slot_name: text_slot_value, img_slot_name: img_slot_value},
)
assert result == {"text": str(text_slot_value), "image": str(img_slot_value)}


@pytest.mark.parametrize(
"text_slot_name, text_slot_value, cust_slot_name, cust_slot_value",
[
("tag_w_underscore", "a", "tag.with.dot", "chocolate"),
("tag with space", "bacon", "tag-w-dash", "apple pie"),
],
)
def test_nlg_fill_template_text_and_custom(
text_slot_name, text_slot_value, cust_slot_name, cust_slot_value
):
template = {
"text": "{" + text_slot_name + "}",
"custom": {
"field": "{" + cust_slot_name + "}",
"properties": {"field_prefixed": "prefix_{" + cust_slot_name + "}"},
},
}
t = TemplatedNaturalLanguageGenerator(templates=dict())
result = t._fill_template(
template=template,
filled_slots={text_slot_name: text_slot_value, cust_slot_name: cust_slot_value},
)
assert result == {
"text": str(text_slot_value),
"custom": {
"field": str(cust_slot_value),
"properties": {"field_prefixed": "prefix_" + str(cust_slot_value)},
},
}


@pytest.mark.parametrize(
"attach_slot_name, attach_slot_value", [("attach_file", "https://attach.pdf")]
)
def test_nlg_fill_template_attachment(attach_slot_name, attach_slot_value):
template = {"attachment": "{" + attach_slot_name + "}"}
t = TemplatedNaturalLanguageGenerator(templates=dict())
result = t._fill_template(
template=template, filled_slots={attach_slot_name: attach_slot_value}
)
assert result == {"attachment": str(attach_slot_value)}


@pytest.mark.parametrize(
"button_slot_name, button_slot_value", [("button_1", "button1")]
)
def test_nlg_fill_template_button(button_slot_name, button_slot_value):
template = {"button": "{" + button_slot_name + "}"}
t = TemplatedNaturalLanguageGenerator(templates=dict())
result = t._fill_template(
template=template, filled_slots={button_slot_name: button_slot_value}
)
assert result == {"button": str(button_slot_value)}


@pytest.mark.parametrize(
"quick_replies_slot_name, quick_replies_slot_value", [("qreply", "reply 1")]
)
def test_nlg_fill_template_quick_replies(
quick_replies_slot_name, quick_replies_slot_value
):
template = {"quick_replies": "{" + quick_replies_slot_name + "}"}
t = TemplatedNaturalLanguageGenerator(templates=dict())
result = t._fill_template(
template=template,
filled_slots={quick_replies_slot_name: quick_replies_slot_value},
)
assert result == {"quick_replies": str(quick_replies_slot_value)}

0 comments on commit 0c520d1

Please sign in to comment.