Skip to content
This repository has been archived by the owner on Aug 22, 2019. It is now read-only.

Fix new interactive training feature to add custom actions to domain #1375

Merged
merged 23 commits into from Nov 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 2 additions & 5 deletions examples/restaurantbot/policy.py
Expand Up @@ -13,11 +13,8 @@
class RestaurantPolicy(KerasPolicy):
def model_architecture(self, input_shape, output_shape):
"""Build a Keras model and return a compiled model."""
from keras.layers import LSTM, Activation, Masking, Dense
from keras.models import Sequential

from keras.models import Sequential
from keras.layers import \
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import \
Masking, LSTM, Dense, TimeDistributed, Activation

# Build Model
Expand Down
2 changes: 1 addition & 1 deletion rasa_core/channels/botframework.py
Expand Up @@ -185,7 +185,7 @@ def webhook():
except Exception as e:
logger.error("Exception when trying to handle "
"message.{0}".format(e))
logger.error(e, exc_info=True)
logger.debug(e, exc_info=True)
pass

return "success"
Expand Down
2 changes: 1 addition & 1 deletion rasa_core/channels/mattermost.py
Expand Up @@ -102,7 +102,7 @@ def webhook():
except Exception as e:
logger.error("Exception when trying to handle "
"message.{0}".format(e))
logger.error(e, exc_info=True)
logger.debug(e, exc_info=True)
pass
return make_response()

Expand Down
2 changes: 1 addition & 1 deletion rasa_core/channels/telegram.py
Expand Up @@ -171,7 +171,7 @@ def message():
except Exception as e:
logger.error("Exception when trying to handle "
"message.{0}".format(e))
logger.error(e, exc_info=True)
logger.debug(e, exc_info=True)
if self.debug_mode:
raise
pass
Expand Down
2 changes: 1 addition & 1 deletion rasa_core/channels/twilio.py
Expand Up @@ -103,7 +103,7 @@ def message():
except Exception as e:
logger.error("Exception when trying to handle "
"message.{0}".format(e))
logger.error(e, exc_info=True)
logger.debug(e, exc_info=True)
if self.debug_mode:
raise
pass
Expand Down
25 changes: 21 additions & 4 deletions rasa_core/domain.py
Expand Up @@ -144,7 +144,12 @@ def merge_lists(l1, l2):
merged_intents = merge_dicts(intents_1, intents_2, override)
combined['intents'] = list(merged_intents.values())

for key in ['entities', 'actions']:
# remove existing forms from new actions
for form in combined['forms']:
if form in domain_dict['actions']:
domain_dict['actions'].remove(form)

for key in ['entities', 'actions', 'forms']:
combined[key] = merge_lists(combined[key],
domain_dict[key])

Expand Down Expand Up @@ -292,7 +297,7 @@ def action_for_index(self, index, action_endpoint):

if self.num_actions <= index or index < 0:
raise IndexError(
"Can not access action at index {}. "
"Cannot access action at index {}. "
"Domain has {} actions.".format(index, self.num_actions))
return self.action_for_name(self.action_names[index],
action_endpoint)
Expand All @@ -314,7 +319,7 @@ def _raise_action_not_found_exception(self, action_name):
action_names = "\n".join(["\t - {}".format(a)
for a in self.action_names])
raise NameError(
"Can not access action '{}', "
"Cannot access action '{}', "
"as that name is not a registered action for this domain. "
"Available actions are: \n{}"
"".format(action_name, action_names))
Expand Down Expand Up @@ -575,9 +580,21 @@ def persist_clean(self, filename):
if intent.get("use_entities"):
data["intents"][idx] = name

for name, slot in data["slots"].items():
for slot in data["slots"].values():
if slot["initial_value"] is None:
del slot["initial_value"]
if slot["auto_fill"]:
del slot["auto_fill"]
if slot["type"].startswith('rasa_core.slots'):
slot["type"] = Slot.resolve_by_type(slot["type"]).type_name

if data["config"]["store_entities_as_slots"]:
del data["config"]["store_entities_as_slots"]

# clean empty keys
data = {k: v
for k, v in data.items()
if v != {} and v != [] and v is not None}

utils.dump_obj_as_yaml_to_file(filename, data)

Expand Down
4 changes: 2 additions & 2 deletions rasa_core/processor.py
Expand Up @@ -358,8 +358,8 @@ def _run_action(self, action, tracker, dispatcher, policy=None,
logger.error("Encountered an exception while running action '{}'. "
"Bot will continue, but the actions events are lost. "
"Make sure to fix the exception in your custom "
"code.".format(action.name()), )
logger.error(e, exc_info=True)
"code.".format(action.name()))
logger.debug(e, exc_info=True)
events = []

