diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3451f61f259..2a6658f59b6 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -21,6 +21,7 @@ Added - added ``/parse`` endpoint to query for NLU results - File based event store - ability to configure event store using the endpoints file +- added ability to use multiple env vars per line in yaml files Changed ------- @@ -34,12 +35,17 @@ Changed - configuration key ``store_type`` of the tracker store endpoint configuration has been renamed to ``type`` to allow usage accross endpoints - renamed ``policy_metadata.json`` to ``metadata.json`` for persisted models +- ``scores`` array returned by the ``/conversations/{sender_id}/predict`` + endpoint is now sorted according to the actions' scores. Removed ------- +- removed ``admin_token`` from ``RasaChatInput`` since it wasn't used Fixed ----- +- When a ``fork`` is used in interactive learning, every forked + storyline is saved (not just the last) [0.13.2] - 2019-02-06 ^^^^^^^^^^^^^^^^^^^^^ diff --git a/README.md b/README.md index a5380eae996..8a9bcb32534 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ # Rasa Core - [![Join the chat on Rasa Community Forum](https://img.shields.io/badge/forum-join%20discussions-brightgreen.svg)](https://forum.rasa.com/?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![PyPI version](https://img.shields.io/pypi/v/rasa_core.svg)](https://pypi.python.org/pypi/rasa-core) [![Supported Python Versions](https://img.shields.io/pypi/pyversions/rasa_core.svg)](https://pypi.python.org/pypi/rasa_core) @@ -8,29 +7,10 @@ [![Coverage Status](https://coveralls.io/repos/github/RasaHQ/rasa_core/badge.svg?branch=master)](https://coveralls.io/github/RasaHQ/rasa_core?branch=master) [![Documentation Status](https://img.shields.io/badge/docs-stable-brightgreen.svg)](https://rasa.com/docs/core) - -- **What do Rasa Core & NLU do? 🤔** - [Read About the Rasa Stack](https://rasa.com/products/rasa-stack/) - -- **I'd like to read the detailed docs 🤓** - [Read The Docs](https://rasa.com/docs/core) - -- **I'm ready to install Rasa Core! 🚀** - [Installation](https://rasa.com/docs/core/installation.html) - -- **I'm ready to start building! 🤖** - [Rasa Stack starter-pack](https://github.com/RasaHQ/starter-pack-rasa-stack) - -- **I have a question ❓** - [Rasa Community Forum](https://forum.rasa.com) - -- **I would like to contribute 🤗** - [How to contribute](#how-to-contribute) - -## Introduction + Rasa Core is a framework for building conversational software, which includes -Chat Bots on: +chatbots on: - Facebook Messenger - Slack - Microsoft Bot Framework @@ -51,6 +31,26 @@ Rasa Core lets you do that in a scalable way. There's a lot more background information in this [blog post](https://medium.com/rasa-blog/a-new-approach-to-conversational-software-2e64a5d05f2a). +--- +- **What do Rasa Core & NLU do? 🤔** + [Read About the Rasa Stack](https://rasa.com/products/rasa-stack/) + +- **I'd like to read the detailed docs 🤓** + [Read The Docs](https://rasa.com/docs/core) + +- **I'm ready to install Rasa Core! 🚀** + [Installation](https://rasa.com/docs/core/installation/) + +- **I'm ready to start building! 🤖** + [Rasa Stack starter-pack](https://github.com/RasaHQ/starter-pack-rasa-stack) + +- **I have a question ❓** + [Rasa Community Forum](https://forum.rasa.com) + +- **I would like to contribute 🤗** + [How to contribute](#how-to-contribute) + +--- ## Where to get help There is extensive documentation: @@ -65,7 +65,6 @@ Please use [Rasa Community Forum](https://forum.rasa.com) for quick answers to questions. - ### README Contents: - [How to contribute](#how-to-contribute) - [Development Internals](#development-internals) @@ -74,7 +73,7 @@ questions. ### How to contribute We are very happy to receive and merge your contributions. There is some more information about the style of the code and docs in the -[documentation](https://nlu.rasa.com/contribute.html). +[documentation](https://rasa.com/docs/contributing/). In general the process is rather simple: 1. create an issue describing the feature you want to work on (or diff --git a/docs/_static/spec/server.yml b/docs/_static/spec/server.yml index 883128a80e1..cf09e648e4d 100644 --- a/docs/_static/spec/server.yml +++ b/docs/_static/spec/server.yml @@ -343,10 +343,11 @@ paths: - Tracker summary: Predict the next action description: >- - Runs the conversations tracker through the models - policies to predict the next action. The action is - not executed, just returned. The state of the tracker - is not modified. + Runs the conversations tracker through the model's + policies to predict the scores of all actions present + in the model's domain. Actions are returned in the + 'scores' array, sorted on their 'score' values. + The state of the tracker is not modified. operationId: predictAction parameters: - $ref: '#/components/parameters/senderId' diff --git a/docs/migrations.rst b/docs/migrations.rst index 116eaebff20..a33944787bb 100644 --- a/docs/migrations.rst +++ b/docs/migrations.rst @@ -8,6 +8,20 @@ Migration Guide This page contains information about changes between major versions and how you can migrate from one version to another. + +.. _migration-to-0-14-0: + +0.13.x to 0.14.0 +---------------- + +Function Naming +~~~~~~~~~~~~~~~ +- renamed ``train_dialogue_model`` to ``train``. Please use ``train`` from + now on. +- renamed ``rasa_core.evaluate`` to ``rasa_core.test``. Please use ``test`` + from now on. + + .. _migration-to-0-13-0: 0.12.x to 0.13.0 diff --git a/examples/concertbot/train.py b/examples/concertbot/train.py index a91151d23eb..5bd09697ea7 100644 --- a/examples/concertbot/train.py +++ b/examples/concertbot/train.py @@ -8,13 +8,11 @@ def train_dialogue(domain_file='domain.yml', stories_file='data/stories.md', model_path='models/dialogue', policy_config='policy_config.yml'): - return train.train(domain_file=domain_file, - stories_file=stories_file, - output_path=model_path, - policy_config=policy_config, - kwargs={'augmentation_factor': 50, - 'validation_split': 0.2} - ) + return train(domain_file=domain_file, + stories_file=stories_file, + output_path=model_path, + policy_config=policy_config, + kwargs={'augmentation_factor': 50, 'validation_split': 0.2}) if __name__ == '__main__': diff --git a/examples/concertbot/train_interactive.py b/examples/concertbot/train_interactive.py index 784088d47bc..913f21cd8a3 100644 --- a/examples/concertbot/train_interactive.py +++ b/examples/concertbot/train_interactive.py @@ -7,11 +7,10 @@ def train_agent(): - return train.train(domain_file="domain.yml", - stories_file="data/stories.md", - output_path="models/dialogue", - policy_config='policy_config.yml' - ) + return train(domain_file="domain.yml", + stories_file="data/stories.md", + output_path="models/dialogue", + policy_config='policy_config.yml') if __name__ == '__main__': diff --git a/rasa_core/__init__.py b/rasa_core/__init__.py index 657adaa56b7..2b2c07a1cef 100644 --- a/rasa_core/__init__.py +++ b/rasa_core/__init__.py @@ -3,7 +3,7 @@ import rasa_core.version from rasa_core.train import train -from rasa_core.test import test as test +from rasa_core.test import test from rasa_core.visualize import visualize logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/rasa_core/agent.py b/rasa_core/agent.py index 3ebe07e8796..98862f4e34a 100644 --- a/rasa_core/agent.py +++ b/rasa_core/agent.py @@ -225,8 +225,7 @@ def __init__( "FormPolicy to your policy ensemble." ) - self.interpreter = None - self.set_interpreter(interpreter) + self.interpreter = NaturalLanguageInterpreter.create(interpreter) self.nlg = NaturalLanguageGenerator.create(generator, self.domain) self.tracker_store = self.create_tracker_store( @@ -368,7 +367,7 @@ def handle_text( message_preprocessor: Optional[Callable[[Text], Text]] = None, output_channel: Optional[OutputChannel] = None, sender_id: Optional[Text] = UserMessage.DEFAULT_SENDER_ID - ) -> Optional[List[Any]]: + ) -> Optional[List[Dict[Text, Any]]]: """Handle a single message. If a message preprocessor is passed, the message will be passed to that @@ -707,21 +706,3 @@ def _form_policy_not_present(self) -> bool: return (self.domain and self.domain.form_names and not any(isinstance(p, FormPolicy) for p in self.policy_ensemble.policies)) - - def set_interpreter(self, - interpreter: Optional[NaturalLanguageInterpreter] - ) -> None: - from rasa_nlu.model import Interpreter - - if not (isinstance(interpreter, NaturalLanguageInterpreter) or - isinstance(interpreter, Interpreter)): - if interpreter is not None: - logger.warning( - "Passing a value for interpreter to an agent " - "where the value is not an interpreter " - "is deprecated. Construct the interpreter, before" - "passing it to the agent, e.g. " - "`interpreter = NaturalLanguageInterpreter.create(nlu)`.") - - interpreter = NaturalLanguageInterpreter.create(interpreter, None) - self.interpreter = interpreter diff --git a/rasa_core/channels/rasa_chat.py b/rasa_core/channels/rasa_chat.py index 74712fb8316..015da014837 100644 --- a/rasa_core/channels/rasa_chat.py +++ b/rasa_core/channels/rasa_chat.py @@ -21,12 +21,10 @@ def from_credentials(cls, credentials): if not credentials: cls.raise_missing_credentials_exception() - return cls(credentials.get("url"), - credentials.get("admin_token")) + return cls(credentials.get("url")) - def __init__(self, url, admin_token=None): + def __init__(self, url): self.base_url = url - self.admin_token = admin_token def _check_token(self, token): url = "{}/users/me".format(self.base_url) diff --git a/rasa_core/cli/__init__.py b/rasa_core/cli/__init__.py index 719d40f6d78..cfd1a080158 100644 --- a/rasa_core/cli/__init__.py +++ b/rasa_core/cli/__init__.py @@ -3,12 +3,3 @@ import rasa_core.cli.run import rasa_core.cli.train import rasa_core.cli.visualization - - -def stories_from_cli_args(cmdline_arguments): - from rasa_core import utils - - if cmdline_arguments.url: - return utils.download_file_from_url(cmdline_arguments.url) - else: - return cmdline_arguments.stories diff --git a/rasa_core/cli/train.py b/rasa_core/cli/train.py index 09057856f31..16367de17fb 100644 --- a/rasa_core/cli/train.py +++ b/rasa_core/cli/train.py @@ -102,3 +102,12 @@ def add_general_args(parser): "file called `story_blocks_connections.html`.") arguments.add_logging_option_arguments(parser) + + +def stories_from_cli_args(cmdline_arguments): + from rasa_core import utils + + if cmdline_arguments.url: + return utils.download_file_from_url(cmdline_arguments.url) + else: + return cmdline_arguments.stories diff --git a/rasa_core/constants.py b/rasa_core/constants.py index e4402ad4d49..396d4766dae 100644 --- a/rasa_core/constants.py +++ b/rasa_core/constants.py @@ -4,7 +4,7 @@ DEFAULT_SERVER_URL = DEFAULT_SERVER_FORMAT.format(DEFAULT_SERVER_PORT) -MINIMUM_COMPATIBLE_VERSION = "0.13.0a6" +MINIMUM_COMPATIBLE_VERSION = "0.14.0a1" DOCS_BASE_URL = "https://rasa.com/docs/core" diff --git a/rasa_core/evaluate.py b/rasa_core/evaluate.py index 8b0369d7b5c..562be326038 100644 --- a/rasa_core/evaluate.py +++ b/rasa_core/evaluate.py @@ -1,10 +1,10 @@ import logging -import rasa_core.test as test +from rasa_core.test import main logger = logging.getLogger(__name__) if __name__ == '__main__': # pragma: no cover logger.warning("Calling `rasa_core.evaluate` is deprecated. " "Please use `rasa_core.test` instead.") - test.main() + main() diff --git a/rasa_core/interpreter.py b/rasa_core/interpreter.py index f9c5257b4c7..4ddaf75c0dd 100644 --- a/rasa_core/interpreter.py +++ b/rasa_core/interpreter.py @@ -21,7 +21,10 @@ def parse(self, text): @staticmethod def create(obj, endpoint=None): - if isinstance(obj, NaturalLanguageInterpreter): + from rasa_nlu.model import Interpreter + + if (isinstance(obj, NaturalLanguageInterpreter) or + isinstance(obj, Interpreter)): return obj if not isinstance(obj, str): diff --git a/rasa_core/server.py b/rasa_core/server.py index 3837924f4e7..c4a6f598ff3 100644 --- a/rasa_core/server.py +++ b/rasa_core/server.py @@ -406,6 +406,9 @@ def predict(sender_id): try: # Fetches the appropriate bot response in a json format responses = agent.predict_next(sender_id) + responses['scores'] = sorted(responses['scores'], + key = lambda k: (-k['score'], + k['action'])) return jsonify(responses) except Exception as e: diff --git a/rasa_core/test.py b/rasa_core/test.py index 0794f508ff4..3a2676cf157 100644 --- a/rasa_core/test.py +++ b/rasa_core/test.py @@ -10,6 +10,7 @@ import rasa_core.cli.arguments from typing import List, Optional, Any, Text, Dict, Tuple +import rasa_core.cli.train from rasa_core import training, cli from rasa_core import utils from rasa_core.events import ActionExecuted, UserUttered @@ -615,7 +616,7 @@ def main(): _agent = Agent.load(cmdline_arguments.core, interpreter=_interpreter) - stories = cli.stories_from_cli_args(cmdline_arguments) + stories = rasa_core.cli.train.stories_from_cli_args(cmdline_arguments) test(stories, _agent, cmdline_arguments.max_stories, cmdline_arguments.output, diff --git a/rasa_core/train.py b/rasa_core/train.py index 936d87d10a6..64fe831e352 100644 --- a/rasa_core/train.py +++ b/rasa_core/train.py @@ -4,6 +4,7 @@ import tempfile from typing import Text, Dict, Optional +import rasa_core.cli.train from rasa_core import config, cli from rasa_core import utils from rasa_core.broker import PikaProducer @@ -243,7 +244,8 @@ def do_interactive_learning(cmdline_args, stories, additional_arguments=None): utils.configure_colored_logging(cmdline_arguments.loglevel) - training_stories = cli.stories_from_cli_args(cmdline_arguments) + training_stories = rasa_core.cli.train.stories_from_cli_args( + cmdline_arguments) if cmdline_arguments.mode == 'default': do_default_training(cmdline_arguments, diff --git a/rasa_core/training/interactive.py b/rasa_core/training/interactive.py index f5bf3ea541a..283734ac2fd 100644 --- a/rasa_core/training/interactive.py +++ b/rasa_core/training/interactive.py @@ -12,7 +12,8 @@ from gevent.pywsgi import WSGIServer from terminaltables import SingleTable, AsciiTable from threading import Thread -from typing import Any, Text, Dict, List, Optional, Callable, Union, Tuple +from typing import (Any, Text, Dict, List, Optional, Callable, Union, Tuple, + TYPE_CHECKING) from rasa_core import utils, server, events, constants from rasa_core.actions.action import ACTION_LISTEN_NAME, default_action_names @@ -38,6 +39,8 @@ from rasa_nlu.training_data.loading import load_data, _guess_format from rasa_nlu.training_data.message import Message +if TYPE_CHECKING: + from rasa_core.agent import Agent try: FileNotFoundError except NameError: @@ -53,9 +56,9 @@ MAX_VISUAL_HISTORY = 3 -PATHS = {"stories": "data/core/stories.md", - "nlu": "data/nlu/nlu.md", - "backup": "data/nlu/nlu_interactive.md", +PATHS = {"stories": "data/stories.md", + "nlu": "data/nlu.md", + "backup": "data/nlu_interactive.md", "domain": "domain.yml"} # choose other intent, making sure this doesn't clash with an existing intent @@ -553,6 +556,21 @@ def _slot_history(tracker_dump: Dict[Text, Any]) -> List[Text]: return slot_strs +def _write_data_to_file(sender_id: Text, endpoint: EndpointConfig): + """Write stories and nlu data to file.""" + + story_path, nlu_path, domain_path = _request_export_info() + + tracker = retrieve_tracker(endpoint, sender_id) + evts = tracker.get("events", []) + + _write_stories_to_file(story_path, evts) + _write_nlu_to_file(nlu_path, evts) + _write_domain_to_file(domain_path, evts, endpoint) + + logger.info("Successfully wrote stories and NLU data") + + def _ask_if_quit(sender_id: Text, endpoint: EndpointConfig) -> bool: """Display the exit menu. @@ -568,16 +586,7 @@ def _ask_if_quit(sender_id: Text, endpoint: EndpointConfig) -> bool: if not answer or answer == "quit": # this is also the default answer if the user presses Ctrl-C - story_path, nlu_path, domain_path = _request_export_info() - - tracker = retrieve_tracker(endpoint, sender_id) - evts = tracker.get("events", []) - - _write_stories_to_file(story_path, evts) - _write_nlu_to_file(nlu_path, evts) - _write_domain_to_file(domain_path, evts, endpoint) - - logger.info("Successfully wrote stories and NLU data") + _write_data_to_file(sender_id, endpoint) sys.exit() elif answer == "continue": # in this case we will just return, and the original @@ -599,13 +608,10 @@ def _request_action_from_user( _print_history(sender_id, endpoint) - sorted_actions = sorted(predictions, - key=lambda k: (-k['score'], k['action'])) - choices = [{"name": "{:03.2f} {:40}".format(a.get("score"), a.get("action")), "value": a.get("action")} - for a in sorted_actions] + for a in predictions] choices = ([{"name": "", "value": OTHER_ACTION}] + choices) @@ -1235,14 +1241,19 @@ def record_messages(endpoint: EndpointConfig, except ForkTracker: _print_history(sender_id, endpoint) - evts = _request_fork_from_user(sender_id, endpoint) - sender_id = uuid.uuid4().hex + evts_fork = _request_fork_from_user(sender_id, endpoint) + + send_event(endpoint, sender_id, + Restarted().as_dict()) + + if evts_fork: + for evt in evts_fork: + send_event(endpoint, sender_id, evt) - if evts is not None: - replace_events(endpoint, sender_id, evts) - sender_ids.append(sender_id) - _print_history(sender_id, endpoint) - _plot_trackers(sender_ids, plot_file, endpoint) + logger.info("Restarted conversation at fork.") + + _print_history(sender_id, endpoint) + _plot_trackers(sender_ids, plot_file, endpoint) except Exception: logger.exception("An exception occurred while recording messages.") @@ -1317,7 +1328,7 @@ def visualisation_png(): abort(404) -def run_interactive_learning(agent: 'Agent', +def run_interactive_learning(agent: "Agent", stories: Text = None, finetune: bool = False, serve_forever: bool = True, diff --git a/rasa_core/training/visualization.py b/rasa_core/training/visualization.py index 6a488d7ca5a..708a3b44543 100644 --- a/rasa_core/training/visualization.py +++ b/rasa_core/training/visualization.py @@ -31,8 +31,8 @@ def __init__(self, nlu_training_data): @staticmethod def _create_reverse_mapping( - data: 'TrainingData' - ) -> Dict[Dict[Text, Any], List['Message']]: + data: "TrainingData" + ) -> Dict[Dict[Text, Any], List["Message"]]: """Create a mapping from intent to messages This allows a faster intent lookup.""" @@ -363,7 +363,7 @@ def visualize_neighborhood( output_file: Optional[Text] = None, max_history: int = 2, interpreter: NaturalLanguageInterpreter = RegexInterpreter(), - nlu_training_data: Optional['TrainingData'] = None, + nlu_training_data: Optional["TrainingData"] = None, should_merge_nodes: bool = True, max_distance: int = 1, fontsize: int = 12 @@ -480,7 +480,7 @@ def visualize_stories( output_file: Optional[Text], max_history: int, interpreter: NaturalLanguageInterpreter = RegexInterpreter(), - nlu_training_data: Optional['TrainingData'] = None, + nlu_training_data: Optional["TrainingData"] = None, should_merge_nodes: bool = True, fontsize: int = 12, silent: bool = False diff --git a/rasa_core/utils.py b/rasa_core/utils.py index 614c3a62fa0..adbc6ac6c2b 100644 --- a/rasa_core/utils.py +++ b/rasa_core/utils.py @@ -299,8 +299,14 @@ def replace_environment_variables(): def env_var_constructor(loader, node): """Process environment variables found in the YAML.""" value = loader.construct_scalar(node) - prefix, env_var, remaining_path = env_var_pattern.match(value).groups() - return prefix + os.environ[env_var] + remaining_path + expanded_vars = os.path.expandvars(value) + if '$' in expanded_vars: + not_expanded = [w for w in expanded_vars.split() if '$' in w] + raise ValueError( + "Error when trying to expand the environment variables" + " in '{}'. Please make sure to also set these environment" + " variables: '{}'.".format(value, not_expanded)) + return expanded_vars yaml.SafeConstructor.add_constructor(u'!env_var', env_var_constructor) diff --git a/rasa_core/version.py b/rasa_core/version.py index 31ada01c987..0d07139dd67 100644 --- a/rasa_core/version.py +++ b/rasa_core/version.py @@ -1,2 +1,2 @@ -__version__ = '0.13.2' +__version__ = '0.14.0a1' diff --git a/rasa_core/visualize.py b/rasa_core/visualize.py index e1549b3a4f5..d98cc5da921 100644 --- a/rasa_core/visualize.py +++ b/rasa_core/visualize.py @@ -4,6 +4,7 @@ from typing import Text import rasa_core.cli.arguments +import rasa_core.cli.train from rasa_core import utils import rasa_core.cli @@ -60,9 +61,7 @@ def visualize(config_path: Text, domain_path: Text, stories_path: Text, args = arg_parser.parse_args() utils.configure_colored_logging(args.loglevel) - stories = rasa_core.cli.stories_from_cli_args(args) + stories = rasa_core.cli.train.stories_from_cli_args(args) visualize(args.config[0], args.domain, stories, args.nlu_data, args.output, args.max_history) - - diff --git a/tests/conftest.py b/tests/conftest.py index dc6f967800c..00e950a3364 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -96,7 +96,7 @@ def default_processor(default_domain, default_nlg): @pytest.fixture(scope="session") def trained_moodbot_path(): - train.train( + train( domain_file="examples/moodbot/domain.yml", stories_file="examples/moodbot/data/stories.md", output_path=MOODBOT_MODEL_PATH, diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 7cd3bd9ac48..150f5343215 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -1,11 +1,11 @@ import os -from rasa_core import test from rasa_core.test import ( test, + _generate_trackers, collect_story_predictions) -from tests.conftest import DEFAULT_STORIES_FILE, END_TO_END_STORY_FILE, \ - E2E_STORY_FILE_UNKNOWN_ENTITY +from tests.conftest import (DEFAULT_STORIES_FILE, END_TO_END_STORY_FILE, + E2E_STORY_FILE_UNKNOWN_ENTITY) # from tests.conftest import E2E_STORY_FILE_UNKNOWN_ENTITY @@ -28,7 +28,7 @@ def test_evaluation_image_creation(tmpdir, default_agent): def test_action_evaluation_script(tmpdir, default_agent): - completed_trackers = test._generate_trackers( + completed_trackers = _generate_trackers( DEFAULT_STORIES_FILE, default_agent, use_e2e=False) story_evaluation, num_stories = collect_story_predictions( completed_trackers, @@ -42,7 +42,7 @@ def test_action_evaluation_script(tmpdir, default_agent): def test_end_to_end_evaluation_script(tmpdir, default_agent): - completed_trackers = test._generate_trackers( + completed_trackers = _generate_trackers( END_TO_END_STORY_FILE, default_agent, use_e2e=True) story_evaluation, num_stories = collect_story_predictions( @@ -57,7 +57,7 @@ def test_end_to_end_evaluation_script(tmpdir, default_agent): def test_end_to_end_evaluation_script_unknown_entity(tmpdir, default_agent): - completed_trackers = test._generate_trackers( + completed_trackers = _generate_trackers( E2E_STORY_FILE_UNKNOWN_ENTITY, default_agent, use_e2e=True) story_evaluation, num_stories = collect_story_predictions( diff --git a/tests/test_server.py b/tests/test_server.py index 3c519496673..736afc728d0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -207,6 +207,20 @@ def test_predict(http_app, app): assert response.status_code == 200 +def test_sorted_predict(http_app, app): + client = RasaCoreClient(EndpointConfig(http_app)) + cid = str(uuid.uuid1()) + for event in test_events[:3]: + client.append_event_to_tracker(cid, event) + response = app.post("http://dummy/conversations/{}/predict".format(cid)) + content = response.get_json() + scores = content["scores"] + sorted_scores = sorted(scores, + key = lambda k: (-k['score'], + k['action'])) + assert scores == sorted_scores + + def test_evaluate(app): with io.open(DEFAULT_STORIES_FILE, 'r') as f: stories = f.read() diff --git a/tests/test_training.py b/tests/test_training.py index 05acef00ad6..891bb7d0363 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -10,11 +10,6 @@ from tests.conftest import DEFAULT_DOMAIN_PATH, DEFAULT_STORIES_FILE -def test_story_visualization_script(): - from rasa_core.visualize import create_argument_parser - assert create_argument_parser() is not None - - def test_story_visualization(default_domain, tmpdir): story_steps = StoryFileReader.read_from_file( "data/test_stories/stories.md", default_domain, @@ -86,7 +81,7 @@ def test_training_script_with_max_history_set(tmpdir): def test_training_script_with_restart_stories(tmpdir): train(DEFAULT_DOMAIN_PATH, - "data/test_stories/stories_restart.md", + "data/test_stories/stories_restart.md", tmpdir.strpath, interpreter=RegexInterpreter(), policy_config='data/test_config/max_hist_config.yml', @@ -96,8 +91,8 @@ def test_training_script_with_restart_stories(tmpdir): def configs_for_random_seed_test(): # define the configs for the random_seed tests - return [('data/test_config/keras_random_seed.yaml'), - ('data/test_config/embedding_random_seed.yaml')] + return ['data/test_config/keras_random_seed.yaml', + 'data/test_config/embedding_random_seed.yaml'] @pytest.mark.parametrize("config_file", configs_for_random_seed_test()) diff --git a/tests/test_utils.py b/tests/test_utils.py index f6d3cc7dbdf..06444b1541c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -142,6 +142,15 @@ def test_read_yaml_string_with_env_var(): assert r['user'] == 'user' and r['password'] == 'pass' +def test_read_yaml_string_with_multiple_env_vars_per_line(): + config_with_env_var = """ + user: ${USER_NAME} ${PASS} + password: ${PASS} + """ + r = utils.read_yaml_string(config_with_env_var) + assert r['user'] == 'user pass' and r['password'] == 'pass' + + def test_read_yaml_string_with_env_var_prefix(): config_with_env_var_prefix = """ user: db_${USER_NAME} @@ -174,5 +183,5 @@ def test_read_yaml_string_with_env_var_not_exist(): user: ${USER_NAME} password: ${PASSWORD} """ - with pytest.raises(KeyError): + with pytest.raises(ValueError): r = utils.read_yaml_string(config_with_env_var_not_exist)