Skip to content

Commit

Permalink
Merge f7724ed into 84afa6e
Browse files Browse the repository at this point in the history
  • Loading branch information
wochinge committed Jul 11, 2019
2 parents 84afa6e + f7724ed commit fa5e2e3
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 11 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ This project adheres to `Semantic Versioning`_ starting with version 1.0.

Added
-----

- rest channel supports setting a message's input_channel through a field
``input_channel`` in the request body

Changed
-------
Expand Down
24 changes: 16 additions & 8 deletions rasa/core/channels/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,33 +379,38 @@ async def on_message_wrapper(
text: Text,
queue: Queue,
sender_id: Text,
input_channel,
) -> None:
collector = QueueOutputChannel(queue)

message = UserMessage(
text, collector, sender_id, input_channel=RestInput.name()
)
message = UserMessage(text, collector, sender_id, input_channel=input_channel)
await on_new_message(message)

await queue.put("DONE") # pytype: disable=bad-return-type

async def _extract_sender(self, req) -> Optional[Text]:
async def _extract_sender(self, req: Request) -> Optional[Text]:
return req.json.get("sender", None)

# noinspection PyMethodMayBeStatic
def _extract_message(self, req):
def _extract_message(self, req: Request) -> Optional[Text]:
return req.json.get("message", None)

def _extract_input_channel(self, req: Request) -> Text:
return req.json.get("input_channel") or self.name()

def stream_response(
self,
on_new_message: Callable[[UserMessage], Awaitable[None]],
text: Text,
sender_id: Text,
input_channel: Text,
) -> Callable[[Any], Awaitable[None]]:
async def stream(resp: Any) -> None:
q = Queue()
task = asyncio.ensure_future(
self.on_message_wrapper(on_new_message, text, q, sender_id)
self.on_message_wrapper(
on_new_message, text, q, sender_id, input_channel
)
)
while True:
result = await q.get() # pytype: disable=bad-return-type
Expand Down Expand Up @@ -435,10 +440,13 @@ async def receive(request: Request):
should_use_stream = rasa.utils.endpoints.bool_arg(
request, "stream", default=False
)
input_channel = self._extract_input_channel(request)

if should_use_stream:
return response.stream(
self.stream_response(on_new_message, text, sender_id),
self.stream_response(
on_new_message, text, sender_id, input_channel
),
content_type="text/event-stream",
)
else:
Expand All @@ -447,7 +455,7 @@ async def receive(request: Request):
try:
await on_new_message(
UserMessage(
text, collector, sender_id, input_channel=self.name()
text, collector, sender_id, input_channel=input_channel
)
)
except CancelledError:
Expand Down
5 changes: 4 additions & 1 deletion rasa/core/channels/rasa_chat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Dict

import aiohttp
import logging
from sanic.exceptions import abort

from rasa.core.channels.channel import RestInput
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
from sanic.request import Request

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,7 +46,7 @@ async def _check_token(self, token):
)
return None

async def _extract_sender(self, req):
async def _extract_sender(self, req: Request) -> Dict:
"""Fetch user from the Rasa X Admin API"""

if req.headers.get("Authorization"):
Expand Down
21 changes: 20 additions & 1 deletion tests/core/test_channels.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging
from unittest.mock import patch
from unittest.mock import patch, MagicMock

import pytest
import responses
Expand Down Expand Up @@ -784,3 +784,22 @@ def test_channel_registration_with_absolute_url_prefix_overwrites_route():
routes_list = utils.list_routes(app)
assert routes_list.get("custom_webhook_RestInput.health").startswith(test_route)
assert ignored_base_route not in routes_list.get("custom_webhook_RestInput.health")


@pytest.mark.parametrize(
"test_input, expected",
[
({}, "rest"),
({"input_channel": None}, "rest"),
({"input_channel": "custom"}, "custom"),
],
)
def test_extract_input_channel(test_input, expected):
from rasa.core.channels.channel import RestInput

input_channel = RestInput()

fake_request = MagicMock()
fake_request.json = test_input

assert input_channel._extract_input_channel(fake_request) == expected

0 comments on commit fa5e2e3

Please sign in to comment.