Skip to content
This repository has been archived by the owner on Aug 22, 2019. It is now read-only.

Commit

Permalink
improved documentation and args to train
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Nov 14, 2018
1 parent fc35d37 commit ecd6c06
Show file tree
Hide file tree
Showing 20 changed files with 260 additions and 204 deletions.
14 changes: 8 additions & 6 deletions docs/evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ by using the evaluate script:

.. code-block:: bash
$ python -m rasa_core.evaluate -d models/dialogue \
-s test_stories.md -o results
$ python -m rasa_core.evaluate --core models/dialogue \
--stories test_stories.md -o results
This will print the failed stories to ``results/failed_stories.md``.
Expand All @@ -34,7 +34,7 @@ incorrect action was predicted instead.

The full list of options for the script is:

.. program-output:: python -m rasa_core.evaluate default -h
.. program-output:: python -m rasa_core.evaluate default --help

.. _end_to_end_evaluation:

Expand Down Expand Up @@ -77,8 +77,9 @@ the full end-to-end evaluation command is this:

.. code-block:: bash
$ python -m rasa_core.evaluate default -d models/dialogue --nlu models/nlu/current \
-s e2e_stories.md --e2e
$ python -m rasa_core.evaluate default --core models/dialogue \
--nlu models/nlu/current \
--stories e2e_stories.md --e2e
.. note::

Expand Down Expand Up @@ -118,7 +119,8 @@ mode to evaluate the models you just trained:

.. code-block:: bash
$ python -m rasa_core.evaluate compare -s stories_folder -d comparison_models \
$ python -m rasa_core.evaluate compare --stories stories_folder \
--core comparison_models \
-o comparison_results
This will evaluate each of the models on the training set, and plot some graphs
Expand Down
2 changes: 1 addition & 1 deletion docs/policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Default configuration
---------------------

By default, we try to provide you with a good set of configuration values
and policies that suite most people. But you are encouraged to modify
and policies that suit most people. But you are encouraged to modify
these to your needs:

.. literalinclude:: ../rasa_core/default_config.yml
Expand Down
9 changes: 5 additions & 4 deletions rasa_core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def _max_history(self):

return max(max_histories, default=0)

def _are_all_featurizes_using_a_max_history(self):
def _are_all_featurizers_using_a_max_history(self):
"""Check if all featurizers are MaxHistoryTrackerFeaturizer."""

