Skip to content

Commit

Permalink
support new async and old sync versions of 'sanic-jwt
Browse files Browse the repository at this point in the history
  • Loading branch information
wochinge committed Jan 12, 2021
1 parent efa2a19 commit 9d398e1
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
20 changes: 15 additions & 5 deletions rasa/server.py
Expand Up @@ -32,6 +32,7 @@

import rasa
import rasa.core.utils
import rasa.utils.common
import rasa.shared.utils.common
import rasa.shared.utils.io
import rasa.utils.endpoints
Expand Down Expand Up @@ -170,8 +171,14 @@ def conversation_id_from_args(args: Any, kwargs: Any) -> Optional[Text]:
except ValueError:
return None

def sufficient_scope(request, *args: Any, **kwargs: Any) -> Optional[bool]:
jwt_data = request.app.auth.extract_payload(request)
async def sufficient_scope(
request, *args: Any, **kwargs: Any
) -> Optional[bool]:
# This is a coroutine since `sanic-jwt==1.6`
jwt_data = await rasa.utils.common.call_potential_coroutine(
request.app.auth.extract_payload(request)
)

user = jwt_data.get("user", {})

username = user.get("username", None)
Expand All @@ -196,10 +203,13 @@ async def decorated(request: Request, *args: Any, **kwargs: Any) -> Any:
if isawaitable(result):
result = await result
return result
elif app.config.get("USE_JWT") and request.app.auth.is_authenticated(
request
elif app.config.get(
"USE_JWT"
) and await rasa.utils.common.call_potential_coroutine(
# This is a coroutine since `sanic-jwt==1.6`
request.app.auth.is_authenticated(request)
):
if sufficient_scope(request, *args, **kwargs):
if await sufficient_scope(request, *args, **kwargs):
result = f(request, *args, **kwargs)
if isawaitable(result):
result = await result
Expand Down
20 changes: 19 additions & 1 deletion rasa/utils/common.py
Expand Up @@ -4,7 +4,7 @@
import shutil
import warnings
from types import TracebackType
from typing import Any, Coroutine, Dict, List, Optional, Text, Type, TypeVar
from typing import Any, Coroutine, Dict, List, Optional, Text, Type, TypeVar, Union

import rasa.core.utils
import rasa.utils.io
Expand Down Expand Up @@ -312,3 +312,21 @@ def run_in_loop(
loop.run_until_complete(asyncio.gather(*pending))

return result


async def call_potential_coroutine(
coroutine_or_return_value: Union[Any, Coroutine]
) -> Any:
"""Awaits coroutine or returns value directly if it's not a coroutine.
Args:
coroutine_or_return_value: Either the return value of a synchronous function
call or a coroutine which needs to be await first.
Returns:
The return value of the function.
"""
if asyncio.iscoroutine(coroutine_or_return_value):
return await coroutine_or_return_value

return coroutine_or_return_value
24 changes: 24 additions & 0 deletions tests/utils/test_common.py
@@ -1,5 +1,7 @@
import logging
from typing import Any

import rasa.utils.common
from rasa.utils.common import RepeatedLogFilter


Expand All @@ -22,3 +24,25 @@ def test_repeated_log_filter():
assert log_filter.filter(record2_other_args) is True
assert log_filter.filter(record3_other) is True
assert log_filter.filter(record1) is True # same as before, but not repeated


async def test_call_maybe_coroutine_with_async() -> Any:
expected = 5

async def my_function():
return expected

actual = await rasa.utils.common.call_potential_coroutine(my_function())

assert actual == expected


async def test_call_maybe_coroutine_with_sync() -> Any:
expected = 5

def my_function():
return expected

actual = await rasa.utils.common.call_potential_coroutine(my_function())

assert actual == expected

0 comments on commit 9d398e1

Please sign in to comment.