From ecd6c066654af4ebab411b33f97c1d1611336464 Mon Sep 17 00:00:00 2001 From: Tom Bocklisch Date: Wed, 14 Nov 2018 13:01:33 +0100 Subject: [PATCH] improved documentation and args to train --- docs/evaluation.rst | 14 +-- docs/policies.rst | 2 +- rasa_core/agent.py | 9 +- rasa_core/channels/botframework.py | 5 +- rasa_core/channels/facebook.py | 23 +++-- rasa_core/channels/mattermost.py | 13 ++- rasa_core/channels/slack.py | 24 +++-- rasa_core/channels/webexteams.py | 5 +- rasa_core/cli/arguments.py | 48 ++++++--- rasa_core/config.py | 15 ++- rasa_core/domain.py | 11 +- rasa_core/evaluate.py | 24 +++-- rasa_core/events/__init__.py | 16 +-- rasa_core/featurizers.py | 11 +- rasa_core/policies/ensemble.py | 4 +- rasa_core/policies/fallback.py | 38 ++++--- rasa_core/policies/sklearn_policy.py | 45 ++++----- rasa_core/train.py | 146 ++++++++++++++++----------- rasa_core/training/dsl.py | 5 + rasa_core/visualize.py | 6 +- 20 files changed, 260 insertions(+), 204 deletions(-) diff --git a/docs/evaluation.rst b/docs/evaluation.rst index b5592c97ced..062566de3e4 100644 --- a/docs/evaluation.rst +++ b/docs/evaluation.rst @@ -19,8 +19,8 @@ by using the evaluate script: .. code-block:: bash - $ python -m rasa_core.evaluate -d models/dialogue \ - -s test_stories.md -o results + $ python -m rasa_core.evaluate --core models/dialogue \ + --stories test_stories.md -o results This will print the failed stories to ``results/failed_stories.md``. @@ -34,7 +34,7 @@ incorrect action was predicted instead. The full list of options for the script is: -.. program-output:: python -m rasa_core.evaluate default -h +.. program-output:: python -m rasa_core.evaluate default --help .. _end_to_end_evaluation: @@ -77,8 +77,9 @@ the full end-to-end evaluation command is this: .. code-block:: bash - $ python -m rasa_core.evaluate default -d models/dialogue --nlu models/nlu/current \ - -s e2e_stories.md --e2e + $ python -m rasa_core.evaluate default --core models/dialogue \ + --nlu models/nlu/current \ + --stories e2e_stories.md --e2e .. note:: @@ -118,7 +119,8 @@ mode to evaluate the models you just trained: .. code-block:: bash - $ python -m rasa_core.evaluate compare -s stories_folder -d comparison_models \ + $ python -m rasa_core.evaluate compare --stories stories_folder \ + --core comparison_models \ -o comparison_results This will evaluate each of the models on the training set, and plot some graphs diff --git a/docs/policies.rst b/docs/policies.rst index ca59eea23ec..a6294ea87f1 100644 --- a/docs/policies.rst +++ b/docs/policies.rst @@ -37,7 +37,7 @@ Default configuration --------------------- By default, we try to provide you with a good set of configuration values -and policies that suite most people. But you are encouraged to modify +and policies that suit most people. But you are encouraged to modify these to your needs: .. literalinclude:: ../rasa_core/default_config.yml diff --git a/rasa_core/agent.py b/rasa_core/agent.py index e93756fff68..fc23977a5fd 100644 --- a/rasa_core/agent.py +++ b/rasa_core/agent.py @@ -448,7 +448,7 @@ def _max_history(self): return max(max_histories, default=0) - def _are_all_featurizes_using_a_max_history(self): + def _are_all_featurizers_using_a_max_history(self): """Check if all featurizers are MaxHistoryTrackerFeaturizer.""" for policy in self.policy_ensemble.policies: @@ -477,7 +477,7 @@ def load_data(self, # automatically detect unique_last_num_states # if it was not set and # if all featurizers are MaxHistoryTrackerFeaturizer - if self._are_all_featurizes_using_a_max_history(): + if self._are_all_featurizers_using_a_max_history(): unique_last_num_states = max_history elif unique_last_num_states < max_history: # possibility of data loss @@ -503,8 +503,9 @@ def train(self, # type: (...) -> None """Train the policies / policy ensemble using dialogue data from file. - :param training_trackers: trackers to train on - :param kwargs: additional arguments passed to the underlying ML + Args: + training_trackers: trackers to train on + **kwargs: additional arguments passed to the underlying ML trainer (e.g. keras parameters) """ diff --git a/rasa_core/channels/botframework.py b/rasa_core/channels/botframework.py index 171e845474b..81e1f88314c 100644 --- a/rasa_core/channels/botframework.py +++ b/rasa_core/channels/botframework.py @@ -149,8 +149,9 @@ def __init__(self, app_id, app_password): # type: (Text, Text) -> None """Create a Bot Framework input channel. - :param app_id: Bot Framework's API id - :param app_password: Bot Framework application secret + Args: + app_id: Bot Framework's API id + app_password: Bot Framework application secret """ self.app_id = app_id diff --git a/rasa_core/channels/facebook.py b/rasa_core/channels/facebook.py index 0823983360e..2bae783cd23 100644 --- a/rasa_core/channels/facebook.py +++ b/rasa_core/channels/facebook.py @@ -236,10 +236,12 @@ def __init__(self, fb_verify, fb_secret, fb_access_token): messages. Details to setup: https://github.com/rehabstudio/fbmessenger#facebook-app-setup - :param fb_verify: FB Verification string - (can be chosen by yourself on webhook creation) - :param fb_secret: facebook application secret - :param fb_access_token: access token to post in the name of the FB page + + Args: + fb_verify: FB Verification string + (can be chosen by yourself on webhook creation) + fb_secret: facebook application secret + fb_access_token: access token to post in the name of the FB page """ self.fb_verify = fb_verify self.fb_secret = fb_secret @@ -282,12 +284,15 @@ def webhook(): @staticmethod def validate_hub_signature(app_secret, request_payload, hub_signature_header): - """Makes sure the incoming webhook requests are properly signed. + """Make sure the incoming webhook requests are properly signed. + + Args: + app_secret: Secret Key for application + request_payload: request body + hub_signature_header: X-Hub-Signature header sent with request - :param app_secret: Secret Key for application - :param request_payload: request body - :param hub_signature_header: X-Hub-Signature header sent with request - :return: boolean indicated that hub signature is validated + Returns: + bool: indicated that hub signature is validated """ # noinspection PyBroadException diff --git a/rasa_core/channels/mattermost.py b/rasa_core/channels/mattermost.py index d0e6f148516..2b214a1de3b 100644 --- a/rasa_core/channels/mattermost.py +++ b/rasa_core/channels/mattermost.py @@ -59,14 +59,13 @@ def __init__(self, url, team, user, pw): """Create a Mattermost input channel. Needs a couple of settings to properly authenticate and validate messages. - :param url: Your Mattermost team url including /v4 example - https://mysite.example.com/api/v4 - :param team: Your mattermost team name - - :param user: Your mattermost userid that will post messages - - :param pw: Your mattermost password for your user + Args: + url: Your Mattermost team url including /v4 example + https://mysite.example.com/api/v4 + team: Your mattermost team name + user: Your mattermost userid that will post messages + pw: Your mattermost password for your user """ self.url = url self.team = team diff --git a/rasa_core/channels/slack.py b/rasa_core/channels/slack.py index 20975cc9334..85618baa8ee 100644 --- a/rasa_core/channels/slack.py +++ b/rasa_core/channels/slack.py @@ -103,17 +103,19 @@ def __init__(self, slack_token, slack_channel=None, messages. Details to setup: https://github.com/slackapi/python-slackclient - :param slack_token: Your Slack Authentication token. You can find or - generate a test token - `here `_. - :param slack_channel: the string identifier for a channel to which - the bot posts, or channel name - (e.g. 'C1234ABC', 'bot-test' or '#bot-test') - If unset, messages will be sent back to the user they came from. - :param errors_ignore_retry: If error code given by slack - included in this list then it will ignore the event. - The code is listed here: - https://api.slack.com/events-api#errors + + Args: + slack_token: Your Slack Authentication token. You can find or + generate a test token + `here `_. + slack_channel: the string identifier for a channel to which + the bot posts, or channel name (e.g. 'C1234ABC', 'bot-test' + or '#bot-test') If unset, messages will be sent back + to the user they came from. + errors_ignore_retry: If error code given by slack + included in this list then it will ignore the event. + The code is listed here: + https://api.slack.com/events-api#errors """ self.slack_token = slack_token self.slack_channel = slack_channel diff --git a/rasa_core/channels/webexteams.py b/rasa_core/channels/webexteams.py index 3e78376b3cb..57360be6a63 100644 --- a/rasa_core/channels/webexteams.py +++ b/rasa_core/channels/webexteams.py @@ -61,8 +61,9 @@ def __init__(self, access_token, room=None): Needs a couple of settings to properly authenticate and validate messages. Details here https://developer.webex.com/authentication.html - :param access_token: Cisco WebexTeams bot access token. - :param room: the string identifier for a room to which the bot posts + Args: + access_token: Cisco WebexTeams bot access token. + room: the string identifier for a room to which the bot posts """ self.token = access_token self.room = room diff --git a/rasa_core/cli/arguments.py b/rasa_core/cli/arguments.py index 6282741d598..779851c5737 100644 --- a/rasa_core/cli/arguments.py +++ b/rasa_core/cli/arguments.py @@ -6,28 +6,50 @@ import pkg_resources -def add_config_arg(parser): +def add_config_arg(parser, nargs="*", **kwargs): """Add an argument to the parser to request a policy configuration.""" parser.add_argument( '-c', '--config', type=str, - nargs="*", + nargs=nargs, default=[pkg_resources.resource_filename(__name__, "../default_config.yml")], - help="Policy specification yaml file.") - return parser + help="Policy specification yaml file.", + **kwargs) -def add_domain_arg(parser): +def add_core_model_arg(parser, **kwargs): + """Add an argument to the parser to request a policy configuration.""" + + parser.add_argument( + '--core', + type=str, + help="Path to a pre-trained core model directory", + **kwargs) + + +def add_domain_arg(parser, required=True, **kwargs): """Add an argument to the parser to request a the domain file.""" parser.add_argument( '-d', '--domain', type=str, - required=True, - help="domain specification yaml file") - return parser + required=required, + help="Domain specification (yml file)", + **kwargs) + + +def add_output_arg(parser, + help_text, + required=True, + **kwargs): + parser.add_argument( + '-o', '--out', + type=str, + required=required, + help=help_text, + **kwargs) def add_model_and_story_group(parser, allow_pretrained_model=True): @@ -39,17 +61,13 @@ def add_model_and_story_group(parser, allow_pretrained_model=True): group.add_argument( '-s', '--stories', type=str, - help="file or folder containing the training stories") + help="File or folder containing stories") group.add_argument( '--url', type=str, help="If supplied, downloads a story file from a URL and " "trains on it. Fetches the data by sending a GET request " "to the supplied URL.") + if allow_pretrained_model: - group.add_argument( - '--core', - default=None, - help="path to load a pre-trained model instead of training (" - "for interactive mode only)") - return parser + add_core_model_arg(group) diff --git a/rasa_core/config.py b/rasa_core/config.py index f1c4104df3b..42b1a0d15fb 100644 --- a/rasa_core/config.py +++ b/rasa_core/config.py @@ -3,19 +3,26 @@ from __future__ import print_function from __future__ import unicode_literals -from typing import Optional, Text, Dict, Any, List +import os +import typing +from typing import Optional, Text, List from rasa_core import utils from rasa_core.policies import PolicyEnsemble +if typing.TYPE_CHECKING: + from rasa_core.policies import Policy + def load(config_file): - # type: (Optional[Text], Dict[Text, Any], int) -> List[Policy] + # type: (Optional[Text]) -> List[Policy] """Load policy data stored in the specified file.""" - if config_file: + if config_file and os.path.isfile(config_file): config_data = utils.read_yaml_file(config_file) else: - raise ValueError("You have to provide a config file") + raise ValueError("You have to provide a valid path to a config file. " + "The file '{}' could not be found." + "".format(os.path.abspath(config_file))) return PolicyEnsemble.from_dict(config_data) diff --git a/rasa_core/domain.py b/rasa_core/domain.py index 6808d5096a3..dc253d5ac2a 100644 --- a/rasa_core/domain.py +++ b/rasa_core/domain.py @@ -17,7 +17,7 @@ from rasa_core.trackers import DialogueStateTracker, SlotSet from rasa_core.utils import read_file, read_yaml_string, EndpointConfig from six import string_types -from typing import Dict, Any +from typing import Dict, Any, Tuple from typing import List from typing import Optional from typing import Text @@ -34,7 +34,7 @@ class InvalidDomain(Exception): def check_domain_sanity(domain): - """Makes sure the domain is properly configured. + """Make sure the domain is properly configured. Checks the settings and checks if there are duplicate actions, intents, slots and entities.""" @@ -46,11 +46,8 @@ def get_duplicates(my_items): if count > 1] def get_exception_message(duplicates): - """Returns a message given a list of error locations. - - Duplicates has the format of (duplicate_actions [List], name [Text]). - :param duplicates: - :return: """ + # type: (List[Tuple[List[Text], Text]]) -> Text + """Return a message given a list of error locations.""" msg = "" for d, name in duplicates: diff --git a/rasa_core/evaluate.py b/rasa_core/evaluate.py index 05a37639e4c..4f6be127c3b 100644 --- a/rasa_core/evaluate.py +++ b/rasa_core/evaluate.py @@ -71,10 +71,6 @@ def add_args_to_parser(parser): '-m', '--max_stories', type=int, help="maximum number of stories to test on") - parser.add_argument( - '-d', '--core', - type=str, - help="core model directory to evaluate") parser.add_argument( '-u', '--nlu', type=str, @@ -101,6 +97,8 @@ def add_args_to_parser(parser): "is thrown. This can be used to validate stories during " "tests, e.g. on travis.") + cli.arguments.add_core_model_arg(parser) + return parser @@ -599,11 +597,17 @@ def run_comparison_evaluation(models, stories_file, output): num_correct) -def plot_curve(output, no_stories, ax=None): - """Plot the results from run_comparison_evaluation.""" +def plot_curve(output, no_stories): + # type: (Text, List[int]) -> None + """Plot the results from run_comparison_evaluation. + + Args: + output: Output directory to save resulting plots to + no_stories: Number of stories per run + """ import matplotlib.pyplot as plt - ax = ax or plt.gca() + ax = plt.gca() # load results from file data = utils.read_json_file(os.path.join(output, 'results.json')) @@ -666,11 +670,9 @@ def plot_curve(output, no_stories, ax=None): cmdline_arguments.stories, cmdline_arguments.output) - story_n_path = os.path.join(cmdline_arguments.core, 'num_stories.p') - - with io.open(story_n_path, 'rb') as story_n_file: - number_of_stories = pickle.load(story_n_file) + story_n_path = os.path.join(cmdline_arguments.core, 'num_stories.json') + number_of_stories = utils.read_json_file(story_n_path) plot_curve(cmdline_arguments.output, number_of_stories) logger.info("Finished evaluation") diff --git a/rasa_core/events/__init__.py b/rasa_core/events/__init__.py index c75877e3174..274d31e024a 100644 --- a/rasa_core/events/__init__.py +++ b/rasa_core/events/__init__.py @@ -498,13 +498,15 @@ def __init__(self, action_name, trigger_date_time, name=None, kill_on_user_message=True, timestamp=None): """Creates the reminder - :param action_name: name of the action to be scheduled - :param trigger_date_time: date at which the execution of the action - should be triggered (either utc or with tz) - :param name: id of the reminder. if there are multiple reminders with - the same id only the last will be run - :param kill_on_user_message: ``True`` means a user message before the - trigger date will abort the reminder + Args: + action_name: name of the action to be scheduled + trigger_date_time: date at which the execution of the action + should be triggered (either utc or with tz) + name: id of the reminder. if there are multiple reminders with + the same id only the last will be run + kill_on_user_message: ``True`` means a user message before the + trigger date will abort the reminder + timestamp: creation date of the event """ self.action_name = action_name diff --git a/rasa_core/featurizers.py b/rasa_core/featurizers.py index db8598f1fc2..f7f224a242c 100644 --- a/rasa_core/featurizers.py +++ b/rasa_core/featurizers.py @@ -152,12 +152,13 @@ class LabelTokenizerSingleStateFeaturizer(SingleStateFeaturizer): bot action names into tokens and uses these tokens to create bag-of-words feature vectors. - :param Text split_symbol: - The symbol that separates words in intets and action names. + Args: + split_symbol: The symbol that separates words in + intets and action names. - :param bool use_shared_vocab: - The flag that specifies if to create the same vocabulary for - user intents and bot actions.""" + use_shared_vocab: The flag that specifies if to create + the same vocabulary for user intents and bot actions. + """ def __init__(self, use_shared_vocab=False, split_symbol='_'): # type: (bool, Text) -> None diff --git a/rasa_core/policies/ensemble.py b/rasa_core/policies/ensemble.py index fbe2ec46149..eba3939b8ca 100644 --- a/rasa_core/policies/ensemble.py +++ b/rasa_core/policies/ensemble.py @@ -65,11 +65,11 @@ def train(self, training_trackers, domain, **kwargs): if training_trackers: for policy in self.policies: policy.train(training_trackers, domain, **kwargs) - self.training_trackers = training_trackers - self.date_trained = datetime.now().strftime('%Y%m%d-%H%M%S') else: logger.info("Skipped training, because there are no " "training samples.") + self.training_trackers = training_trackers + self.date_trained = datetime.now().strftime('%Y%m%d-%H%M%S') def probabilities_using_best_policy(self, tracker, domain): # type: (DialogueStateTracker, Domain) -> Tuple[List[float], Text] diff --git a/rasa_core/policies/fallback.py b/rasa_core/policies/fallback.py index 470757999b8..467c8d608b6 100644 --- a/rasa_core/policies/fallback.py +++ b/rasa_core/policies/fallback.py @@ -3,17 +3,15 @@ from __future__ import print_function from __future__ import unicode_literals +import json import logging import os -import json -import io import typing - from typing import Any, List, Text from rasa_core import utils -from rasa_core.policies.policy import Policy from rasa_core.constants import FALLBACK_SCORE +from rasa_core.policies.policy import Policy logger = logging.getLogger(__name__) @@ -23,23 +21,11 @@ class FallbackPolicy(Policy): - """Policy which executes a fallback action if NLU confidence is low - or no other policy has a high-confidence prediction. - - :param float nlu_threshold: - minimum threshold for NLU confidence. - If intent prediction confidence is lower than this, - predict fallback action with confidence 1.0. + """Policy which predicts fallback actions. - :param float core_threshold: - if NLU confidence threshold is met, - predict fallback action with confidence `core_threshold`. - If this is the highest confidence in the ensemble, - the fallback action will be executed. - - :param Text fallback_action_name: - name of the action to execute as a fallback. - """ + A fallback can be triggered by a low confidence score on a + NLU prediction or by a low confidence score on an action + prediction. """ @staticmethod def _standard_featurizer(): @@ -51,6 +37,18 @@ def __init__(self, fallback_action_name="action_default_fallback" # type: Text ): # type: (...) -> None + """Create a new Fallback policy. + + Args: + core_threshold: if NLU confidence threshold is met, + predict fallback action with confidence `core_threshold`. + If this is the highest confidence in the ensemble, + the fallback action will be executed. + nlu_threshold: minimum threshold for NLU confidence. + If intent prediction confidence is lower than this, + predict fallback action with confidence 1.0. + fallback_action_name: name of the action to execute as a fallback + """ super(FallbackPolicy, self).__init__() diff --git a/rasa_core/policies/sklearn_policy.py b/rasa_core/policies/sklearn_policy.py index a556537a330..12dd904e2f5 100644 --- a/rasa_core/policies/sklearn_policy.py +++ b/rasa_core/policies/sklearn_policy.py @@ -32,33 +32,7 @@ class SklearnPolicy(Policy): - """Use an sklearn classifier to train a policy. - - Supports cross validation and grid search. - - :param sklearn.base.ClassifierMixin model: - The sklearn model or model pipeline. - - :param cv: - If *cv* is not None, perform a cross validation on the training - data. *cv* should then conform to the sklearn standard - (e.g. *cv=5* for a 5-fold cross-validation). - - :param dict param_grid: - If *param_grid* is not None and *cv* is given, a grid search on - the given *param_grid* is performed - (e.g. *param_grid={'n_estimators': [50, 100]}*). - - :param scoring: - Scoring strategy, using the sklearn standard. - - :param sklearn.base.TransformerMixin label_encoder: - Encoder for the labels. Must implement an *inverse_transform* - method. - - :param bool shuffle: - Whether to shuffle training data. - """ + """Use an sklearn classifier to train a policy.""" def __init__( self, @@ -71,6 +45,23 @@ def __init__( shuffle=True, # type: bool ): # type: (...) -> None + """Create a new sklearn policy. + + Args: + featurizer: Featurizer used to convert the training data into + vector format. + model: The sklearn model or model pipeline. + param_grid: If *param_grid* is not None and *cv* is given, + a grid search on the given *param_grid* is performed + (e.g. *param_grid={'n_estimators': [50, 100]}*). + cv: If *cv* is not None, perform a cross validation on + the training data. *cv* should then conform to the + sklearn standard (e.g. *cv=5* for a 5-fold cross-validation). + scoring: Scoring strategy, using the sklearn standard. + label_encoder: Encoder for the labels. Must implement an + *inverse_transform* method. + shuffle: Whether to shuffle training data. + """ if featurizer: if not isinstance(featurizer, MaxHistoryTrackerFeaturizer): diff --git a/rasa_core/train.py b/rasa_core/train.py index 391e0ab7260..73b1dd807cd 100644 --- a/rasa_core/train.py +++ b/rasa_core/train.py @@ -6,10 +6,9 @@ from builtins import str import argparse -import io import logging import os -import pickle +import tempfile from rasa_core import config, cli from rasa_core import utils @@ -30,29 +29,36 @@ def create_argument_parser(): """Parse all the command line arguments for the training script.""" parser = argparse.ArgumentParser( - description='trains a dialogue model') + description='Train a dialogue model for Rasa Core. ' + 'The training will use your conversations ' + 'in the story training data format and ' + 'your domain definition to train a dialogue ' + 'model to predict a bots actions.') parent_parser = argparse.ArgumentParser(add_help=False) - add_args_to_parser(parent_parser) - cli.arguments.add_domain_arg(parent_parser) - cli.arguments.add_config_arg(parent_parser) - cli.arguments.add_model_and_story_group(parent_parser) - utils.add_logging_option_arguments(parent_parser) - subparsers = parser.add_subparsers(help='mode', dest='mode') - subparsers.add_parser('default', - help='default mode: train a dialogue' - ' model', - parents=[parent_parser]) - compare_parser = subparsers.add_parser('compare', - help='compare mode: train multiple ' - 'dialogue models to compare ' - 'policies', - parents=[parent_parser]) - interactive_parser = subparsers.add_parser('interactive', - help='teach the bot with ' - 'interactive learning', - parents=[parent_parser]) + add_general_args(parent_parser) + + subparsers = parser.add_subparsers( + help='Training mode of core.', + dest='mode') + subparsers.required = True + + train_parser = subparsers.add_parser( + 'default', + help='train a dialogue model', + parents=[parent_parser]) + compare_parser = subparsers.add_parser( + 'compare', + help='train multiple dialogue models to compare ' + 'policies', + parents=[parent_parser]) + interactive_parser = subparsers.add_parser( + 'interactive', + help='teach the bot with interactive learning', + parents=[parent_parser]) + add_compare_args(compare_parser) add_interactive_args(interactive_parser) + add_train_args(train_parser) return parser @@ -70,6 +76,20 @@ def add_compare_args(parser): default=3, help="Number of runs for experiments") + cli.arguments.add_output_arg( + parser, + help_text="directory to persist the trained model in", + required=True) + cli.arguments.add_config_arg( + parser, + nargs="*") + cli.arguments.add_model_and_story_group( + parser, + allow_pretrained_model=False) + cli.arguments.add_domain_arg( + parser, + required=True) + def add_interactive_args(parser): parser.add_argument( @@ -93,13 +113,38 @@ def add_interactive_args(parser): action='store_true', help="retrain the model immediately based on feedback.") - -def add_args_to_parser(parser): - parser.add_argument( - '-o', '--out', - type=str, - required=False, - help="directory to persist the trained model in") + cli.arguments.add_output_arg( + parser, + help_text="directory to persist the trained model in", + required=False) + cli.arguments.add_config_arg( + parser, + nargs=1) + cli.arguments.add_model_and_story_group( + parser, + allow_pretrained_model=True) + cli.arguments.add_domain_arg( + parser, + required=False) + + +def add_train_args(parser): + cli.arguments.add_config_arg( + parser, + nargs=1) + cli.arguments.add_output_arg( + parser, + help_text="directory to persist the trained model in", + required=True) + cli.arguments.add_model_and_story_group( + parser, + allow_pretrained_model=False) + cli.arguments.add_domain_arg( + parser, + required=True) + + +def add_general_args(parser): parser.add_argument( '--augmentation', type=int, @@ -118,7 +163,7 @@ def add_args_to_parser(parser): "and their connections between story blocks in a " "file called `story_blocks_connections.pdf`.") - return parser + utils.add_logging_option_arguments(parser) def train_dialogue_model(domain_file, stories_file, output_path, @@ -220,14 +265,6 @@ def get_no_of_stories(story_file, domain): def do_default_training(cmdline_args, stories, additional_arguments): """Train a model.""" - if not cmdline_args.out: - raise ValueError("you must provide a path where the model " - "will be saved using -o / --out") - - if (isinstance(cmdline_args.config, list) and - len(cmdline_args.config) > 1): - raise ValueError("You can only pass one config file at a time") - train_dialogue_model(domain_file=cmdline_args.domain, stories_file=stories, output_path=cmdline_args.out, @@ -237,10 +274,6 @@ def do_default_training(cmdline_args, stories, additional_arguments): def do_compare_training(cmdline_args, stories, additional_arguments): - if not cmdline_args.out: - raise ValueError("you must provide a path where the model " - "will be saved using -o / --out") - train_comparison_models(stories, cmdline_args.domain, cmdline_args.out, @@ -258,8 +291,8 @@ def do_compare_training(cmdline_args, stories, additional_arguments): story_range = [no_stories - round((x / 100.0) * no_stories) for x in cmdline_args.percentages] - with io.open(os.path.join(cmdline_args.out, 'num_stories.p'), 'wb') as f: - pickle.dump(story_range, f) + story_n_path = os.path.join(cmdline_args.out, 'num_stories.json') + utils.dump_obj_as_json_to_file(story_n_path, story_range) def do_interactive_learning(cmdline_args, stories, additional_arguments): @@ -267,15 +300,12 @@ def do_interactive_learning(cmdline_args, stories, additional_arguments): _interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu, _endpoints.nlu) - if (isinstance(cmdline_args.config, list) and - len(cmdline_args.config) > 1): - raise ValueError("You can only pass one config file at a time") + if cmdline_args.core: + if cmdline_args.finetune: + raise ValueError("--core can only be used without --finetune flag.") - if cmdline_args.core and cmdline_args.finetune: - raise ValueError("--core can only be used without --finetune flag.") - elif cmdline_args.core: - logger.info("loading a pre-trained model. " - "all training-related parameters will be ignored") + logger.info("Loading a pre-trained model. This means that " + "all training-related parameters will be ignored.") _broker = PikaProducer.from_endpoint_config(_endpoints.event_broker) _tracker_store = TrackerStore.find_tracker_store( @@ -289,13 +319,14 @@ def do_interactive_learning(cmdline_args, stories, additional_arguments): tracker_store=_tracker_store, action_endpoint=_endpoints.action) else: - if not cmdline_args.out: - raise ValueError("you must provide a path where the model " - "will be saved using -o / --out") + if cmdline_args.out: + model_directory = cmdline_args.out + else: + model_directory = tempfile.mkdtemp(suffix="_core_model") _agent = train_dialogue_model(cmdline_args.domain, stories, - cmdline_args.out, + model_directory, _interpreter, _endpoints, cmdline_args.dump_stories, @@ -315,9 +346,6 @@ def do_interactive_learning(cmdline_args, stories, additional_arguments): arg_parser = create_argument_parser() set_default_subparser(arg_parser, 'default') cmdline_arguments = arg_parser.parse_args() - if not cmdline_arguments.mode: - raise ValueError("You must specify the mode you want training to run " - "in. The options are: (default|compare|interactive)") additional_args = _additional_arguments(cmdline_arguments) utils.configure_colored_logging(cmdline_arguments.loglevel) diff --git a/rasa_core/training/dsl.py b/rasa_core/training/dsl.py index 0554d1ad072..11d12fb1163 100644 --- a/rasa_core/training/dsl.py +++ b/rasa_core/training/dsl.py @@ -165,6 +165,11 @@ def read_from_folder(resource_name, domain, interpreter=RegexInterpreter(), exclusion_percentage=None): """Given a path reads all contained story files.""" + if not os.path.exists(resource_name): + raise ValueError("Story file or folder could not be found. Make " + "sure '{}' exists and points to a story folder " + "or file.".format(os.path.abspath(resource_name))) + story_steps = [] for f in nlu_utils.list_files(resource_name): steps = StoryFileReader.read_from_file(f, domain, interpreter, diff --git a/rasa_core/visualize.py b/rasa_core/visualize.py index 4c5923bfd4a..a82aac5f114 100644 --- a/rasa_core/visualize.py +++ b/rasa_core/visualize.py @@ -41,7 +41,7 @@ def create_argument_parser(): utils.add_logging_option_arguments(parser) - cli.arguments.add_config_arg(parser) + cli.arguments.add_config_arg(parser, nargs=1) cli.arguments.add_domain_arg(parser) cli.arguments.add_model_and_story_group(parser, allow_pretrained_model=False) @@ -54,10 +54,6 @@ def create_argument_parser(): utils.configure_colored_logging(cmdline_arguments.loglevel) - if (isinstance(cmdline_arguments.config, list) and - len(cmdline_arguments.config) > 1): - raise ValueError("You can only pass one config file at a time") - policies = config.load(cmdline_arguments.config[0]) agent = Agent(cmdline_arguments.domain, policies=policies)