for policy in self.policy_ensemble.policies:
Expand Down Expand Up @@ -477,7 +477,7 @@ def load_data(self,
# automatically detect unique_last_num_states
# if it was not set and
# if all featurizers are MaxHistoryTrackerFeaturizer
if self._are_all_featurizes_using_a_max_history():
if self._are_all_featurizers_using_a_max_history():
unique_last_num_states = max_history
elif unique_last_num_states < max_history:
# possibility of data loss
Expand All @@ -503,8 +503,9 @@ def train(self,
# type: (...) -> None
"""Train the policies / policy ensemble using dialogue data from file.
:param training_trackers: trackers to train on
:param kwargs: additional arguments passed to the underlying ML
Args:
training_trackers: trackers to train on
**kwargs: additional arguments passed to the underlying ML
trainer (e.g. keras parameters)
"""

Expand Down
5 changes: 3 additions & 2 deletions rasa_core/channels/botframework.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,9 @@ def __init__(self, app_id, app_password):
# type: (Text, Text) -> None
"""Create a Bot Framework input channel.
:param app_id: Bot Framework's API id
:param app_password: Bot Framework application secret
Args:
app_id: Bot Framework's API id
app_password: Bot Framework application secret
"""

self.app_id = app_id
Expand Down
23 changes: 14 additions & 9 deletions rasa_core/channels/facebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,12 @@ def __init__(self, fb_verify, fb_secret, fb_access_token):
messages. Details to setup:
https://github.com/rehabstudio/fbmessenger#facebook-app-setup
:param fb_verify: FB Verification string
(can be chosen by yourself on webhook creation)
:param fb_secret: facebook application secret
:param fb_access_token: access token to post in the name of the FB page
Args:
fb_verify: FB Verification string
(can be chosen by yourself on webhook creation)
fb_secret: facebook application secret
fb_access_token: access token to post in the name of the FB page
"""
self.fb_verify = fb_verify
self.fb_secret = fb_secret
Expand Down Expand Up @@ -282,12 +284,15 @@ def webhook():
@staticmethod
def validate_hub_signature(app_secret, request_payload,
hub_signature_header):
"""Makes sure the incoming webhook requests are properly signed.
"""Make sure the incoming webhook requests are properly signed.
Args:
app_secret: Secret Key for application
request_payload: request body
hub_signature_header: X-Hub-Signature header sent with request
:param app_secret: Secret Key for application
:param request_payload: request body
:param hub_signature_header: X-Hub-Signature header sent with request
:return: boolean indicated that hub signature is validated
Returns:
bool: indicated that hub signature is validated
"""

# noinspection PyBroadException
Expand Down
13 changes: 6 additions & 7 deletions rasa_core/channels/mattermost.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,13 @@ def __init__(self, url, team, user, pw):
"""Create a Mattermost input channel.
Needs a couple of settings to properly authenticate and validate
messages.
:param url: Your Mattermost team url including /v4 example
https://mysite.example.com/api/v4
:param team: Your mattermost team name
:param user: Your mattermost userid that will post messages
:param pw: Your mattermost password for your user
Args:
url: Your Mattermost team url including /v4 example
https://mysite.example.com/api/v4
team: Your mattermost team name
user: Your mattermost userid that will post messages
pw: Your mattermost password for your user
"""
self.url = url
self.team = team
Expand Down
24 changes: 13 additions & 11 deletions rasa_core/channels/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,19 @@ def __init__(self, slack_token, slack_channel=None,
messages. Details to setup:
https://github.com/slackapi/python-slackclient
:param slack_token: Your Slack Authentication token. You can find or
generate a test token
`here <https://api.slack.com/docs/oauth-test-tokens>`_.
:param slack_channel: the string identifier for a channel to which
the bot posts, or channel name
(e.g. 'C1234ABC', 'bot-test' or '#bot-test')
If unset, messages will be sent back to the user they came from.
:param errors_ignore_retry: If error code given by slack
included in this list then it will ignore the event.
The code is listed here:
https://api.slack.com/events-api#errors
Args:
slack_token: Your Slack Authentication token. You can find or
generate a test token
`here <https://api.slack.com/docs/oauth-test-tokens>`_.
slack_channel: the string identifier for a channel to which
the bot posts, or channel name (e.g. 'C1234ABC', 'bot-test'
or '#bot-test') If unset, messages will be sent back
to the user they came from.
errors_ignore_retry: If error code given by slack
included in this list then it will ignore the event.
The code is listed here:
https://api.slack.com/events-api#errors
"""
self.slack_token = slack_token
self.slack_channel = slack_channel
Expand Down
5 changes: 3 additions & 2 deletions rasa_core/channels/webexteams.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def __init__(self, access_token, room=None):
Needs a couple of settings to properly authenticate and validate
messages. Details here https://developer.webex.com/authentication.html
:param access_token: Cisco WebexTeams bot access token.
:param room: the string identifier for a room to which the bot posts
Args:
access_token: Cisco WebexTeams bot access token.
room: the string identifier for a room to which the bot posts
"""
self.token = access_token
self.room = room
Expand Down
48 changes: 33 additions & 15 deletions rasa_core/cli/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,50 @@
import pkg_resources


def add_config_arg(parser):
def add_config_arg(parser, nargs="*", **kwargs):
"""Add an argument to the parser to request a policy configuration."""

parser.add_argument(
'-c', '--config',
type=str,
nargs="*",
nargs=nargs,
default=[pkg_resources.resource_filename(__name__,
"../default_config.yml")],
help="Policy specification yaml file.")
return parser
help="Policy specification yaml file.",
**kwargs)


def add_domain_arg(parser):
def add_core_model_arg(parser, **kwargs):
"""Add an argument to the parser to request a policy configuration."""

parser.add_argument(
'--core',
type=str,
help="Path to a pre-trained core model directory",
**kwargs)


def add_domain_arg(parser, required=True, **kwargs):
"""Add an argument to the parser to request a the domain file."""

parser.add_argument(
'-d', '--domain',
type=str,
required=True,
help="domain specification yaml file")
return parser
required=required,
help="Domain specification (yml file)",
**kwargs)


def add_output_arg(parser,
help_text,
required=True,
**kwargs):
parser.add_argument(
'-o', '--out',
type=str,
required=required,
help=help_text,
**kwargs)


def add_model_and_story_group(parser, allow_pretrained_model=True):
Expand All @@ -39,17 +61,13 @@ def add_model_and_story_group(parser, allow_pretrained_model=True):
group.add_argument(
'-s', '--stories',
type=str,
help="file or folder containing the training stories")
help="File or folder containing stories")
group.add_argument(
'--url',
type=str,
help="If supplied, downloads a story file from a URL and "
"trains on it. Fetches the data by sending a GET request "
"to the supplied URL.")

