Skip to content

Commit

Permalink
Merge branch 'master' into http-api-spec
Browse files Browse the repository at this point in the history
  • Loading branch information
tabergma committed Aug 19, 2019
2 parents fa066a0 + df88eb4 commit ffce592
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 16 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ Changed
- compare mode of ``rasa train core`` allows the whole core config comparison,
naming style of models trained for comparison is changed (this is a breaking change)
- Pika keeps a single connection open, instead of open and closing on each incoming event
- ``RasaChatInput`` fetches the public key from the Rasa X API. The key is used to
decode the bearer token containing the conversation ID. This requires
``rasa-x>=0.20.2``.

Fixed
-----
Expand Down
72 changes: 56 additions & 16 deletions rasa/core/channels/rasa_chat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Text, Optional
import json
from typing import Text, Optional, Dict

import aiohttp
import logging
from sanic.exceptions import abort
import jwt

from rasa.core import constants
from rasa.core.channels.channel import RestInput
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
from sanic.request import Request
Expand All @@ -27,35 +30,72 @@ def from_credentials(cls, credentials):

def __init__(self, url):
self.base_url = url
self.jwt_key = None
self.jwt_algorithm = None

async def _check_token(self, token):
url = "{}/auth/verify".format(self.base_url)
headers = {"Authorization": token}
logger.debug("Requesting user information from auth server {}.".format(url))

async def _fetch_public_key(self) -> None:
public_key_url = "{}/version".format(self.base_url)
async with aiohttp.ClientSession() as session:
async with session.get(
url, headers=headers, timeout=DEFAULT_REQUEST_TIMEOUT
public_key_url, timeout=DEFAULT_REQUEST_TIMEOUT
) as resp:
if resp.status == 200:
return await resp.json()
status_code = resp.status
if status_code != 200:
logger.error(
"Failed to fetch JWT public key from URL '{}' with "
"status code {}: {}"
"".format(public_key_url, status_code, await resp.text())
)
return
rjs = await resp.json()
public_key_field = "keys"
if public_key_field in rjs:
self.jwt_key = rjs["keys"][0]["key"]
self.jwt_algorithm = rjs["keys"][0]["alg"]
logger.debug(
"Fetched JWT public key from URL '{}' for algorithm '{}':\n{}"
"".format(public_key_url, self.jwt_algorithm, self.jwt_key)
)
else:
logger.info(
"Failed to check token: {}. "
"Content: {}".format(token, await resp.text())
logger.error(
"Retrieved json response from URL '{}' but could not find "
"'{}' field containing the JWT public key. Please make sure "
"you use an up-to-date version of Rasa X (>= 0.20.2). "
"Response was: {}"
"".format(public_key_url, public_key_field, json.dumps(rjs))
)
return None

def _decode_jwt(self, bearer_token: Text) -> Dict:
authorization_header_value = bearer_token.replace(
constants.BEARER_TOKEN_PREFIX, ""
)
return jwt.decode(
authorization_header_value, self.jwt_key, algorithms=self.jwt_algorithm
)

async def _decode_bearer_token(self, bearer_token: Text) -> Optional[Dict]:
if self.jwt_key is None:
await self._fetch_public_key()

# noinspection PyBroadException
try:
return self._decode_jwt(bearer_token)
except jwt.exceptions.InvalidSignatureError:
logger.error("JWT public key invalid, fetching new one.")
await self._fetch_public_key()
return self._decode_jwt(bearer_token)
except Exception:
logger.exception("Failed to decode bearer token.")

async def _extract_sender(self, req: Request) -> Optional[Text]:
"""Fetch user from the Rasa X Admin API"""

if req.headers.get("Authorization"):
user = await self._check_token(req.headers.get("Authorization"))

user = await self._decode_bearer_token(req.headers["Authorization"])
if user:
return user["username"]

user = await self._check_token(req.args.get("token", default=None))
user = await self._decode_bearer_token(req.args.get("token", default=None))
if user:
return user["username"]

Expand Down
2 changes: 2 additions & 0 deletions rasa/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@
USER_INTENT_OUT_OF_SCOPE = "out_of_scope"

ACTION_NAME_SENDER_ID_CONNECTOR_STR = "__sender_id:"

BEARER_TOKEN_PREFIX = "Bearer "
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ SQLAlchemy~=1.3.3
kafka-python==1.4.6
sklearn-crfsuite==0.3.6
psycopg2-binary==2.8.2
PyJWT==1.7.1
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"SQLAlchemy~=1.3.0",
"kafka-python~=1.4",
"sklearn-crfsuite~=0.3.6",
"PyJWT~=1.7",
]

extras_requires = {
Expand Down
19 changes: 19 additions & 0 deletions tests/core/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,22 @@ def test_extract_input_channel(test_input, expected):
fake_request.json = test_input

assert input_channel._extract_input_channel(fake_request) == expected


async def test_rasa_chat_input():
from rasa.core.channels import RasaChatInput

rasa_x_api_url = "https://rasa-x.com:5002"
rasa_chat_input = RasaChatInput(rasa_x_api_url)
public_key = "random_key123"
jwt_algorithm = "RS256"
with aioresponses() as mocked:
mocked.get(
rasa_x_api_url + "/version",
payload={"keys": [{"key": public_key, "alg": jwt_algorithm}]},
repeat=True,
status=200,
)
await rasa_chat_input._fetch_public_key()
assert rasa_chat_input.jwt_key == public_key
assert rasa_chat_input.jwt_algorithm == jwt_algorithm

0 comments on commit ffce592

Please sign in to comment.