Skip to content

Commit

Permalink
Merge branch 'master' into johannes-4088b
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes E. M. Mosig committed Feb 24, 2020
2 parents 5b18b28 + 1240405 commit 267ce89
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 9 deletions.
3 changes: 3 additions & 0 deletions changelog/5117.improvement.rst
@@ -0,0 +1,3 @@
New command-line argument --conversation-id will be added and wiil give the ability to
set specific conversation ID for each shell session, if not passed will be random.

11 changes: 10 additions & 1 deletion rasa/cli/shell.py
@@ -1,13 +1,13 @@
import argparse
import logging
import uuid

from typing import List

from rasa.cli.arguments import shell as arguments
from rasa.cli.utils import print_error
from rasa.exceptions import ModelNotFound


logger = logging.getLogger(__name__)


Expand All @@ -26,14 +26,23 @@ def add_subparser(
)
shell_parser.set_defaults(func=shell)

shell_parser.add_argument(
"--conversation-id",
default=uuid.uuid4().hex,
required=False,
help="Set the conversation ID.",
)

run_subparsers = shell_parser.add_subparsers()

shell_nlu_subparser = run_subparsers.add_parser(
"nlu",
parents=parents,
conflict_handler="resolve",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
help="Interprets messages on the command line using your NLU model.",
)

shell_nlu_subparser.set_defaults(func=shell_nlu)

arguments.set_shell_arguments(shell_parser)
Expand Down
3 changes: 1 addition & 2 deletions rasa/core/channels/console.py
Expand Up @@ -12,7 +12,6 @@
from rasa.cli import utils as cli_utils
from rasa.core import utils
from rasa.core.channels.channel import RestInput
from rasa.core.channels.channel import UserMessage
from rasa.core.constants import DEFAULT_SERVER_URL
from rasa.core.interpreter import INTENT_MESSAGE_PREFIX
from rasa.utils.io import DEFAULT_ENCODING
Expand Down Expand Up @@ -109,9 +108,9 @@ async def send_message_receive_stream(


async def record_messages(
sender_id,
server_url=DEFAULT_SERVER_URL,
auth_token="",
sender_id=UserMessage.DEFAULT_SENDER_ID,
max_message_limit=None,
use_response_stream=True,
) -> int:
Expand Down
8 changes: 7 additions & 1 deletion rasa/core/run.py
@@ -1,5 +1,6 @@
import asyncio
import logging
import uuid
import os
import shutil
from functools import partial
Expand Down Expand Up @@ -87,6 +88,7 @@ def configure_app(
port: int = constants.DEFAULT_SERVER_PORT,
endpoints: Optional[AvailableEndpoints] = None,
log_file: Optional[Text] = None,
conversation_id: Optional[Text] = uuid.uuid4().hex,
):
"""Run the agent."""
from rasa import server
Expand Down Expand Up @@ -124,8 +126,10 @@ async def configure_async_logging():
async def run_cmdline_io(running_app: Sanic):
"""Small wrapper to shut down the server once cmd io is done."""
await asyncio.sleep(1) # allow server to start

await console.record_messages(
server_url=constants.DEFAULT_SERVER_FORMAT.format("http", port)
server_url=constants.DEFAULT_SERVER_FORMAT.format("http", port),
sender_id=conversation_id,
)

logger.info("Killing Sanic server now.")
Expand Down Expand Up @@ -153,6 +157,7 @@ def serve_application(
ssl_keyfile: Optional[Text] = None,
ssl_ca_file: Optional[Text] = None,
ssl_password: Optional[Text] = None,
conversation_id: Optional[Text] = uuid.uuid4().hex,
):
from rasa import server

Expand All @@ -171,6 +176,7 @@ def serve_application(
port=port,
endpoints=endpoints,
log_file=log_file,
conversation_id=conversation_id,
)

ssl_context = server.create_ssl_context(
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/training/interactive.py
Expand Up @@ -1624,7 +1624,7 @@ def run_interactive_learning(
else:
p = None

app = run.configure_app(enable_api=True)
app = run.configure_app(enable_api=True, conversation_id="default")
endpoints = AvailableEndpoints.read_endpoints(server_args.get("endpoints"))

# before_server_start handlers make sure the agent is loaded before the
Expand Down
7 changes: 4 additions & 3 deletions tests/cli/test_rasa_shell.py
Expand Up @@ -5,9 +5,10 @@
def test_shell_help(run: Callable[..., RunResult]):
output = run("shell", "--help")

help_text = """usage: rasa shell [-h] [-v] [-vv] [--quiet] [-m MODEL] [--log-file LOG_FILE]
[--endpoints ENDPOINTS] [-p PORT] [-t AUTH_TOKEN]
[--cors [CORS [CORS ...]]] [--enable-api]
help_text = """usage: rasa shell [-h] [-v] [-vv] [--quiet]
[--conversation-id CONVERSATION_ID] [-m MODEL]
[--log-file LOG_FILE] [--endpoints ENDPOINTS] [-p PORT]
[-t AUTH_TOKEN] [--cors [CORS [CORS ...]]] [--enable-api]
[--remote-storage REMOTE_STORAGE]
[--ssl-certificate SSL_CERTIFICATE]
[--ssl-keyfile SSL_KEYFILE] [--ssl-ca-file SSL_CA_FILE]
Expand Down
4 changes: 3 additions & 1 deletion tests/core/test_channels.py
Expand Up @@ -114,7 +114,9 @@ async def test_console_input():
)

await console.record_messages(
server_url="https://example.com", max_message_limit=3
server_url="https://example.com",
max_message_limit=3,
sender_id="default",
)

r = latest_request(
Expand Down

0 comments on commit 267ce89

Please sign in to comment.