if allow_pretrained_model:
group.add_argument(
'--core',
default=None,
help="path to load a pre-trained model instead of training ("
"for interactive mode only)")
return parser
add_core_model_arg(group)
15 changes: 11 additions & 4 deletions rasa_core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,26 @@
from __future__ import print_function
from __future__ import unicode_literals

from typing import Optional, Text, Dict, Any, List
import os
import typing
from typing import Optional, Text, List

from rasa_core import utils
from rasa_core.policies import PolicyEnsemble

if typing.TYPE_CHECKING:
from rasa_core.policies import Policy


def load(config_file):
# type: (Optional[Text], Dict[Text, Any], int) -> List[Policy]
# type: (Optional[Text]) -> List[Policy]
"""Load policy data stored in the specified file."""

if config_file:
if config_file and os.path.isfile(config_file):
config_data = utils.read_yaml_file(config_file)
else:
raise ValueError("You have to provide a config file")
raise ValueError("You have to provide a valid path to a config file. "
"The file '{}' could not be found."
"".format(os.path.abspath(config_file)))

return PolicyEnsemble.from_dict(config_data)
11 changes: 4 additions & 7 deletions rasa_core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from rasa_core.trackers import DialogueStateTracker, SlotSet
from rasa_core.utils import read_file, read_yaml_string, EndpointConfig
from six import string_types
from typing import Dict, Any
from typing import Dict, Any, Tuple
from typing import List
from typing import Optional
from typing import Text
Expand All @@ -34,7 +34,7 @@ class InvalidDomain(Exception):


def check_domain_sanity(domain):
"""Makes sure the domain is properly configured.
"""Make sure the domain is properly configured.
Checks the settings and checks if there are duplicate actions,
intents, slots and entities."""
Expand All @@ -46,11 +46,8 @@ def get_duplicates(my_items):
if count > 1]

def get_exception_message(duplicates):
"""Returns a message given a list of error locations.
Duplicates has the format of (duplicate_actions [List], name [Text]).
:param duplicates:
:return: """
# type: (List[Tuple[List[Text], Text]]) -> Text
"""Return a message given a list of error locations."""

msg = ""
for d, name in duplicates:
Expand Down
24 changes: 13 additions & 11 deletions rasa_core/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ def add_args_to_parser(parser):
'-m', '--max_stories',
type=int,
help="maximum number of stories to test on")
parser.add_argument(
'-d', '--core',
type=str,
help="core model directory to evaluate")
parser.add_argument(
'-u', '--nlu',
type=str,
Expand All @@ -101,6 +97,8 @@ def add_args_to_parser(parser):
"is thrown. This can be used to validate stories during "
"tests, e.g. on travis.")

cli.arguments.add_core_model_arg(parser)

return parser


Expand Down Expand Up @@ -599,11 +597,17 @@ def run_comparison_evaluation(models, stories_file, output):
num_correct)


def plot_curve(output, no_stories, ax=None):
"""Plot the results from run_comparison_evaluation."""
def plot_curve(output, no_stories):
# type: (Text, List[int]) -> None
"""Plot the results from run_comparison_evaluation.
Args:
output: Output directory to save resulting plots to
no_stories: Number of stories per run
"""
import matplotlib.pyplot as plt

ax = ax or plt.gca()
ax = plt.gca()

# load results from file
data = utils.read_json_file(os.path.join(output, 'results.json'))
Expand Down Expand Up @@ -666,11 +670,9 @@ def plot_curve(output, no_stories, ax=None):
cmdline_arguments.stories,
cmdline_arguments.output)

story_n_path = os.path.join(cmdline_arguments.core, 'num_stories.p')

with io.open(story_n_path, 'rb') as story_n_file:
number_of_stories = pickle.load(story_n_file)
story_n_path = os.path.join(cmdline_arguments.core, 'num_stories.json')

number_of_stories = utils.read_json_file(story_n_path)
plot_curve(cmdline_arguments.output, number_of_stories)

logger.info("Finished evaluation")

0 comments on commit ecd6c06

Please sign in to comment.