self._log_action_on_tracker(tracker, action.name(), events, policy,
Expand Down
6 changes: 5 additions & 1 deletion rasa_core/server.py
Expand Up @@ -232,7 +232,11 @@ def execute_action(sender_id):
except ValueError as e:
return error(400, "ValueError", e)
except Exception as e:
logger.exception(e)
logger.error("Encountered an exception while running action '{}'. "
"Bot will continue, but the actions events are lost. "
"Make sure to fix the exception in your custom "
"code.".format(action_to_execute))
logger.debug(e, exc_info=True)
return error(500, "ValueError",
"Server failure. Error: {}".format(e))

Expand Down
36 changes: 23 additions & 13 deletions rasa_core/training/interactive.py
Expand Up @@ -9,17 +9,16 @@
import logging
import numpy as np
import os
import pkg_resources
import requests
import six
import textwrap
import uuid
from PyInquirer import prompt
from colorclass import Color
from flask import Flask, send_from_directory, send_file, abort
from flask import Flask, send_file, abort
from gevent.pywsgi import WSGIServer
from rasa_core import utils, server, events, constants
from rasa_core.actions.action import ACTION_LISTEN_NAME
from rasa_core.actions.action import ACTION_LISTEN_NAME, default_action_names
from rasa_core.agent import Agent
from rasa_core.channels import UserMessage
from rasa_core.channels.channel import button_to_string
Expand Down Expand Up @@ -176,14 +175,19 @@ def send_action(endpoint, # type: EndpointConfig
return _response_as_json(r)
except requests.exceptions.HTTPError:
if is_new_action:
logger.warning("You have created a new action: {} "
"which was not successfully executed. \n"
warning_questions = [{
"name": "warning",
"type": "confirm",
"message": "WARNING: You have created a new action: '{}', "
"which was not successfully executed. "
"If this action does not return any events, "
"you do not need to do anything. \n"
"you do not need to do anything. "
"If this is a custom action which returns events, "
"you are recommended to implement this action "
"in your action server and try again."
"".format(action_name))
"".format(action_name)
}]
_ask_questions(warning_questions, sender_id, endpoint)

payload = ActionExecuted(action_name).as_dict()

Expand Down Expand Up @@ -728,7 +732,9 @@ def _collect_actions(evts):
# type: (List[Dict[Text, Any]]) -> List[Dict[Text, Any]]
"""Collect all the `ActionExecuted` events into a list."""

return [evt for evt in evts if evt.get("event") == ActionExecuted.type_name]
return [evt
for evt in evts
if evt.get("event") == ActionExecuted.type_name]


def _write_stories_to_file(export_story_path, evts):
Expand Down Expand Up @@ -784,7 +790,7 @@ def _entities_from_messages(messages):


def _intents_from_messages(messages):
"""Return all intents that occur in atleast one of the messages."""
"""Return all intents that occur in at least one of the messages."""

# set of distinct intents
intents = {m.data["intent"]
Expand All @@ -806,10 +812,14 @@ def _write_domain_to_file(domain_path, evts, endpoint):

domain_dict = dict.fromkeys(domain.keys(), {}) # type: Dict[Text, Any]

# TODO for now there is no way to distinguish between action and form
domain_dict["forms"] = []
domain_dict["intents"] = _intents_from_messages(messages)
domain_dict["entities"] = _entities_from_messages(messages)
domain_dict["actions"] = list({e["name"] for e in actions})
# do not automatically add default actions to the domain dict
domain_dict["actions"] = list({e["name"]
for e in actions
if e["name"] not in default_action_names()})

new_domain = Domain.from_dict(domain_dict)

Expand Down Expand Up @@ -922,7 +932,7 @@ def _confirm_form_validation(action_name, tracker, endpoint, sender_id):
# handle contradiction with learned behaviour
warning_questions = [{
"name": "warning",
"type": "input",
"type": "confirm",
"message": "ERROR: FormPolicy predicted no form validation "
"based on previous training stories. "
"Make sure to remove contradictory stories "
Expand Down Expand Up @@ -1298,8 +1308,8 @@ def _start_interactive_learning_io(endpoint, stories, on_finish,
finetune=False,
skip_visualization=False):
# type: (EndpointConfig, Text, Callable[[], None], bool, bool) -> None
"""Start the interactive learning message recording in a separate thread."""

"""Start the interactive learning message recording in a separate thread.
"""
p = Thread(target=record_messages,
kwargs={
"endpoint": endpoint,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -40,7 +40,7 @@ python-dateutil==2.7.3
rocketchat_API==0.6.22
colorclass==2.2.0
terminaltables==3.1.0
PyInquirer==1.0.2
PyInquirer==1.0.3
flask-jwt-simple==0.0.3
python-socketio==2.0.0
prompt_toolkit==1.0.14
Expand Down
24 changes: 24 additions & 0 deletions tests/test_interactive.py
Expand Up @@ -6,6 +6,7 @@
from rasa_core import utils
from rasa_core.training import interactive
from rasa_core.utils import EndpointConfig
from rasa_core.actions.action import default_actions


@pytest.fixture
Expand Down Expand Up @@ -205,3 +206,26 @@ def test_undo_latest_msg(mock_endpoint):
replaced_evts = json.loads(b)
assert len(replaced_evts) == 6
assert replaced_evts == evts[:6]


def test_interactive_domain_persistence(mock_endpoint, tmpdir):
# Test method interactive._write_domain_to_file

tracker_dump = "data/test_trackers/tracker_moodbot.json"
tracker_json = utils.read_json_file(tracker_dump)

events = tracker_json.get("events", [])

domain_path = tmpdir.join("interactive_domain_save.yml").strpath

url = '{}/domain'.format(mock_endpoint.url)
httpretty.register_uri(httpretty.GET, url, body='{}')

httpretty.enable()
interactive._write_domain_to_file(domain_path, events, mock_endpoint)
httpretty.disable()

saved_domain = utils.read_yaml_file(domain_path)

for default_action in default_actions():
assert default_action.name() not in saved_domain["actions"]