Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return most /authorize errors as a redirect #71

Merged
merged 3 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions aioauth/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,11 @@ class TemporarilyUnavailableError(OAuth2Error[TRequest]):
"""

error: Literal["temporarily_unavailable"] = "temporarily_unavailable"


class InvalidRedirectURIError(OAuth2Error[TRequest]):
"""
The requested redirect URI is missing or not allowed.
"""

error: Literal["invalid_request"] = "invalid_request"
8 changes: 5 additions & 3 deletions aioauth/grant_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
"""
from typing import Generic, Optional
from .errors import (
InvalidClientError,
InvalidGrantError,
InvalidRedirectURIError,
InvalidRequestError,
InvalidScopeError,
MismatchingStateError,
Expand Down Expand Up @@ -62,7 +64,7 @@ async def validate_request(self, request: TRequest) -> Client:
)

if not client:
raise InvalidRequestError[TRequest](
raise InvalidClientError[TRequest](
request=request, description="Invalid client_id parameter value."
)

Expand Down Expand Up @@ -96,12 +98,12 @@ async def validate_request(self, request: TRequest) -> Client:
client = await super().validate_request(request)

if not request.post.redirect_uri:
raise InvalidRequestError[TRequest](
raise InvalidRedirectURIError[TRequest](
request=request, description="Mismatching redirect URI."
)

if not client.check_redirect_uri(request.post.redirect_uri):
raise InvalidRequestError[TRequest](
raise InvalidRedirectURIError[TRequest](
request=request, description="Invalid redirect URI."
)

Expand Down
9 changes: 5 additions & 4 deletions aioauth/response_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .utils import generate_token
from .errors import (
InvalidClientError,
InvalidRedirectURIError,
InvalidRequestError,
InvalidScopeError,
UnsupportedResponseTypeError,
Expand All @@ -38,7 +39,7 @@ async def validate_request(self, request: TRequest) -> Client:
code_challenge_methods: List[CodeChallengeMethod] = ["plain", "S256"]

if not request.query.client_id:
raise InvalidRequestError[TRequest](
raise InvalidClientError[TRequest](
request=request, description="Missing client_id parameter."
)

Expand All @@ -47,17 +48,17 @@ async def validate_request(self, request: TRequest) -> Client:
)

if not client:
raise InvalidRequestError[TRequest](
raise InvalidClientError[TRequest](
request=request, description="Invalid client_id parameter value."
)

if not request.query.redirect_uri:
raise InvalidRequestError[TRequest](
raise InvalidRedirectURIError[TRequest](
request=request, description="Mismatching redirect URI."
)

if not client.check_redirect_uri(request.query.redirect_uri):
raise InvalidRequestError[TRequest](
raise InvalidRedirectURIError[TRequest](
request=request, description="Invalid redirect URI."
)

Expand Down
6 changes: 3 additions & 3 deletions aioauth/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def validate_request(self, request: TRequest, allowed_methods: List[RequestMetho
)
raise MethodNotAllowedError[TRequest](request=request, headers=headers)

@catch_errors_and_unavailability
@catch_errors_and_unavailability()
async def create_token_introspection_response(self, request: TRequest) -> Response:
"""
Returns a response object with introspection of the passed token.
Expand Down Expand Up @@ -219,7 +219,7 @@ def get_client_credentials(self, request: TRequest) -> Tuple[str, str]:

return client_id, client_secret

@catch_errors_and_unavailability
@catch_errors_and_unavailability()
async def create_token_response(self, request: TRequest) -> Response:
"""Endpoint to obtain an access and/or ID token by presenting an
authorization grant or refresh token.
Expand Down Expand Up @@ -290,7 +290,7 @@ async def token(request: fastapi.Request) -> fastapi.Response:
content=content, status_code=HTTPStatus.OK, headers=default_headers
)

@catch_errors_and_unavailability
@catch_errors_and_unavailability(redirect=True)
async def create_authorization_response(self, request: TRequest) -> Response:
"""
Endpoint to interact with the resource owner and obtain an
Expand Down
87 changes: 59 additions & 28 deletions aioauth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@
import random
import string
from base64 import b64decode, b64encode
from http import HTTPStatus
from typing import Any, Callable, Coroutine, Dict, List, Optional, Set, Tuple, Union
from urllib.parse import quote, urlencode, urlparse, urlunsplit

from .collections import HTTPHeaderDict
from .errors import (
InvalidClientError,
InvalidRedirectURIError,
MethodNotAllowedError,
OAuth2Error,
ServerError,
TemporarilyUnavailableError,
Expand Down Expand Up @@ -213,7 +217,9 @@ def create_s256_code_challenge(code_verifier: str) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()


def catch_errors_and_unavailability(f) -> Callable[..., Coroutine[Any, Any, Response]]:
def catch_errors_and_unavailability(
redirect=False,
) -> Callable[..., Callable[..., Coroutine[Any, Any, Response]]]:
"""
Decorator that adds error catching to the function passed.

Expand All @@ -223,30 +229,55 @@ def catch_errors_and_unavailability(f) -> Callable[..., Coroutine[Any, Any, Resp
A callable with error catching capabilities.
"""

@functools.wraps(f)
async def wrapper(self, request, *args, **kwargs) -> Response:
error: Union[TemporarilyUnavailableError, ServerError]

try:
response = await f(self, request, *args, **kwargs)
except OAuth2Error as exc:
content = ErrorResponse(error=exc.error, description=exc.description)
log.debug("%s %r", exc, request)
return Response(
content=asdict(content),
status_code=exc.status_code,
headers=exc.headers,
)
except Exception:
error = ServerError(request=request)
log.exception("Exception caught while processing request.")
content = ErrorResponse(error=error.error, description=error.description)
return Response(
content=asdict(content),
status_code=error.status_code,
headers=error.headers,
)

return response

return wrapper
non_redirect_exceptions = (
(MethodNotAllowedError, InvalidClientError, InvalidRedirectURIError)
if redirect
else (OAuth2Error,)
)

def decorator(f) -> Callable[..., Coroutine[Any, Any, Response]]:
@functools.wraps(f)
async def wrapper(self, request, *args, **kwargs) -> Response:
error: Union[TemporarilyUnavailableError, ServerError]

try:
response = await f(self, request, *args, **kwargs)
except non_redirect_exceptions as exc: # type: ignore
content = ErrorResponse(error=exc.error, description=exc.description)
log.debug("%s %r", exc, request)
return Response(
content=asdict(content),
status_code=exc.status_code,
headers=exc.headers,
)
except OAuth2Error as exc:
log.debug("%s %r", exc, request)
query: Dict[str, str] = {
"error": exc.error,
}
if exc.description:
query["error_description"] = exc.description
if request.settings.ERROR_URI:
query["error_uri"] = request.settings.ERROR_URI
location = build_uri(request.query.redirect_uri, query)
return Response(
status_code=HTTPStatus.FOUND,
headers=HTTPHeaderDict({"location": location}),
)
except Exception:
error = ServerError(request=request)
log.exception("Exception caught while processing request.")
content = ErrorResponse(
error=error.error, description=error.description
)
return Response(
content=asdict(content),
status_code=error.status_code,
headers=error.headers,
)

return response

return wrapper

return decorator
2 changes: 1 addition & 1 deletion tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, available: Optional[bool] = None):
if available is not None:
self.available = available

@catch_errors_and_unavailability
@catch_errors_and_unavailability()
async def server(self, request):
raise Exception()

Expand Down
14 changes: 9 additions & 5 deletions tests/test_request_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import replace
from http import HTTPStatus
from typing import Dict, List
from urllib.parse import urlparse, parse_qs

import pytest

Expand All @@ -25,7 +26,9 @@ async def test_insecure_transport_error(server: AuthorizationServer):
request = Request(url=request_url, method="GET")

response = await server.create_authorization_response(request)
assert response.status_code == HTTPStatus.BAD_REQUEST
assert response.status_code == HTTPStatus.FOUND
query_params = parse_qs(urlparse(response.headers["Location"]).query)
assert query_params["error"] == ["insecure_transport"]


@pytest.mark.asyncio
Expand Down Expand Up @@ -59,8 +62,8 @@ async def test_invalid_client_credentials(
)

response = await server.create_token_response(request)
assert response.status_code == HTTPStatus.BAD_REQUEST
assert response.content["error"] == "invalid_request"
assert response.status_code == HTTPStatus.UNAUTHORIZED
assert response.content["error"] == "invalid_client"


@pytest.mark.asyncio
Expand Down Expand Up @@ -153,8 +156,9 @@ async def test_invalid_response_type(
user=user,
)
response = await server.create_authorization_response(request)
assert response.status_code == HTTPStatus.BAD_REQUEST
assert response.content["error"] == "unsupported_response_type"
assert response.status_code == HTTPStatus.FOUND
query_params = parse_qs(urlparse(response.headers["Location"]).query)
assert query_params["error"] == ["unsupported_response_type"]


@pytest.mark.asyncio
Expand Down
Loading