From d99ad6bd89c35cd859f35e0185f319973cfeacd8 Mon Sep 17 00:00:00 2001 From: Tom Bocklisch Date: Wed, 12 Sep 2018 15:35:58 +0200 Subject: [PATCH] added socketio implementation --- rasa_core/channels/__init__.py | 4 +- rasa_core/channels/socketio.py | 144 +++++++++++++++++++++++++++++++++ rasa_core/run.py | 4 +- requirements.txt | 1 + setup.py | 3 +- 5 files changed, 153 insertions(+), 3 deletions(-) create mode 100644 rasa_core/channels/socketio.py diff --git a/rasa_core/channels/__init__.py b/rasa_core/channels/__init__.py index 67c6482a053..a9b6e4dd7e7 100644 --- a/rasa_core/channels/__init__.py +++ b/rasa_core/channels/__init__.py @@ -13,6 +13,8 @@ # this prevents IDE's from optimizing the imports - we need to import the # above first, otherwise we will run into import cycles +from rasa_core.channels.socketio import SocketIOInput + pass from rasa_core.channels.botframework import BotFrameworkInput @@ -29,7 +31,7 @@ input_channel_classes = [ CmdlineInput, FacebookInput, SlackInput, TelegramInput, MattermostInput, TwilioInput, RasaChatInput, BotFrameworkInput, RocketChatInput, - CallbackInput, RestInput + CallbackInput, RestInput, SocketIOInput ] # type: List[InputChannel] # Mapping from a input channel name to its class to allow name based lookup. diff --git a/rasa_core/channels/socketio.py b/rasa_core/channels/socketio.py new file mode 100644 index 00000000000..dd65977c332 --- /dev/null +++ b/rasa_core/channels/socketio.py @@ -0,0 +1,144 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import logging +import socketio +from flask import Blueprint, jsonify +from typing import Optional, Text + +from rasa_core.channels import InputChannel +from rasa_core.channels.channel import ( + UserMessage, + OutputChannel) + +logger = logging.getLogger(__name__) + + +class SocketBlueprint(Blueprint): + def __init__(self, sio, *args, **kwargs): + self.sio = sio + super(SocketBlueprint, self).__init__(*args, **kwargs) + + def register(self, app, options, first_registration=False): + app.wsgi_app = socketio.Middleware(self.sio, app.wsgi_app) + super(SocketBlueprint, self).register(app, options, first_registration) + + +class SocketIOOutput(OutputChannel): + + @classmethod + def name(cls): + return "socketio" + + def __init__(self, sio, bot_message_evt): + self.sio = sio + self.bot_message_evt = bot_message_evt + + def send(self, recipient_id, message): + # type: (Text, Any) -> None + """Sends a message to the recipient.""" + self.sio.emit(message, room=recipient_id) + + def _send_message(self, recipient_id, response): + # type: (Text, Any) -> None + """Sends a message to the recipient using the bot event.""" + self.sio.emit(self.bot_message_evt, response, room=recipient_id) + + def send_text_message(self, recipient_id, message): + # type: (Text, Text) -> None + """Send a message through this channel.""" + + self._send_message(recipient_id, {"text": message}) + + def send_image_url(self, recipient_id, image_url): + # type: (Text, Text) -> None + """Sends an image. Default will just post the url as a string.""" + message = { + "attachment": { + "type": "image", + "payload": {"src": image_url} + } + } + self._send_message(recipient_id, message) + + def send_text_with_buttons(self, recipient_id, text, buttons, **kwargs): + # type: (Text, Text, List[Dict[Text, Any]], **Any) -> None + """Sends buttons to the output.""" + + message = { + "text": text, + "quick_replies": [] + } + + for button in buttons: + message["quick_replies"].append({ + "content_type": "text", + "title": button['title'], + "payload": button['payload'] + }) + + self._send_message(recipient_id, message) + + def send_custom_message(self, recipient_id, elements): + # type: (Text, List[Dict[Text, Any]]) -> None + """Sends elements to the output.""" + + message = {"attachment": { + "type": "template", + "payload": { + "template_type": "generic", + "elements": elements[0] + }}} + + self._send_message(recipient_id, message) + + +class SocketIOInput(InputChannel): + """A socket.io input channel.""" + + @classmethod + def name(cls): + return "socketio" + + @classmethod + def from_credentials(cls, credentials): + credentials = credentials or {} + return cls(credentials.get("user_message_evt", "user_uttered"), + credentials.get("bot_message_evt", "bot_uttered"), + credentials.get("namespace")) + + def __init__(self, + user_message_evt="user_uttered", # type: Text + bot_message_evt="bot_uttered", # type: Text + namespace=None # type: Optional[Text] + ): + self.bot_message_evt = bot_message_evt + self.user_message_evt = user_message_evt + self.namespace = namespace + + def blueprint(self, on_new_message): + sio = socketio.Server() + socketio_webhook = SocketBlueprint(sio, 'socketio_webhook', __name__) + + @socketio_webhook.route("/", methods=['GET']) + def health(): + return jsonify({"status": "ok"}) + + @sio.on('connect', namespace=self.namespace) + def connect(sid, environ): + logger.debug("User {} connected to socketio endpoint.".format(sid)) + + @sio.on('disconnect', namespace=self.name()) + def disconnect(sid): + logger.debug("User {} disconnected from socketio endpoint." + "".format(sid)) + + @sio.on(self.user_message_evt, namespace=self.namespace) + def handle_message(sid, data): + output_channel = SocketIOOutput(sio, self.bot_message_evt) + message = UserMessage(data['message'], output_channel, sid) + on_new_message(message) + + return socketio_webhook diff --git a/rasa_core/run.py b/rasa_core/run.py index 3f20196eff2..96f5eb3f346 100644 --- a/rasa_core/run.py +++ b/rasa_core/run.py @@ -78,7 +78,6 @@ def create_argument_parser(): help="Configuration file for the connectors as a yml file") parser.add_argument( '-c', '--connector', - default="cmdline", choices=list(BUILTIN_CHANNELS.keys()), help="service to connect to") parser.add_argument( @@ -121,6 +120,9 @@ def read_endpoints(endpoint_file): def _create_external_channels(channel, credentials_file): # type: (Optional[Text], Optional[Text]) -> List[InputChannel] + if not channel and not credentials_file: + channel = "cmdline" + if credentials_file: all_credentials = read_yaml_file(credentials_file) else: diff --git a/requirements.txt b/requirements.txt index 6282e17bd14..e90070a16b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,3 +39,4 @@ pymongo==3.5.1 python-dateutil==2.7.3 rocketchat_API==0.6.22 flask-jwt-simple==0.0.3 +python-socketio==2.0.0 diff --git a/setup.py b/setup.py index 91c1ae82a5c..9619efdbda3 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,8 @@ "python-dateutil~=2.7", "rasa_nlu~=0.13.0", "rasa_core_sdk~=0.11.0", - "flask-jwt-simple~=0.0.3" + "flask-jwt-simple~=0.0.3", + "python-socketio~=2.0", ] extras_requires = {