Skip to content
This repository has been archived by the owner on Aug 22, 2019. It is now read-only.

Commit

Permalink
Merge pull request #1088 from RasaHQ/enchanced-cli
Browse files Browse the repository at this point in the history
Interactive learning enhancements.  Fixes #981
  • Loading branch information
MetcalfeTom committed Oct 8, 2018
2 parents 157e4ed + 7bd6ff0 commit 1bc6692
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 38 deletions.
10 changes: 4 additions & 6 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@ tmp/
tokens.dat
graph.png
debug.md
examples/moodbot/models
examples/restaurantbot/models
examples/concertbot/models
examples/moodbot/*.png
examples/moodbot/errors.json
examples/concertbot/data*
examples/concertbot/models*
examples/moodbot/models*
docs/key
docs/key.pub
secrets.tar
failed_stories.md
failed_stories.md
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ This project adheres to `Semantic Versioning`_ starting with version 0.2.0.
Added
-----
- openapi documentation of server API
- NLU data learned through interactive learning will now be stored in a separate markdown-format file (any previous NLU
data is merged)
- Command line interface for interactive learning now displays policy confidence alongside the action name
- added action prediction confidence & policy to ``ActionExecuted`` event
- both the date and the time at which a model was trained are now included in the policy's metadata when it is persisted


Changed
-------
- improved response format for ``/predict`` endpoint
Expand All @@ -26,6 +30,7 @@ Changed
Removed
-------


Fixed
-----
- fixed an issue with boolean slots where False and None had the same value
Expand Down
2 changes: 1 addition & 1 deletion data/test_trackers/tracker_moodbot.json
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,4 @@
"confidence": 1.0
}
]
}
}
7 changes: 4 additions & 3 deletions docs/interactive_learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ so make sure you have the domain & data for it. You can download
the data from :gh-code:`examples/concertbot`.

If you ask ``/search_concerts``, the bot should suggest
``action_search_concerts`` and then ``action_listen``.
``action_search_concerts`` and then ``action_listen`` (the confidence at which
the policy selected its next action will be displayed next to the action name).
Now let's enter ``/compare_reviews`` as the next user message.
The bot **might** choose the wrong one out of the two
possibilities (depending on the training run, it might also be correct):
Expand All @@ -91,8 +92,8 @@ possibilities (depending on the training run, it might also be correct):
2 /search_concerts
intent: search_concerts 1.00
───────────────────────────────────────────────────────────────
3 action_search_concerts
action_listen
3 action_search_concerts 0.72
action_listen 0.78
───────────────────────────────────────────────────────────────
4 /compare_reviews
intent: compare_reviews 1.00
Expand Down
4 changes: 2 additions & 2 deletions examples/concertbot/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ run-actions:
train-core:
python train.py

run-online:
run-interactive:
make run-actions&
python train_online.py
python train_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging

from rasa_core import utils, train, run
from rasa_core.training import online
from rasa_core.training import interactive

logger = logging.getLogger(__name__)

Expand All @@ -25,4 +25,4 @@ def train_agent():
if __name__ == '__main__':
utils.configure_colored_logging(loglevel="INFO")
agent = train_agent()
online.run_online_learning(agent)
interactive.run_interactive_learning(agent)
1 change: 1 addition & 0 deletions rasa_core/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ def __str__(self):
return ("ActionExecuted(action: {}, policy: {}, confidence: {})"
"".format(self.action_name, self.policy, self.confidence))


def __hash__(self):
return hash(self.action_name)

Expand Down
106 changes: 82 additions & 24 deletions rasa_core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,17 @@
from rasa_core.training.structures import Story
from rasa_core.utils import EndpointConfig
from rasa_nlu.training_data.formats import MarkdownWriter, MarkdownReader
from rasa_nlu.training_data.loading import load_data, _guess_format
from rasa_nlu.training_data.message import Message
from rasa_nlu.training_data import TrainingData

logger = logging.getLogger(__name__)

MAX_VISUAL_HISTORY = 3

DEFAULT_FILE_EXPORT_PATH = "stories.md"
PATHS = {"stories": "data/stories.md",
"nlu": "data/nlu.md",
"backup": "data/nlu_interactive.md"}

# choose other intent, making sure this doesn't clash with an existing intent
OTHER_INTENT = uuid.uuid4().hex
Expand Down Expand Up @@ -138,7 +143,7 @@ def send_action(endpoint, sender_id, action_name):

def send_event(endpoint, sender_id, evt):
# type: (EndpointConfig, Text, Dict[Text, Any]) -> Dict[Text, Any]
"""Log an event to a concersation."""
"""Log an event to a conversation."""

subpath = "/conversations/{}/tracker/events".format(sender_id)

Expand All @@ -151,7 +156,7 @@ def send_event(endpoint, sender_id, evt):

def replace_events(endpoint, sender_id, evts):
# type: (EndpointConfig, Text, List[Dict[Text, Any]]) -> Dict[Text, Any]
"""Replace all the events of a concersation with the provided ones."""
"""Replace all the events of a conversation with the provided ones."""

subpath = "/conversations/{}/tracker/events".format(sender_id)

Expand Down Expand Up @@ -394,6 +399,8 @@ def add_user_cell(data, cell):
for idx, evt in enumerate(evts):
if evt.get("event") == "action":
bot_column.append(colored(evt['name'], 'autocyan'))
if evt['confidence'] is not None:
bot_column[-1] += (colored(" {:03.2f}".format(evt['confidence']), 'autowhite'))

elif evt.get("event") == 'user':
if bot_column:
Expand Down Expand Up @@ -470,10 +477,15 @@ def _ask_if_quit(sender_id, endpoint):

if not answers or answers["abort"] == "quit":
# this is also the default answer if the user presses Ctrl-C
export_file_path = _request_export_stories_info()
_write_stories_to_file(export_file_path, sender_id, endpoint)
logger.info("Successfully wrote stories to "
"{}.".format(export_file_path))
story_path, nlu_path = _request_export_info()

tracker = retrieve_tracker(endpoint, sender_id)
evts = tracker.get("events", [])

_write_stories_to_file(story_path, evts)
_write_nlu_to_file(nlu_path, evts)

logger.info("Successfully wrote stories and NLU data")
sys.exit()
elif answers["abort"] == "continue":
# in this case we will just return, and the original
Expand Down Expand Up @@ -511,9 +523,9 @@ def _request_action_from_user(predictions, sender_id, endpoint):
return action_name


def _request_export_stories_info():
# type: () -> Text
"""Request file path and export stories to that path"""
def _request_export_info():
# type: () -> (Text, Text)
"""Request file path and export stories & nlu data to that path"""

def validate_path(path):
try:
Expand All @@ -522,26 +534,31 @@ def validate_path(path):
except Exception as e:
return "Failed to open file. {}".format(e)

# export current stories and quit
# export training data and quit
questions = [{
"name": "export",
"name": "export stories",
"type": "input",
"message": "Export stories to (if file exists, this "
"will append the stories)",
"default": DEFAULT_FILE_EXPORT_PATH,
"default": PATHS["stories"],
"validate": validate_path
}]
}, {"name": "export nlu",
"type": "input",
"message": "Export NLU data to (if file exists, this "
"will merge learned data with previous training examples)",
"default": PATHS["nlu"],
"validate": validate_path}]

answers = prompt(questions)
if not answers:
sys.exit()

return answers["export"]
return answers["export stories"], answers["export nlu"]


def _split_conversation_at_restarts(evts):
# type: (List[Dict[Text, Any]]) -> List[List[Dict[Text, Any]]]
""""Split a conversation at restart events.
"""Split a conversation at restart events.
Returns an array of event lists, without the restart events."""

Expand All @@ -561,22 +578,63 @@ def _split_conversation_at_restarts(evts):
return sub_conversations


def _write_stories_to_file(export_file_path, sender_id, endpoint):
# type: (Text, Text, EndpointConfig) -> None
"""Write the conversation of the sender_id to the file path."""
def _collect_messages(evts):
# type: (List[Dict[Text, Any]]) -> List[Dict[Text, Any]]
"""Collect the message text and parsed data from the UserMessage events into a list"""

msgs = []

for evt in evts:
if evt.get("event") == "user":
data = evt.get("parse_data")
msg = Message.build(data["text"], data["intent"]["name"], data["entities"])
msgs.append(msg)

return msgs


tracker = retrieve_tracker(endpoint, sender_id)
evts = tracker.get("events", [])
def _write_stories_to_file(export_story_path, evts):
# type: (Text, List[Dict[Text, Any]]) -> None
"""Write the conversation of the sender_id to the file paths."""

sub_conversations = _split_conversation_at_restarts(evts)

with io.open(export_file_path, 'a', encoding="utf-8") as f:
with io.open(export_story_path, 'a', encoding="utf-8") as f:
for conversation in sub_conversations:
parsed_events = events.deserialise_events(conversation)
s = Story.from_events(parsed_events)
f.write(s.as_story_string(flat=True) + "\n")


def _write_nlu_to_file(export_nlu_path, evts):
# type: (Text, List[Dict[Text, Any]]) -> None
"""Write the nlu data of the sender_id to the file paths."""

msgs = _collect_messages(evts)

try:
previous_examples = load_data(export_nlu_path)

except:
questions = [{"name": "export nlu",
"type": "input",
"message": "Could not load existing NLU data, please specify where to store NLU data "
"learned in this session (this will overwrite any existing file)",
"default": PATHS["backup"]}]

answers = prompt(questions)
export_nlu_path = answers["export nlu"]
previous_examples = TrainingData()

nlu_data = previous_examples.merge(TrainingData(msgs))

with io.open(export_nlu_path, 'w', encoding="utf-8") as f:
if _guess_format(export_nlu_path) in ["md", "unk"]:
f.write(nlu_data.as_markdown())
else:
f.write(nlu_data.as_json())


def _predict_till_next_listen(endpoint, # type: EndpointConfig
sender_id, # type: Text
finetune # type: bool
Expand Down Expand Up @@ -784,8 +842,8 @@ def _enter_user_message(sender_id, endpoint, exit_text):
}]

answers = _ask_questions(
questions, sender_id, endpoint,
is_abort=lambda a: a["message"] == exit_text)
questions, sender_id, endpoint,
is_abort=lambda a: a["message"] == exit_text)

send_message(endpoint, sender_id, answers["message"])

Expand Down
2 changes: 2 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def test_requesting_non_existent_tracker(app):
"policy": None,
"confidence": None,
"name": "action_listen",
"policy": None,
"confidence": None,
"timestamp": 1514764800}]
assert content["latest_message"] == {"text": None,
"intent": {},
Expand Down

0 comments on commit 1bc6692

Please sign in to comment.