Skip to content

Commit

Permalink
Merge 4850add into 510d0fd
Browse files Browse the repository at this point in the history
  • Loading branch information
btotharye committed Nov 26, 2019
2 parents 510d0fd + 4850add commit dde8869
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 37 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -15,6 +15,9 @@ Added

Changed
-------
- Print info message when running Rasa X and a custom model server url was specified in ``endpoints.yml``
- If a ``wait_time_between_pulls`` is configured for the model server in ``endpoints.yml``,
this will be used instead of the default one when running Rasa X

Removed
-------
Expand Down
25 changes: 20 additions & 5 deletions rasa/cli/x.py
Expand Up @@ -2,6 +2,7 @@
import asyncio
import importlib.util
import logging
import warnings
import os
import signal
import traceback
Expand Down Expand Up @@ -112,12 +113,27 @@ def _overwrite_endpoints_for_local_x(
from rasa.utils.endpoints import EndpointConfig
import questionary

# Checking if endpoint.yml has existing url and wait time values set, if so give
# warning we are overwriting the endpoint.yml file.
custom_wait_time_pulls = endpoints.model.kwargs.get("wait_time_between_pulls")
custom_url = endpoints.model.url
default_rasax_model_server_url = (
f"{rasa_x_url}/projects/default/models/tag/production"
)

if custom_url != default_rasax_model_server_url:
warnings.warn(
f"Ignoring url '{custom_url}' from 'endpoints.yml' and using "
f"'{default_rasax_model_server_url}' instead."
)

endpoints.model = EndpointConfig(
f"{rasa_x_url}/projects/default/models/tags/production",
default_rasax_model_server_url,
token=rasa_x_token,
wait_time_between_pulls=2,
wait_time_between_pulls=custom_wait_time_pulls or 2,
)

overwrite_existing_event_broker = False
if endpoints.event_broker and not _is_correct_event_broker(endpoints.event_broker):
cli_utils.print_error(
"Rasa X currently only supports a SQLite event broker with path '{}' "
Expand All @@ -132,9 +148,8 @@ def _overwrite_endpoints_for_local_x(
if not overwrite_existing_event_broker:
exit(0)

endpoints.event_broker = EndpointConfig(
type="sql", db=DEFAULT_EVENTS_DB, dialect="sqlite"
)
if not endpoints.tracker_store or overwrite_existing_event_broker:
endpoints.event_broker = EndpointConfig(type="sql", db=DEFAULT_EVENTS_DB)


def _is_correct_event_broker(event_broker: EndpointConfig) -> bool:
Expand Down
73 changes: 41 additions & 32 deletions tests/cli/test_rasa_x.py
@@ -1,18 +1,18 @@
from pathlib import Path
from unittest.mock import Mock
import warnings

from typing import Callable, Dict, Text, Any
import pytest
from typing import Callable, Dict
from _pytest.pytester import RunResult
from _pytest.monkeypatch import MonkeyPatch
import questionary
from _pytest.logging import LogCaptureFixture


from aioresponses import aioresponses

import rasa.utils.io as io_utils
from rasa.cli import x
from rasa.core.utils import AvailableEndpoints
from rasa.utils.endpoints import EndpointConfig
from rasa.core.utils import AvailableEndpoints


def test_x_help(run: Callable[..., RunResult]):
Expand Down Expand Up @@ -65,33 +65,6 @@ def test_prepare_credentials_if_already_valid(tmpdir: Path):
assert actual == credentials


@pytest.mark.parametrize(
"event_broker",
[
# Event broker was not configured.
{},
# Event broker was explicitly configured to work with Rasa X in local mode.
{"type": "sql", "dialect": "sqlite", "db": x.DEFAULT_EVENTS_DB},
# Event broker was configured but the values are not compatible for running Rasa
# X in local mode.
{"type": "sql", "dialect": "postgresql"},
],
)
def test_overwrite_endpoints_for_local_x(
event_broker: Dict[Text, Any], monkeypatch: MonkeyPatch
):
confirm = Mock()
confirm.return_value.ask.return_value = True
monkeypatch.setattr(questionary, "confirm", confirm)

event_broker_config = EndpointConfig.from_dict(event_broker)
endpoints = AvailableEndpoints(event_broker=event_broker_config)

x._overwrite_endpoints_for_local_x(endpoints, "test-token", "http://localhost:5002")

assert x._is_correct_event_broker(endpoints.event_broker)


def test_if_endpoint_config_is_valid_in_local_mode():
config = EndpointConfig(type="sql", dialect="sqlite", db=x.DEFAULT_EVENTS_DB)

Expand All @@ -111,6 +84,42 @@ def test_if_endpoint_config_is_invalid_in_local_mode(kwargs: Dict):
assert not x._is_correct_event_broker(config)


def test_overwrite_model_server_url():
endpoint_config = EndpointConfig(url="http://testserver:5002/models/default@latest")
endpoints = AvailableEndpoints(model=endpoint_config)
with pytest.warns(UserWarning):
x._overwrite_endpoints_for_local_x(endpoints, "test", "http://localhost")
assert (
endpoints.model.url == "http://localhost/projects/default/models/tag/production"
)


def test_reuse_wait_time_between_pulls():
test_wait_time = 5
endpoint_config = EndpointConfig(
url="http://localhost:5002/models/default@latest",
wait_time_between_pulls=test_wait_time,
)
endpoints = AvailableEndpoints(model=endpoint_config)
assert endpoints.model.kwargs["wait_time_between_pulls"] == test_wait_time


def test_default_wait_time_between_pulls():
endpoint_config = EndpointConfig(url="http://localhost:5002/models/default@latest")
endpoints = AvailableEndpoints(model=endpoint_config)
x._overwrite_endpoints_for_local_x(endpoints, "test", "http://localhost")
assert endpoints.model.kwargs["wait_time_between_pulls"] == 2


def test_default_model_server_url():
endpoint_config = EndpointConfig()
endpoints = AvailableEndpoints(model=endpoint_config)
x._overwrite_endpoints_for_local_x(endpoints, "test", "http://localhost")
assert (
endpoints.model.url == "http://localhost/projects/default/models/tag/production"
)


async def test_pull_runtime_config_from_server():
config_url = "http://example.com/api/config?token=token"
credentials = "rasa: http://example.com:5002/api"
Expand Down

0 comments on commit dde8869

Please sign in to comment.