diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 07c609b4eb7..161ea695e9d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -52,6 +52,7 @@ Added - add ``Form`` and ``FormValidation`` events - add ``REQUESTED_SLOT`` constant - add ability to read ``action_listen`` from stories +- added train/eval scripts to compare policies Changed ------- @@ -65,10 +66,16 @@ Changed - forms were completely reworked, see changelog in ``rasa_core_sdk`` - state featurization if some form is active changed - ``Domain`` raises ``InvalidDomain`` exception +- interactive learning is now started with rasa_core.train interactive +- passing a policy config file to train a model is now required +- flags for output of evaluate script have been merged to one flag ``--output`` + where you provide a folder where any output from the script should be stored Removed ------- - removed graphviz dependency +- policy config related flags in training script (see migration guide) + Fixed ----- diff --git a/data/test_config/max_hist_config.yml b/data/test_config/max_hist_config.yml new file mode 100644 index 00000000000..04acbe38103 --- /dev/null +++ b/data/test_config/max_hist_config.yml @@ -0,0 +1,5 @@ +policies: + - name: MemoizationPolicy + max_history: 5 + - name: KerasPolicy + max_history: 5 diff --git a/data/test_config/no_max_hist_config.yml b/data/test_config/no_max_hist_config.yml new file mode 100644 index 00000000000..5ea9e2c0052 --- /dev/null +++ b/data/test_config/no_max_hist_config.yml @@ -0,0 +1,3 @@ +policies: + - name: MemoizationPolicy + - name: KerasPolicy diff --git a/default_config.yml b/default_config.yml new file mode 100644 index 00000000000..021f4b1067b --- /dev/null +++ b/default_config.yml @@ -0,0 +1,9 @@ +policies: + - name: KerasPolicy + epochs: 100 + max_history: 5 + - name: FallbackPolicy + fallback_action_name: 'action_default_fallback' + - name: MemoizationPolicy + max_history: 5 + - name: FormPolicy diff --git a/docs/evaluation.rst b/docs/evaluation.rst index 7333f625387..b5592c97ced 100644 --- a/docs/evaluation.rst +++ b/docs/evaluation.rst @@ -20,21 +20,21 @@ by using the evaluate script: .. code-block:: bash $ python -m rasa_core.evaluate -d models/dialogue \ - -s test_stories.md -o matrix.pdf --failed failed_stories.md + -s test_stories.md -o results -This will print the failed stories to ``failed_stories.md``. +This will print the failed stories to ``results/failed_stories.md``. We count any story as `failed` if at least one of the actions was predicted incorrectly. In addition, this will save a confusion matrix to a file called -``matrix.pdf``. The confusion matrix shows, for each action in your -domain, how often that action was predicted, and how often an +``results/story_confmat.pdf``. The confusion matrix shows, for each action in +your domain, how often that action was predicted, and how often an incorrect action was predicted instead. The full list of options for the script is: -.. program-output:: python -m rasa_core.evaluate -h +.. program-output:: python -m rasa_core.evaluate default -h .. _end_to_end_evaluation: @@ -77,7 +77,7 @@ the full end-to-end evaluation command is this: .. code-block:: bash - $ python -m rasa_core.evaluate -d models/dialogue --nlu models/nlu/current \ + $ python -m rasa_core.evaluate default -d models/dialogue --nlu models/nlu/current \ -s e2e_stories.md --e2e .. note:: @@ -98,14 +98,40 @@ your bot, so you don't just want to throw some away to use as a test set. Rasa Core has some scripts to help you choose and fine-tune your policy. Once you are happy with it, you can then train your final policy on your -full data set. To do this, split your training data into multiple files -in a single directory. You can then use the ``train_paper`` script to -train multiple policies on the same data. You can choose one of the -files to be partially excluded. This means that Rasa Core will be -trained multiple times, with 0, 5, 25, 50, 70, 90, 95, and 100% of -the stories in that file removed from the training data. By evaluating -on the full set of stories, you can measure how well Rasa Core is -predicting the held-out stories. +full data set. To do this, you first have to train models for your different +policies. Create two (or more) policy config files of the policies you want to +compare (containing only one policy each), and then use the ``compare`` mode of +the train script to train your models: + +.. code-block:: bash + + $ python -m rasa_core.train compare -c policy_config1.yml policy_config2.yml \ + -d domain.yml -s stories_folder -o comparison_models --runs 3 --percentages \ + 0 5 25 50 70 90 95 + +For each policy configuration provided, Rasa Core will be trained multiple times +with 0, 5, 25, 50, 70 and 95% of your training stories excluded from the training +data. This is done for multiple runs, to ensure consistent results. + +Once this script has finished, you can now use the evaluate script in compare +mode to evaluate the models you just trained: + +.. code-block:: bash + + $ python -m rasa_core.evaluate compare -s stories_folder -d comparison_models \ + -o comparison_results + +This will evaluate each of the models on the training set, and plot some graphs +to show you which policy is best. By evaluating on the full set of stories, you +can measure how well Rasa Core is predicting the held-out stories. + +If you're not sure which policies to compare, we'd recommend trying out the +``EmbeddingPolicy`` and the ``KerasPolicy`` to see which one works better for +you. + +.. note:: + This training process can take a long time, so we'd suggest letting it run + somewhere in the background where it can't be interrupted Evaluating stories over http @@ -129,5 +155,3 @@ you may do so by adding the ``e2e=true`` query parameter: $ curl --data-binary @eval_stories.md "localhost:5005/evaluate?e2e=true" | python -m json.tool .. include:: feedback.inc - - diff --git a/docs/interactive_learning.rst b/docs/interactive_learning.rst index 7eb4721730d..bb85c46b4ec 100644 --- a/docs/interactive_learning.rst +++ b/docs/interactive_learning.rst @@ -17,7 +17,7 @@ Some people call this `Software 2.0 List[DialogueStateTracker] """Load training data from a resource.""" @@ -478,7 +479,8 @@ def load_data(self, remove_duplicates, unique_last_num_states, augmentation_factor, tracker_limit, use_story_concatenation, - debug_plots) + debug_plots, + exclusion_percentage=exclusion_percentage) def train(self, training_trackers, # type: List[DialogueStateTracker] diff --git a/rasa_core/config.py b/rasa_core/config.py index d1dd51701e8..36b798088f4 100644 --- a/rasa_core/config.py +++ b/rasa_core/config.py @@ -5,62 +5,19 @@ from typing import Optional, Text, Dict, Any, List -from rasa_core.constants import ( - DEFAULT_NLU_FALLBACK_THRESHOLD, - DEFAULT_CORE_FALLBACK_THRESHOLD, DEFAULT_FALLBACK_ACTION) from rasa_core import utils -from rasa_core.policies import PolicyEnsemble, Policy +from rasa_core.policies import PolicyEnsemble -def load(config_file, fallback_args, max_history): +def load(config_file): # type: (Optional[Text], Dict[Text, Any], int) -> List[Policy] """Load policy data stored in the specified file. fallback_args and max_history are typically command line arguments. They take precedence over the arguments specified in the config yaml. """ - - if config_file is None: - return PolicyEnsemble.default_policies(fallback_args, max_history) - - config_data = utils.read_yaml_file(config_file) - config_data = handle_precedence_and_defaults( - config_data, fallback_args, max_history) + if config_file: + config_data = utils.read_yaml_file(config_file) + else: + raise ValueError("You have to provide a config file") return PolicyEnsemble.from_dict(config_data) - - -def handle_precedence_and_defaults(config_data, fallback_args, max_history): - # type: (Dict[Text, Any], Dict[Text, Any], int) -> Dict[Text, Any] - - for policy in config_data.get('policies'): - - if policy.get('name') == 'FallbackPolicy' and fallback_args is not None: - set_fallback_args(policy, fallback_args) - - elif policy.get('name') in {'KerasPolicy', 'MemoizationPolicy'}: - set_arg(policy, "max_history", max_history, 3) - - return config_data - - -def set_arg(data_dict, argument, value, default): - - if value is not None: - data_dict[argument] = value - elif data_dict.get(argument) is None: - data_dict[argument] = default - - return data_dict - - -def set_fallback_args(policy, fallback_args): - - set_arg(policy, "nlu_threshold", - fallback_args.get("nlu_threshold"), - DEFAULT_NLU_FALLBACK_THRESHOLD) - set_arg(policy, "core_threshold", - fallback_args.get("core_threshold"), - DEFAULT_CORE_FALLBACK_THRESHOLD) - set_arg(policy, "fallback_action_name", - fallback_args.get("fallback_action_name"), - DEFAULT_FALLBACK_ACTION) diff --git a/rasa_core/evaluate.py b/rasa_core/evaluate.py index ff2ffe5c2d8..200c8cf9e38 100644 --- a/rasa_core/evaluate.py +++ b/rasa_core/evaluate.py @@ -14,6 +14,11 @@ from sklearn.exceptions import UndefinedMetricWarning from tqdm import tqdm +import os +import json +import numpy as np +import pickle +from collections import defaultdict from rasa_core import training from rasa_core import utils @@ -23,8 +28,12 @@ from rasa_core.policies import SimplePolicyEnsemble from rasa_core.trackers import DialogueStateTracker from rasa_core.training.generator import TrainingDataGenerator -from rasa_core.utils import AvailableEndpoints, pad_list_to_size +from rasa_core.utils import (AvailableEndpoints, pad_list_to_size, + set_default_subparser) + +from rasa_nlu import utils as nlu_utils from rasa_nlu.evaluate import plot_confusion_matrix, get_evaluation_metrics +from rasa_core.events import md_format_message logger = logging.getLogger(__name__) @@ -41,6 +50,24 @@ def create_argument_parser(): parser = argparse.ArgumentParser( description='evaluates a dialogue model') + parent_parser = argparse.ArgumentParser(add_help=False) + add_args_to_parser(parent_parser) + utils.add_logging_option_arguments(parent_parser) + subparsers = parser.add_subparsers(help='mode', dest='mode') + subparsers.add_parser('default', + help='default mode: evaluate a dialogue' + ' model', + parents=[parent_parser]) + subparsers.add_parser('compare', + help='compare mode: evaluate multiple' + ' dialogue models to compare ' + 'policies', + parents=[parent_parser]) + + return parser + + +def add_args_to_parser(parser): parser.add_argument( '-s', '--stories', type=str, @@ -52,9 +79,8 @@ def create_argument_parser(): help="maximum number of stories to test on") parser.add_argument( '-d', '--core', - required=True, type=str, - help="core model to run with the server") + help="core model directory to evaluate") parser.add_argument( '-u', '--nlu', type=str, @@ -62,21 +88,14 @@ def create_argument_parser(): parser.add_argument( '-o', '--output', type=str, - nargs="?", - const="story_confmat.pdf", - help="output path for the created evaluation plot. If not " - "specified, no plot will be generated.") + default="results", + help="output path for the any files created from the evaluation") parser.add_argument( '--e2e', '--end-to-end', action='store_true', help="Run an end-to-end evaluation for combined action and " "intent prediction. Requires a story file in end-to-end " "format.") - parser.add_argument( - '--failed', - type=str, - default="failed_stories.md", - help="output path for the failed stories") parser.add_argument( '--endpoints', default=None, @@ -88,7 +107,6 @@ def create_argument_parser(): "is thrown. This can be used to validate stories during " "tests, e.g. on travis.") - utils.add_logging_option_arguments(parser) return parser @@ -238,12 +256,12 @@ def __init__(self, input_channel) def as_story_string(self): - correct_message = _md_format_message(self.text, - self.intent, - self.entities) - predicted_message = _md_format_message(self.text, - self.predicted_intent, - self.predicted_entities) + correct_message = md_format_message(self.text, + self.intent, + self.entities) + predicted_message = md_format_message(self.text, + self.predicted_intent, + self.predicted_entities) return ("{}: {} " "").format(self.intent.get("name"), correct_message, @@ -409,9 +427,10 @@ def collect_story_predictions( story_eval_store = EvaluationStore() failed = [] correct_dialogues = [] + num_stories = len(completed_trackers) logger.info("Evaluating {} stories\n" - "Progress:".format(len(completed_trackers))) + "Progress:".format(num_stories)) action_list = [] @@ -443,19 +462,19 @@ def collect_story_predictions( in_training_data_fraction, include_report=False) - return StoryEvalution(evaluation_store=story_eval_store, - failed_stories=failed, - action_list=action_list, - in_training_data_fraction=in_training_data_fraction) - + return (StoryEvalution(evaluation_store=story_eval_store, + failed_stories=failed, + action_list=action_list, + in_training_data_fraction=in_training_data_fraction), + num_stories) -def log_failed_stories(failed, failed_output): - """Takes stories as a list of dicts""" - if not failed_output: +def log_failed_stories(failed, out_directory): + """Take stories as a list of dicts.""" + if not out_directory: return - - with io.open(failed_output, 'w', encoding="utf-8") as f: + with io.open(os.path.join(out_directory, 'failed_stories.md'), 'w', + encoding="utf-8") as f: if len(failed) == 0: f.write("") else: @@ -466,8 +485,7 @@ def log_failed_stories(failed, failed_output): def run_story_evaluation(resource_name, agent, max_stories=None, - out_file_stories=None, - out_file_plot=None, + out_directory=None, fail_on_prediction_errors=False, use_e2e=False): """Run the evaluation of the stories, optionally plots the results.""" @@ -475,9 +493,9 @@ def run_story_evaluation(resource_name, agent, completed_trackers = _generate_trackers(resource_name, agent, max_stories, use_e2e) - story_evaluation = collect_story_predictions(completed_trackers, agent, - fail_on_prediction_errors, - use_e2e) + story_evaluation, _ = collect_story_predictions(completed_trackers, agent, + fail_on_prediction_errors, + use_e2e) evaluation_store = story_evaluation.evaluation_store @@ -488,14 +506,14 @@ def run_story_evaluation(resource_name, agent, evaluation_store.serialise_predictions() ) - if out_file_plot: + if out_directory: plot_story_evaluation(evaluation_store.action_targets, evaluation_store.action_predictions, report, precision, f1, accuracy, story_evaluation.in_training_data_fraction, - out_file_plot) + out_directory) - log_failed_stories(story_evaluation.failed_stories, out_file_stories) + log_failed_stories(story_evaluation.failed_stories, out_directory) return { "report": report, @@ -530,7 +548,7 @@ def log_evaluation_table(golds, name, def plot_story_evaluation(test_y, predictions, report, precision, f1, accuracy, in_training_data_fraction, - out_file): + out_directory): """Plot the results of story evaluation""" from sklearn.metrics import confusion_matrix from sklearn.utils.multiclass import unique_labels @@ -549,29 +567,109 @@ def plot_story_evaluation(test_y, predictions, fig = plt.gcf() fig.set_size_inches(int(20), int(20)) - fig.savefig(out_file, bbox_inches='tight') + fig.savefig(os.path.join(out_directory, "story_confmat.pdf"), + bbox_inches='tight') + + +def run_comparison_evaluation(models, stories, output): + # type: (Text, Text, Text) -> None + """Evaluates multiple trained models on a test set""" + + num_correct = defaultdict(list) + + for run in nlu_utils.list_subdirectories(models): + num_correct_run = defaultdict(list) + + for model in sorted(nlu_utils.list_subdirectories(run)): + logger.info("Evaluating model {}".format(model)) + + agent = Agent.load(model) + + completed_trackers = _generate_trackers(stories, agent) + + story_eval_store, no_of_stories = \ + collect_story_predictions(completed_trackers, + agent) + + failed_stories = story_eval_store.failed_stories + policy_name = ''.join([i for i in os.path.basename(model) if not + i.isdigit()]) + num_correct_run[policy_name].append(no_of_stories - + len(failed_stories)) + + for k, v in num_correct_run.items(): + num_correct[k].append(v) + + utils.dump_obj_as_json_to_file(os.path.join(output, 'results.json'), + num_correct) + + +def plot_curve(output, no_stories, ax=None, **kwargs): + """Plot the results from run_comparison_evaluation.""" + import matplotlib.pyplot as plt + + ax = ax or plt.gca() + + # load results from file + data = utils.read_json_file(os.path.join(output, 'results.json')) + x = no_stories + + # compute mean of all the runs for keras/embed policies + for label in data.keys(): + if len(data[label]) == 0: + continue + mean = np.mean(data[label], axis=0) + std = np.std(data[label], axis=0) + ax.plot(x, mean, label=label, marker='.') + ax.fill_between(x, + [m-s for m, s in zip(mean, std)], + [m+s for m, s in zip(mean, std)], + color='#6b2def', + alpha=0.2) + ax.legend(loc=4) + ax.set_xlabel("Number of stories present during training") + ax.set_ylabel("Number of correct test stories") + plt.savefig(os.path.join(output, 'model_comparison_graph.pdf'), + format='pdf') + plt.show() if __name__ == '__main__': # Running as standalone python application arg_parser = create_argument_parser() + set_default_subparser(arg_parser, 'default') cmdline_args = arg_parser.parse_args() logging.basicConfig(level=cmdline_args.loglevel) _endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints) - _interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu, - _endpoints.nlu) + if cmdline_args.output: + nlu_utils.create_dir(cmdline_args.output) + + if not cmdline_args.core: + raise ValueError("you must provide a core model directory to evaluate " + "using -d / --core") + if cmdline_args.mode == 'default': + + _interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu, + _endpoints.nlu) + + _agent = Agent.load(cmdline_args.core, + interpreter=_interpreter) + run_story_evaluation(cmdline_args.stories, + _agent, + cmdline_args.max_stories, + cmdline_args.output, + cmdline_args.fail_on_prediction_errors, + cmdline_args.e2e) + + elif cmdline_args.mode == 'compare': + run_comparison_evaluation(cmdline_args.core, cmdline_args.stories, + cmdline_args.output) - _agent = Agent.load(cmdline_args.core, - interpreter=_interpreter) + no_stories = pickle.load(io.open(os.path.join(cmdline_args.core, + 'num_stories.p'), 'rb')) - run_story_evaluation(cmdline_args.stories, - _agent, - cmdline_args.max_stories, - cmdline_args.failed, - cmdline_args.output, - cmdline_args.fail_on_prediction_errors, - cmdline_args.e2e) + plot_curve(cmdline_args.output, no_stories) logger.info("Finished evaluation") diff --git a/rasa_core/policies/embedding_policy.py b/rasa_core/policies/embedding_policy.py index 1c5fc322181..661952e7643 100644 --- a/rasa_core/policies/embedding_policy.py +++ b/rasa_core/policies/embedding_policy.py @@ -154,7 +154,8 @@ def __init__( rnn_embed=None, # type: Optional[tf.Tensor] attn_embed=None, # type: Optional[tf.Tensor] copy_attn_debug=None, # type: Optional[tf.Tensor] - all_time_masks=None # type: Optional[tf.Tensor] + all_time_masks=None, # type: Optional[tf.Tensor] + **kwargs # type: Any ): # type: (...) -> None if featurizer: @@ -171,7 +172,7 @@ def __init__( except AttributeError: self.share_embedding = False - self._load_params() + self._load_params(**kwargs) # chrono initialization for forget bias self.characteristic_time = None @@ -909,10 +910,6 @@ def train(self, logger.debug('Started training embedding policy.') - if kwargs: - logger.debug("Config is updated with {}".format(kwargs)) - self._load_params(**kwargs) - # dealing with training data training_data = self.featurize_for_training(training_trackers, domain, diff --git a/rasa_core/policies/ensemble.py b/rasa_core/policies/ensemble.py index 6b872496a30..fbe2ec46149 100644 --- a/rasa_core/policies/ensemble.py +++ b/rasa_core/policies/ensemble.py @@ -17,18 +17,12 @@ import rasa_core from rasa_core import utils, training, constants -from rasa_core.constants import ( - DEFAULT_NLU_FALLBACK_THRESHOLD, - DEFAULT_CORE_FALLBACK_THRESHOLD, DEFAULT_FALLBACK_ACTION) from rasa_core.events import SlotSet, ActionExecuted, ActionExecutionRejected from rasa_core.exceptions import UnsupportedDialogueModelError -from rasa_core.featurizers import (MaxHistoryTrackerFeaturizer, - BinarySingleStateFeaturizer) -from rasa_core.policies.keras_policy import KerasPolicy +from rasa_core.featurizers import MaxHistoryTrackerFeaturizer from rasa_core.policies.fallback import FallbackPolicy from rasa_core.policies.memoization import (MemoizationPolicy, AugmentedMemoizationPolicy) -from rasa_core.policies.form_policy import FormPolicy from rasa_core.actions.action import ACTION_LISTEN_NAME @@ -208,39 +202,54 @@ def from_dict(cls, dictionary): for policy in dictionary.get('policies', []): policy_name = policy.pop('name') + if policy.get('featurizer'): + featurizer_func, featurizer_config = \ + cls.get_featurizer_from_dict(policy) - if policy_name == 'KerasPolicy': - policy_object = KerasPolicy(MaxHistoryTrackerFeaturizer( - BinarySingleStateFeaturizer(), - max_history=policy.get('max_history', 3))) - else: - constr_func = utils.class_from_module_path(policy_name) - policy_object = constr_func(**policy) + if featurizer_config.get('state_featurizer'): + state_featurizer_func, state_featurizer_config = \ + cls.get_featurizer_from_dict(featurizer_config) + + # override featurizer's state_featurizer + # with real state_featurizer class + featurizer_config['state_featurizer'] = ( + state_featurizer_func(**state_featurizer_config) + ) + + # override policy's featurizer with real featurizer class + policy['featurizer'] = featurizer_func(**featurizer_config) + + constr_func = utils.class_from_module_path(policy_name) + policy_object = constr_func(**policy) policies.append(policy_object) return policies - @classmethod - def default_policies(cls, fallback_args, max_history): - # type: (Dict[Text, Any], int) -> List[Policy] - """Load the default policy setup consisting of - FallbackPolicy, MemoizationPolicy and KerasPolicy.""" - - return [ - FallbackPolicy( - fallback_args.get("nlu_threshold", - DEFAULT_NLU_FALLBACK_THRESHOLD), - fallback_args.get("core_threshold", - DEFAULT_CORE_FALLBACK_THRESHOLD), - fallback_args.get("fallback_action_name", - DEFAULT_FALLBACK_ACTION)), - MemoizationPolicy( - max_history=max_history), - KerasPolicy( - MaxHistoryTrackerFeaturizer(BinarySingleStateFeaturizer(), - max_history=max_history)), - FormPolicy()] + def get_featurizer_from_dict(self, policy): + # policy can have only 1 featurizer + if len(policy['featurizer']) > 1: + raise InvalidPolicyConfig( + "policy can have only 1 featurizer") + featurizer_config = policy['featurizer'][0] + featurizer_name = featurizer_config.pop('name') + featurizer_func = utils.class_from_module_path(featurizer_name) + + return featurizer_func, featurizer_config + + def get_state_featurizer_from_dict(self, featurizer_config): + # featurizer can have only 1 state featurizer + if len(featurizer_config['state_featurizer']) > 1: + raise InvalidPolicyConfig( + "featurizer can have only 1 state featurizer") + state_featurizer_config = ( + featurizer_config['state_featurizer'][0] + ) + state_featurizer_name = state_featurizer_config.pop('name') + state_featurizer_func = utils.class_from_module_path( + state_featurizer_name) + + return state_featurizer_func, state_featurizer_config def continue_training(self, trackers, domain, **kwargs): # type: (List[DialogueStateTracker], Domain, Any) -> None @@ -311,3 +320,8 @@ def probabilities_using_best_policy(self, tracker, domain): logger.debug("Predicted next action using {}" "".format(best_policy_name)) return result, best_policy_name + + +class InvalidPolicyConfig(Exception): + """Exception that can be raised when policy config is not valid.""" + pass diff --git a/rasa_core/policies/keras_policy.py b/rasa_core/policies/keras_policy.py index 5de82926874..f14af5d7a2f 100644 --- a/rasa_core/policies/keras_policy.py +++ b/rasa_core/policies/keras_policy.py @@ -9,6 +9,7 @@ import os import warnings import tensorflow as tf +import copy import typing from typing import Any, List, Dict, Text, Optional, Tuple @@ -16,6 +17,8 @@ from rasa_core import utils from rasa_core.policies.policy import Policy from rasa_core.featurizers import TrackerFeaturizer +from rasa_core.featurizers import ( + MaxHistoryTrackerFeaturizer, BinarySingleStateFeaturizer) if typing.TYPE_CHECKING: from rasa_core.domain import Domain @@ -29,21 +32,32 @@ class KerasPolicy(Policy): defaults = { # Neural Net and training params - "rnn_size": 32 + "rnn_size": 32, + "epochs": 100, + "batch_size": 32, + "validation_split": 0.1 } + @staticmethod + def _standard_featurizer(max_history=None): + return MaxHistoryTrackerFeaturizer(BinarySingleStateFeaturizer(), + max_history=max_history) + def __init__(self, featurizer=None, # type: Optional[TrackerFeaturizer] model=None, # type: Optional[tf.keras.models.Sequential] graph=None, # type: Optional[tf.Graph] session=None, # type: Optional[tf.Session] - current_epoch=0 # type: int + current_epoch=0, # type: int + max_history=None, # type: Optional[int] + **kwargs # type: Any ): # type: (...) -> None + if not featurizer: + featurizer = self._standard_featurizer(max_history) super(KerasPolicy, self).__init__(featurizer) - self.rnn_size = self.defaults['rnn_size'] - + self._load_params(**kwargs) self.model = model # by default keras uses default tf graph and global tf session # we are going to either load them or create them in train(...) @@ -52,6 +66,16 @@ def __init__(self, self.current_epoch = current_epoch + def _load_params(self, **kwargs): + # type: (Dict[Text, Any]) -> None + config = copy.deepcopy(self.defaults) + config.update(kwargs) + + self.rnn_size = config['rnn_size'] + self.epochs = config['epochs'] + self.batch_size = config['epochs'] + self.validation_split = config['validation_split'] + @property def max_len(self): if self.model: @@ -125,11 +149,6 @@ def train(self, ): # type: (...) -> Dict[Text: Any] - if kwargs.get('rnn_size') is not None: - logger.debug("Parameter `rnn_size` is updated with {}" - "".format(kwargs.get('rnn_size'))) - self.rnn_size = kwargs.get('rnn_size') - training_data = self.featurize_for_training(training_trackers, domain, **kwargs) @@ -145,17 +164,19 @@ def train(self, self.model = self.model_architecture(shuffled_X.shape[1:], shuffled_y.shape[1:]) - validation_split = kwargs.get("validation_split", 0.0) logger.info("Fitting model with {} total samples and a " "validation split of {}".format( training_data.num_examples(), - validation_split)) + self.validation_split)) # filter out kwargs that cannot be passed to fit params = self._get_valid_params(self.model.fit, **kwargs) - self.model.fit(shuffled_X, shuffled_y, **params) + self.model.fit(shuffled_X, shuffled_y, + epochs=self.epochs, + batch_size=self.batch_size, + **params) # the default parameter for epochs in keras fit is 1 - self.current_epoch = kwargs.get("epochs", 1) + self.current_epoch = self.defaults.get("epochs", 1) logger.info("Done fitting keras policy model") def continue_training(self, training_trackers, domain, **kwargs): diff --git a/rasa_core/train.py b/rasa_core/train.py index c7cbbfcb314..1c9fd551f38 100644 --- a/rasa_core/train.py +++ b/rasa_core/train.py @@ -7,15 +7,21 @@ import argparse import logging +import io +import os +import pickle from rasa_core import config from rasa_core import utils from rasa_core.agent import Agent +from rasa_core.domain import TemplateDomain from rasa_core.broker import PikaProducer from rasa_core.interpreter import NaturalLanguageInterpreter from rasa_core.run import AvailableEndpoints from rasa_core.training import interactive from rasa_core.tracker_store import TrackerStore +from rasa_core.training.dsl import StoryFileReader +from rasa_core.utils import set_default_subparser logger = logging.getLogger(__name__) @@ -24,15 +30,69 @@ def create_argument_parser(): parser = argparse.ArgumentParser( description='trains a dialogue model') - - add_model_and_story_group(parser) - add_args_to_parser(parser) - utils.add_logging_option_arguments(parser) + parent_parser = argparse.ArgumentParser(add_help=False) + add_args_to_parser(parent_parser) + add_model_and_story_group(parent_parser) + utils.add_logging_option_arguments(parent_parser) + subparsers = parser.add_subparsers(help='mode', dest='mode') + subparsers.add_parser('default', + help='default mode: train a dialogue' + ' model', + parents=[parent_parser]) + compare_parser = subparsers.add_parser('compare', + help='compare mode: train multiple ' + 'dialogue models to compare ' + 'policies', + parents=[parent_parser]) + interactive_parser = subparsers.add_parser('interactive', + help='teach the bot with ' + 'interactive learning', + parents=[parent_parser]) + add_compare_args(compare_parser) + add_interactive_args(interactive_parser) return parser +def add_compare_args(parser): + parser.add_argument( + '--percentages', + nargs="*", + type=int, + default=[0, 5, 25, 50, 70, 90, 95], + help="Range of exclusion percentages") + parser.add_argument( + '--runs', + type=int, + default=3, + help="Number of runs for experiments") + + +def add_interactive_args(parser): + parser.add_argument( + '-u', '--nlu', + type=str, + default=None, + help="trained nlu model") + parser.add_argument( + '--endpoints', + default=None, + help="Configuration file for the connectors as a yml file") + parser.add_argument( + '--skip_visualization', + default=False, + action='store_true', + help="disables plotting the visualization during " + "interactive learning") + parser.add_argument( + '--finetune', + default=False, + action='store_true', + help="retrain the model immediately based on feedback.") + + def add_args_to_parser(parser): + parser.add_argument( '-o', '--out', type=str, @@ -41,55 +101,25 @@ def add_args_to_parser(parser): parser.add_argument( '-d', '--domain', type=str, - required=False, + required=True, help="domain specification yaml file") parser.add_argument( - '-u', '--nlu', - type=str, - default=None, - help="trained nlu model") - parser.add_argument( - '--history', - type=int, - default=None, - help="max history to use of a story") - parser.add_argument( - '--epochs', - type=int, - default=100, - help="number of epochs to train the model") - parser.add_argument( - '--validation_split', - type=float, - default=0.1, - help="Percentage of training samples used for validation, " - "0.1 by default") - parser.add_argument( - '--batch_size', + '--augmentation', type=int, - default=20, - help="number of training samples to put into one training batch") - parser.add_argument( - '--interactive', - default=False, - action='store_true', - help="enable interactive training") + default=50, + help="how much data augmentation to use during training") parser.add_argument( - '--skip_visualization', - default=False, - action='store_true', - help="disables plotting the visualization during " - "interactive learning") + '-c', '--config', + type=str, + nargs="*", + default='default_config.yml', + required=True, + help="Policy specification yaml file.") parser.add_argument( - '--finetune', + '--dump_stories', default=False, action='store_true', - help="retrain the model immediately based on feedback.") - parser.add_argument( - '--augmentation', - type=int, - default=50, - help="how much data augmentation to use during training") + help="If enabled, save flattened stories to a file") parser.add_argument( '--debug_plots', default=False, @@ -97,43 +127,7 @@ def add_args_to_parser(parser): help="If enabled, will create plots showing checkpoints " "and their connections between story blocks in a " "file called `story_blocks_connections.pdf`.") - parser.add_argument( - '--dump_stories', - default=False, - action='store_true', - help="If enabled, save flattened stories to a file") - parser.add_argument( - '--endpoints', - default=None, - help="Configuration file for the connectors as a yml file") - parser.add_argument( - '--nlu_threshold', - type=float, - default=None, - required=False, - help="If NLU prediction confidence is below threshold, fallback " - "will get triggered.") - parser.add_argument( - '--core_threshold', - type=float, - default=None, - required=False, - help="If Core action prediction confidence is below the threshold " - "a fallback action will get triggered") - parser.add_argument( - '--fallback_action_name', - type=str, - default=None, - required=False, - help="When a fallback is triggered (e.g. because the " - "ML prediction is of low confidence) this is the name " - "of the action that will get triggered instead.") - parser.add_argument( - '-c', '--config', - type=str, - required=False, - help="Policy specification yaml file." - ) + return parser @@ -162,19 +156,14 @@ def add_model_and_story_group(parser): def train_dialogue_model(domain_file, stories_file, output_path, interpreter=None, endpoints=AvailableEndpoints(), - max_history=None, - dump_flattened_stories=False, + dump_stories=False, policy_config=None, + exclusion_percentage=None, kwargs=None): if not kwargs: kwargs = {} - fallback_args, kwargs = utils.extract_args(kwargs, - {"nlu_threshold", - "core_threshold", - "fallback_action_name"}) - - policies = config.load(policy_config, fallback_args, max_history) + policies = config.load(policy_config) agent = Agent(domain_file, generator=endpoints.nlg, @@ -189,61 +178,135 @@ def train_dialogue_model(domain_file, stories_file, output_path, "remove_duplicates", "debug_plots"}) - training_data = agent.load_data(stories_file, **data_load_args) + training_data = agent.load_data(stories_file, + exclusion_percentage=exclusion_percentage, + **data_load_args) agent.train(training_data, **kwargs) - agent.persist(output_path, dump_flattened_stories) + agent.persist(output_path, dump_stories) return agent def _additional_arguments(args): additional = { - "epochs": args.epochs, - "batch_size": args.batch_size, - "validation_split": args.validation_split, "augmentation_factor": args.augmentation, - "debug_plots": args.debug_plots, - "nlu_threshold": args.nlu_threshold, - "core_threshold": args.core_threshold, - "fallback_action_name": args.fallback_action_name + "debug_plots": args.debug_plots } # remove None values return {k: v for k, v in additional.items() if v is not None} -if __name__ == '__main__': +def train_comparison_models(story_filename, + domain, + output_path=None, + exclusion_percentages=None, + policy_configs=None, + runs=None, + dump_stories=False, + kwargs=None): + """Train multiple models for comparison of policies""" - # Running as standalone python application - arg_parser = create_argument_parser() - cmdline_args = arg_parser.parse_args() + for r in range(cmdline_args.runs): + logging.info("Starting run {}/{}".format(r + 1, cmdline_args.runs)) + for i in exclusion_percentages: + current_round = cmdline_args.percentages.index(i) + 1 + for policy_config in policy_configs: + policies = config.load(policy_config) + if len(policies) > 1: + raise ValueError("You can only specify one policy per " + "model for comparison") + policy_name = type(policies[0]).__name__ + output = os.path.join(output_path, 'run_' + str(r + 1), + policy_name + + str(current_round)) - utils.configure_colored_logging(cmdline_args.loglevel) + logging.info("Starting to train {} round {}/{}" + " with {}% exclusion".format( + policy_name, + current_round, + len(exclusion_percentages), + i)) - additional_arguments = _additional_arguments(cmdline_args) + train_dialogue_model( + domain, stories, output, + policy_config=policy_config, + exclusion_percentage=i, + kwargs=kwargs, + dump_stories=dump_stories) + + +def get_no_of_stories(stories, domain): + + """Get number of stories in a file.""" + + no_stories = len(StoryFileReader.read_from_folder(stories, + TemplateDomain.load( + domain))) + return no_stories + + +def do_default_training(cmdline_args, stories, additional_arguments): + if not cmdline_args.out: + raise ValueError("you must provide a path where the model " + "will be saved using -o / --out") + if (isinstance(cmdline_args.config, list) and + len(cmdline_args.config) > 1): + raise ValueError("You can only pass one config file at a time") + + train_dialogue_model(domain_file=cmdline_args.domain, + stories_file=stories, + output_path=cmdline_args.out, + dump_stories=cmdline_args.dump_stories, + policy_config=cmdline_args.config[0], + kwargs=additional_arguments) - if cmdline_args.url: - stories = utils.download_file_from_url(cmdline_args.url) - else: - stories = cmdline_args.stories +def do_compare_training(cmdline_args, stories, additional_arguments): + if not cmdline_args.out: + raise ValueError("you must provide a path where the model " + "will be saved using -o / --out") + + train_comparison_models(cmdline_args.stories, + cmdline_args.domain, + cmdline_args.out, + cmdline_args.percentages, + cmdline_args.config, + cmdline_args.runs, + cmdline_args.dump_stories, + additional_arguments) + + no_stories = get_no_of_stories(cmdline_args.stories, + cmdline_args.domain) + + # store the list of the number of stories present at each exclusion + # percentage + story_range = [no_stories - round((x/100.0) * no_stories) for x in + cmdline_args.percentages] + + pickle.dump(story_range, + io.open(os.path.join(cmdline_args.out, 'num_stories.p'), + 'wb')) + + +def do_interactive_learning(cmdline_args, stories, additional_arguments): _endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints) _interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu, _endpoints.nlu) - _broker = PikaProducer.from_endpoint_config(_endpoints.event_broker) - _tracker_store = TrackerStore.find_tracker_store(None, - _endpoints.tracker_store, - _broker) - if cmdline_args.core: - if not cmdline_args.interactive: - raise ValueError("--core can only be used together with the" - "--interactive flag.") - elif cmdline_args.finetune: - raise ValueError("--core can only be used together with the" - "--interactive flag and without --finetune flag.") - else: - logger.info("loading a pre-trained model. ", - "all training-related parameters will be ignored") + if (isinstance(cmdline_args.config, list) and + len(cmdline_args.config) > 1): + raise ValueError("You can only pass one config file at a time") + if cmdline_args.core and cmdline_args.finetune: + raise ValueError("--core can only be used without --finetune flag.") + elif cmdline_args.core: + logger.info("loading a pre-trained model. " + "all training-related parameters will be ignored") + + _broker = PikaProducer.from_endpoint_config(_endpoints.event_broker) + _tracker_store = TrackerStore.find_tracker_store( + None, + _endpoints.tracker_store, + _broker) _agent = Agent.load(cmdline_args.core, interpreter=_interpreter, generator=_endpoints.nlg, @@ -253,18 +316,45 @@ def _additional_arguments(args): if not cmdline_args.out: raise ValueError("you must provide a path where the model " "will be saved using -o / --out") + _agent = train_dialogue_model(cmdline_args.domain, stories, cmdline_args.out, _interpreter, _endpoints, - cmdline_args.history, cmdline_args.dump_stories, - cmdline_args.config, + cmdline_args.config[0], + None, additional_arguments) + interactive.run_interactive_learning( + _agent, stories, + finetune=cmdline_args.finetune, + skip_visualization=cmdline_args.skip_visualization) + + +if __name__ == '__main__': + + # Running as standalone python application + arg_parser = create_argument_parser() + set_default_subparser(arg_parser, 'default') + cmdline_args = arg_parser.parse_args() + if not cmdline_args.mode: + raise ValueError("You must specify the mode you want training to run " + "in. The options are: (default|compare|interactive)") + additional_arguments = _additional_arguments(cmdline_args) + + utils.configure_colored_logging(cmdline_args.loglevel) + + if cmdline_args.url: + stories = utils.download_file_from_url(cmdline_args.url) + else: + stories = cmdline_args.stories + + if cmdline_args.mode == 'default': + do_default_training(cmdline_args, stories, additional_arguments) + + elif cmdline_args.mode == 'interactive': + do_interactive_learning(cmdline_args, stories, additional_arguments) - if cmdline_args.interactive: - interactive.run_interactive_learning( - _agent, stories, - finetune=cmdline_args.finetune, - skip_visualization=cmdline_args.skip_visualization) + elif cmdline_args.mode == 'compare': + do_compare_training(cmdline_args, stories, additional_arguments) diff --git a/rasa_core/training/__init__.py b/rasa_core/training/__init__.py index ccf0f32341c..880cb493f47 100644 --- a/rasa_core/training/__init__.py +++ b/rasa_core/training/__init__.py @@ -17,7 +17,8 @@ def extract_story_graph( resource_name, # type: Text domain, # type: Domain interpreter=None, # type: Optional[NaturalLanguageInterpreter] - use_e2e=False # type: bool + use_e2e=False, # type: bool + exclusion_percentage=None # type: int ): # type: (...) -> StoryGraph from rasa_core.interpreter import RegexInterpreter @@ -26,9 +27,11 @@ def extract_story_graph( if not interpreter: interpreter = RegexInterpreter() - story_steps = StoryFileReader.read_from_folder(resource_name, - domain, interpreter, - use_e2e=use_e2e) + story_steps = StoryFileReader.read_from_folder( + resource_name, + domain, interpreter, + use_e2e=use_e2e, + exclusion_percentage=exclusion_percentage) return StoryGraph(story_steps) @@ -40,14 +43,16 @@ def load_data( augmentation_factor=20, # type: int tracker_limit=None, # type: Optional[int] use_story_concatenation=True, # type: bool - debug_plots=False # type: bool + debug_plots=False, + exclusion_percentage=None # type: int ): # type: (...) -> List[DialogueStateTracker] from rasa_core.training import extract_story_graph from rasa_core.training.generator import TrainingDataGenerator if resource_name: - graph = extract_story_graph(resource_name, domain) + graph = extract_story_graph(resource_name, domain, + exclusion_percentage=exclusion_percentage) g = TrainingDataGenerator(graph, domain, remove_duplicates, diff --git a/rasa_core/training/dsl.py b/rasa_core/training/dsl.py index 1157c188fbd..0554d1ad072 100644 --- a/rasa_core/training/dsl.py +++ b/rasa_core/training/dsl.py @@ -161,7 +161,8 @@ def __init__(self, domain, interpreter, template_vars=None, use_e2e=False): @staticmethod def read_from_folder(resource_name, domain, interpreter=RegexInterpreter(), - template_variables=None, use_e2e=False): + template_variables=None, use_e2e=False, + exclusion_percentage=None): """Given a path reads all contained story files.""" story_steps = [] @@ -169,6 +170,14 @@ def read_from_folder(resource_name, domain, interpreter=RegexInterpreter(), steps = StoryFileReader.read_from_file(f, domain, interpreter, template_variables, use_e2e) story_steps.extend(steps) + + # if exclusion percentage is not 100 + if exclusion_percentage and exclusion_percentage is not 100: + import random + idx = int(round(exclusion_percentage/100.0 * len(story_steps))) + random.shuffle(story_steps) + story_steps = story_steps[:-idx] + return story_steps @staticmethod diff --git a/rasa_core/utils.py b/rasa_core/utils.py index 51cdff73f16..77d7fb5581c 100644 --- a/rasa_core/utils.py +++ b/rasa_core/utils.py @@ -13,6 +13,7 @@ import re import sys import tempfile +import argparse from builtins import input, range, str from hashlib import sha1 from random import Random @@ -84,8 +85,16 @@ def class_from_module_path(module_path): # load the module, will raise ImportError if module cannot be loaded from rasa_core.policies.keras_policy import KerasPolicy from rasa_core.policies.fallback import FallbackPolicy - from rasa_core.policies.memoization import MemoizationPolicy - + from rasa_core.policies.memoization import (MemoizationPolicy, + AugmentedMemoizationPolicy) + from rasa_core.policies.embedding_policy import EmbeddingPolicy + from rasa_core.policies.form_policy import FormPolicy + from rasa_core.policies.sklearn_policy import SklearnPolicy + + from rasa_core.featurizers import (FullDialogueTrackerFeaturizer, + MaxHistoryTrackerFeaturizer, + BinarySingleStateFeaturizer, + LabelTokenizerSingleStateFeaturizer) if "." in module_path: module_name, _, class_name = module_path.rpartition('.') m = importlib.import_module(module_name) @@ -413,6 +422,12 @@ def read_file(filename, encoding="utf-8"): return f.read() +def read_json_file(filename): + """Read json from a file""" + with io.open(filename) as f: + return json.load(f) + + def list_routes(app): """List all available routes of a flask web server.""" from six.moves.urllib.parse import unquote @@ -749,3 +764,25 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + + +def set_default_subparser(parser, + default_subparser): + """default subparser selection. Call after setup, just before parse_args() + + parser: the name of the parser you're making changes to + default_subparser: the name of the subparser to call by default""" + subparser_found = False + for arg in sys.argv[1:]: + if arg in ['-h', '--help']: # global help if no subparser + break + else: + for x in parser._subparsers._actions: + if not isinstance(x, argparse._SubParsersAction): + continue + for sp_name in x._name_parser_map.keys(): + if sp_name in sys.argv[1:]: + subparser_found = True + if not subparser_found: + # insert default in first position before all other arguments + sys.argv.insert(1, default_subparser) diff --git a/tests/conftest.py b/tests/conftest.py index 2db7ecb9410..794535e1eb8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -105,7 +105,7 @@ def trained_moodbot_path(): stories_file="examples/moodbot/data/stories.md", output_path=MOODBOT_MODEL_PATH, interpreter=RegexInterpreter(), - max_history=None, + policy_config='default_config.yml', kwargs=None ) diff --git a/tests/test_config.py b/tests/test_config.py index 9a75e79a08b..ea0011577fa 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,35 +7,14 @@ import pytest from tests.conftest import ExamplePolicy -from rasa_core.config import load, handle_precedence_and_defaults +from rasa_core.config import load from rasa_core.policies.memoization import MemoizationPolicy -def test_handle_precedence_and_defaults_for_config(): - - config_data = {'policies': [ - {'name': 'FallbackPolicy', 'nlu_threshold': 0.5}, - {'name': 'KerasPolicy'} - ]} - fallback_args = { - 'nlu_threshold': 1, - 'core_threshold': 1, - 'fallback_action_name': 'some_name' - } - expected_config_data = {'policies': [ - {'name': 'FallbackPolicy', 'nlu_threshold': 1, - 'core_threshold': 1, 'fallback_action_name': 'some_name'}, - {'name': 'KerasPolicy', 'max_history': 3} - ]} - new_config_data = handle_precedence_and_defaults( - config_data, fallback_args, None) - assert new_config_data == expected_config_data - - @pytest.mark.parametrize("filename", glob.glob( "data/test_config/example_config.yaml")) def test_load_config(filename): - loaded = load(filename, None, None) + loaded = load(filename) assert len(loaded) == 2 assert isinstance(loaded[0], MemoizationPolicy) assert isinstance(loaded[1], ExamplePolicy) diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 8e7f0967068..5048725b505 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -3,7 +3,6 @@ from __future__ import print_function from __future__ import unicode_literals -import imghdr import os from rasa_core import evaluate @@ -14,45 +13,45 @@ def test_evaluation_image_creation(tmpdir, default_agent): - stories_path = tmpdir.join("failed_stories.md").strpath - img_path = tmpdir.join("evaluation.png").strpath + stories_path = os.path.join(tmpdir.strpath, "failed_stories.md") + img_path = os.path.join(tmpdir.strpath, "story_confmat.pdf") run_story_evaluation( resource_name=DEFAULT_STORIES_FILE, agent=default_agent, - out_file_plot=img_path, + out_directory=tmpdir.strpath, max_stories=None, - out_file_stories=stories_path, use_e2e=False ) assert os.path.isfile(img_path) - assert imghdr.what(img_path) == "png" - assert os.path.isfile(stories_path) def test_action_evaluation_script(tmpdir, default_agent): completed_trackers = evaluate._generate_trackers( DEFAULT_STORIES_FILE, default_agent, use_e2e=False) - - story_evaluation = collect_story_predictions(completed_trackers, - default_agent, - use_e2e=False) + story_evaluation, num_stories = collect_story_predictions( + completed_trackers, + default_agent, + use_e2e=False) assert not story_evaluation.evaluation_store. \ has_prediction_target_mismatch() assert len(story_evaluation.failed_stories) == 0 + assert num_stories == 3 def test_end_to_end_evaluation_script(tmpdir, default_agent): completed_trackers = evaluate._generate_trackers( END_TO_END_STORY_FILE, default_agent, use_e2e=True) - story_evaluation = collect_story_predictions(completed_trackers, - default_agent, - use_e2e=True) + story_evaluation, num_stories = collect_story_predictions( + completed_trackers, + default_agent, + use_e2e=True) assert not story_evaluation.evaluation_store. \ has_prediction_target_mismatch() assert len(story_evaluation.failed_stories) == 0 + assert num_stories == 2 diff --git a/tests/test_examples.py b/tests/test_examples.py index 828d13e65a5..f5d299fd03f 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -50,7 +50,8 @@ def test_formbot_example(): agent = train_dialogue_model(os.path.join(p, "domain.yml"), stories, os.path.join(p, "models", "dialogue"), - endpoints=endpoints) + endpoints=endpoints, + policy_config="default_config.yml") response = { 'events': [ {'event': 'form', 'name': 'restaurant_form', 'timestamp': None}, diff --git a/tests/test_training.py b/tests/test_training.py index c239b88e19e..69441ccf89d 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -49,17 +49,19 @@ def test_story_visualization_with_merging(default_domain): def test_training_script(tmpdir): train_dialogue_model(DEFAULT_DOMAIN_PATH, DEFAULT_STORIES_FILE, tmpdir.strpath, + policy_config='data/test_config/max_hist_config.yml', interpreter=RegexInterpreter(), kwargs={}) assert True def test_training_script_without_max_history_set(tmpdir): - train_dialogue_model(DEFAULT_DOMAIN_PATH, DEFAULT_STORIES_FILE, - tmpdir.strpath, - interpreter=RegexInterpreter(), - max_history=None, - kwargs={}) + train_dialogue_model( + DEFAULT_DOMAIN_PATH, DEFAULT_STORIES_FILE, + tmpdir.strpath, + interpreter=RegexInterpreter(), + policy_config='data/test_config/no_max_hist_config.yml', + kwargs={}) agent = Agent.load(tmpdir.strpath) for policy in agent.policy_ensemble.policies: if hasattr(policy.featurizer, 'max_history'): @@ -71,11 +73,10 @@ def test_training_script_without_max_history_set(tmpdir): def test_training_script_with_max_history_set(tmpdir): - max_history = 3 train_dialogue_model(DEFAULT_DOMAIN_PATH, DEFAULT_STORIES_FILE, tmpdir.strpath, interpreter=RegexInterpreter(), - max_history=max_history, + policy_config='data/test_config/max_hist_config.yml', kwargs={}) agent = Agent.load(tmpdir.strpath) for policy in agent.policy_ensemble.policies: @@ -83,7 +84,7 @@ def test_training_script_with_max_history_set(tmpdir): if type(policy) == FormPolicy: assert policy.featurizer.max_history == 2 else: - assert policy.featurizer.max_history == max_history + assert policy.featurizer.max_history == 5 def test_training_script_with_restart_stories(tmpdir): @@ -91,5 +92,6 @@ def test_training_script_with_restart_stories(tmpdir): "data/test_stories/stories_restart.md", tmpdir.strpath, interpreter=RegexInterpreter(), + policy_config='data/test_config/max_hist_config.yml', kwargs={}) assert True