diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9aa9d85b19d..a4621ddfa1a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,7 @@ Added ----- - JWT support: parameters to allow clients to authenticate requests to the rasa_core.server using JWT's in addition to normal token based auth +- added socket.io input / output channel Changed ------- diff --git a/docs/connectors.rst b/docs/connectors.rst index 52c000c7af4..f7ecd66e964 100644 --- a/docs/connectors.rst +++ b/docs/connectors.rst @@ -491,6 +491,51 @@ the port. The endpoint for receiving botframework channel messages is ``/webhooks/botframework/webhook``. This is the url you should add in your microsoft bot service configuration. +.. _socketio_connector: + +SocketIO Setup +-------------- + +You can **either** attach the input channel running the provided +``rasa_core.run`` script, or you can attach the channel in your +own code. + +Using run script +^^^^^^^^^^^^^^^^ + +If you want to connect the socketio input channel using the run +script, e.g. using: + +.. code-block:: bash + + python -m rasa_core.run -d models/dialogue -u models/nlu/current + --port 5002 --credentials credentials.yml + +you need to supply a ``credentials.yml`` with the following content: + +.. code-block:: yaml + + socketio: + user_message_evt: user_uttered + bot_message_evt: bot_uttered + +These two configuration values define the event names used by Rasa Core +when sending or receiving messages over socket.io. + +Directly using python +^^^^^^^^^^^^^^^^^^^^^ + +Code to create a Socket.IO-compatible webserver looks like this: + +.. literalinclude:: ../tests/test_channels.py + :pyobject: test_socketio_channel + :lines: 2- + :end-before: END DOC INCLUDE + +The arguments for the ``handle_channels`` are the input channels and +the port. Once started, you should be able to connect to +``http://localhost:5005`` with your socket.io client. + .. _ngrok: Using Ngrok For Local Testing diff --git a/rasa_core/channels/__init__.py b/rasa_core/channels/__init__.py index 67c6482a053..a9b6e4dd7e7 100644 --- a/rasa_core/channels/__init__.py +++ b/rasa_core/channels/__init__.py @@ -13,6 +13,8 @@ # this prevents IDE's from optimizing the imports - we need to import the # above first, otherwise we will run into import cycles +from rasa_core.channels.socketio import SocketIOInput + pass from rasa_core.channels.botframework import BotFrameworkInput @@ -29,7 +31,7 @@ input_channel_classes = [ CmdlineInput, FacebookInput, SlackInput, TelegramInput, MattermostInput, TwilioInput, RasaChatInput, BotFrameworkInput, RocketChatInput, - CallbackInput, RestInput + CallbackInput, RestInput, SocketIOInput ] # type: List[InputChannel] # Mapping from a input channel name to its class to allow name based lookup. diff --git a/rasa_core/channels/socketio.py b/rasa_core/channels/socketio.py new file mode 100644 index 00000000000..2efd7dc54ea --- /dev/null +++ b/rasa_core/channels/socketio.py @@ -0,0 +1,144 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import logging +import socketio +from flask import Blueprint, jsonify +from typing import Optional, Text + +from rasa_core.channels import InputChannel +from rasa_core.channels.channel import ( + UserMessage, + OutputChannel) + +logger = logging.getLogger(__name__) + + +class SocketBlueprint(Blueprint): + def __init__(self, sio, *args, **kwargs): + self.sio = sio + super(SocketBlueprint, self).__init__(*args, **kwargs) + + def register(self, app, options, first_registration=False): + app.wsgi_app = socketio.Middleware(self.sio, app.wsgi_app) + super(SocketBlueprint, self).register(app, options, first_registration) + + +class SocketIOOutput(OutputChannel): + + @classmethod + def name(cls): + return "socketio" + + def __init__(self, sio, bot_message_evt): + self.sio = sio + self.bot_message_evt = bot_message_evt + + def send(self, recipient_id, message): + # type: (Text, Any) -> None + """Sends a message to the recipient.""" + self.sio.emit(message, room=recipient_id) + + def _send_message(self, recipient_id, response): + # type: (Text, Any) -> None + """Sends a message to the recipient using the bot event.""" + self.sio.emit(self.bot_message_evt, response, room=recipient_id) + + def send_text_message(self, recipient_id, message): + # type: (Text, Text) -> None + """Send a message through this channel.""" + + self._send_message(recipient_id, {"text": message}) + + def send_image_url(self, recipient_id, image_url): + # type: (Text, Text) -> None + """Sends an image. Default will just post the url as a string.""" + message = { + "attachment": { + "type": "image", + "payload": {"src": image_url} + } + } + self._send_message(recipient_id, message) + + def send_text_with_buttons(self, recipient_id, text, buttons, **kwargs): + # type: (Text, Text, List[Dict[Text, Any]], **Any) -> None + """Sends buttons to the output.""" + + message = { + "text": text, + "quick_replies": [] + } + + for button in buttons: + message["quick_replies"].append({ + "content_type": "text", + "title": button['title'], + "payload": button['payload'] + }) + + self._send_message(recipient_id, message) + + def send_custom_message(self, recipient_id, elements): + # type: (Text, List[Dict[Text, Any]]) -> None + """Sends elements to the output.""" + + message = {"attachment": { + "type": "template", + "payload": { + "template_type": "generic", + "elements": elements[0] + }}} + + self._send_message(recipient_id, message) + + +class SocketIOInput(InputChannel): + """A socket.io input channel.""" + + @classmethod + def name(cls): + return "socketio" + + @classmethod + def from_credentials(cls, credentials): + credentials = credentials or {} + return cls(credentials.get("user_message_evt", "user_uttered"), + credentials.get("bot_message_evt", "bot_uttered"), + credentials.get("namespace")) + + def __init__(self, + user_message_evt="user_uttered", # type: Text + bot_message_evt="bot_uttered", # type: Text + namespace=None # type: Optional[Text] + ): + self.bot_message_evt = bot_message_evt + self.user_message_evt = user_message_evt + self.namespace = namespace + + def blueprint(self, on_new_message): + sio = socketio.Server() + socketio_webhook = SocketBlueprint(sio, 'socketio_webhook', __name__) + + @socketio_webhook.route("/", methods=['GET']) + def health(): + return jsonify({"status": "ok"}) + + @sio.on('connect', namespace=self.namespace) + def connect(sid, environ): + logger.debug("User {} connected to socketio endpoint.".format(sid)) + + @sio.on('disconnect', namespace=self.namespace) + def disconnect(sid): + logger.debug("User {} disconnected from socketio endpoint." + "".format(sid)) + + @sio.on(self.user_message_evt, namespace=self.namespace) + def handle_message(sid, data): + output_channel = SocketIOOutput(sio, self.bot_message_evt) + message = UserMessage(data['message'], output_channel, sid) + on_new_message(message) + + return socketio_webhook diff --git a/rasa_core/run.py b/rasa_core/run.py index 3f20196eff2..96f5eb3f346 100644 --- a/rasa_core/run.py +++ b/rasa_core/run.py @@ -78,7 +78,6 @@ def create_argument_parser(): help="Configuration file for the connectors as a yml file") parser.add_argument( '-c', '--connector', - default="cmdline", choices=list(BUILTIN_CHANNELS.keys()), help="service to connect to") parser.add_argument( @@ -121,6 +120,9 @@ def read_endpoints(endpoint_file): def _create_external_channels(channel, credentials_file): # type: (Optional[Text], Optional[Text]) -> List[InputChannel] + if not channel and not credentials_file: + channel = "cmdline" + if credentials_file: all_credentials = read_yaml_file(credentials_file) else: diff --git a/requirements.txt b/requirements.txt index 6282e17bd14..e90070a16b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,3 +39,4 @@ pymongo==3.5.1 python-dateutil==2.7.3 rocketchat_API==0.6.22 flask-jwt-simple==0.0.3 +python-socketio==2.0.0 diff --git a/setup.py b/setup.py index 91c1ae82a5c..9619efdbda3 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,8 @@ "python-dateutil~=2.7", "rasa_nlu~=0.13.0", "rasa_core_sdk~=0.11.0", - "flask-jwt-simple~=0.0.3" + "flask-jwt-simple~=0.0.3", + "python-socketio~=2.0", ] extras_requires = { diff --git a/tests/test_channels.py b/tests/test_channels.py index 54ee2945248..dcd4fc41e1f 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -320,6 +320,38 @@ def test_callback_channel(): s.stop() +# USED FOR DOCS - don't rename without changing in the docs +def test_socketio_channel(): + from rasa_core.channels.socketio import SocketIOInput + from rasa_core.agent import Agent + from rasa_core.interpreter import RegexInterpreter + + # load your trained agent + agent = Agent.load(MODEL_PATH, interpreter=RegexInterpreter()) + + input_channel = SocketIOInput( + # event name for messages sent from the user + user_message_evt="user_uttered", + # event name for messages sent from the bot + bot_message_evt="bot_uttered", + # socket.io namespace to use for the messages + namespace=None + ) + + # set serve_forever=False if you want to keep the server running + s = agent.handle_channels([input_channel], 5004, serve_forever=False) + # END DOC INCLUDE + # the above marker marks the end of the code snipped included + # in the docs + try: + assert s.started + routes_list = utils.list_routes(s.application) + assert routes_list.get("/webhooks/socketio/").startswith( + 'socketio_webhook.health') + finally: + s.stop() + + def test_callback_calls_endpoint(): from rasa_core.channels.callback import CallbackOutput