diff --git a/rasa/cli/x.py b/rasa/cli/x.py index 487b6c5acc4b..1c7c56818958 100644 --- a/rasa/cli/x.py +++ b/rasa/cli/x.py @@ -150,6 +150,8 @@ def _overwrite_endpoints_for_local_x( if not endpoints.tracker_store or overwrite_existing_event_broker: endpoints.event_broker = EndpointConfig(type="sql", db=DEFAULT_EVENTS_DB) + return endpoints + def _is_correct_event_broker(event_broker: EndpointConfig) -> bool: return all( diff --git a/tests/cli/test_rasa_x.py b/tests/cli/test_rasa_x.py index 6b6da6cbfa1c..d254bebfbd0e 100644 --- a/tests/cli/test_rasa_x.py +++ b/tests/cli/test_rasa_x.py @@ -9,6 +9,7 @@ import rasa.utils.io as io_utils from rasa.cli import x from rasa.utils.endpoints import EndpointConfig +from rasa.core.utils import AvailableEndpoints def test_x_help(run: Callable[..., RunResult]): @@ -79,6 +80,14 @@ def test_if_endpoint_config_is_invalid_in_local_mode(kwargs: Dict): config = EndpointConfig(**kwargs) assert not x._is_correct_event_broker(config) +def test_wait_time_between_pulls_custom(): + #endpoint_config = EndpointConfig(url="http://localhost:5002/api/projects/default/models/tag/production", wait_time_between_pulls=3) + endpoint_config = EndpointConfig(url="http://testserver:5002/models/default@latest", wait_time_between_pulls=5) + endpoints = AvailableEndpoints(model=endpoint_config) + + updated_endpoints = x._overwrite_endpoints_for_local_x(endpoints, "test", "http://localhost") + updated_config = updated_endpoints.model + assert updated_config.url == 'http://localhost/projects/default/models/tag/production' async def test_pull_runtime_config_from_server(): config_url = "http://example.com/api/config?token=token" @@ -99,6 +108,7 @@ async def test_pull_runtime_config_from_server(): endpoints_path, credentials_path = await x._pull_runtime_config_from_server( config_url, 1, 0 ) + with open(endpoints_path) as f: assert f.read() == endpoint_config