From 13b553bfa8d29f5837830cd7ad59e718e14a15bb Mon Sep 17 00:00:00 2001 From: Michael Gorven Date: Thu, 2 Feb 2023 13:05:04 -0800 Subject: [PATCH 1/2] Add InvalidRedirectURIError, use InvalidClientError for all client ID/secret errors --- aioauth/errors.py | 8 ++++++++ aioauth/grant_type.py | 8 +++++--- aioauth/response_type.py | 9 +++++---- tests/test_request_validator.py | 4 ++-- tests/utils.py | 16 ++++++++-------- 5 files changed, 28 insertions(+), 17 deletions(-) diff --git a/aioauth/errors.py b/aioauth/errors.py index ad01c40..93ccea8 100644 --- a/aioauth/errors.py +++ b/aioauth/errors.py @@ -171,3 +171,11 @@ class TemporarilyUnavailableError(OAuth2Error[TRequest]): """ error: Literal["temporarily_unavailable"] = "temporarily_unavailable" + + +class InvalidRedirectURIError(OAuth2Error[TRequest]): + """ + The requested redirect URI is missing or not allowed. + """ + + error: Literal["invalid_request"] = "invalid_request" diff --git a/aioauth/grant_type.py b/aioauth/grant_type.py index 5cdffa4..e39ccfd 100644 --- a/aioauth/grant_type.py +++ b/aioauth/grant_type.py @@ -9,7 +9,9 @@ """ from typing import Generic from .errors import ( + InvalidClientError, InvalidGrantError, + InvalidRedirectURIError, InvalidRequestError, InvalidScopeError, MismatchingStateError, @@ -58,7 +60,7 @@ async def validate_request(self, request: TRequest) -> Client: ) if not client: - raise InvalidRequestError[TRequest]( + raise InvalidClientError[TRequest]( request=request, description="Invalid client_id parameter value." ) @@ -91,12 +93,12 @@ async def validate_request(self, request: TRequest) -> Client: client = await super().validate_request(request) if not request.post.redirect_uri: - raise InvalidRequestError[TRequest]( + raise InvalidRedirectURIError[TRequest]( request=request, description="Mismatching redirect URI." ) if not client.check_redirect_uri(request.post.redirect_uri): - raise InvalidRequestError[TRequest]( + raise InvalidRedirectURIError[TRequest]( request=request, description="Invalid redirect URI." ) diff --git a/aioauth/response_type.py b/aioauth/response_type.py index bf199d2..c32b6f1 100644 --- a/aioauth/response_type.py +++ b/aioauth/response_type.py @@ -12,6 +12,7 @@ from .utils import generate_token from .errors import ( InvalidClientError, + InvalidRedirectURIError, InvalidRequestError, InvalidScopeError, UnsupportedResponseTypeError, @@ -38,7 +39,7 @@ async def validate_request(self, request: TRequest) -> Client: code_challenge_methods: List[CodeChallengeMethod] = ["plain", "S256"] if not request.query.client_id: - raise InvalidRequestError[TRequest]( + raise InvalidClientError[TRequest]( request=request, description="Missing client_id parameter." ) @@ -47,17 +48,17 @@ async def validate_request(self, request: TRequest) -> Client: ) if not client: - raise InvalidRequestError[TRequest]( + raise InvalidClientError[TRequest]( request=request, description="Invalid client_id parameter value." ) if not request.query.redirect_uri: - raise InvalidRequestError[TRequest]( + raise InvalidRedirectURIError[TRequest]( request=request, description="Mismatching redirect URI." ) if not client.check_redirect_uri(request.query.redirect_uri): - raise InvalidRequestError[TRequest]( + raise InvalidRedirectURIError[TRequest]( request=request, description="Invalid redirect URI." ) diff --git a/tests/test_request_validator.py b/tests/test_request_validator.py index a4e49dd..e0fb9f4 100644 --- a/tests/test_request_validator.py +++ b/tests/test_request_validator.py @@ -59,8 +59,8 @@ async def test_invalid_client_credentials( ) response = await server.create_token_response(request) - assert response.status_code == HTTPStatus.BAD_REQUEST - assert response.content["error"] == "invalid_request" + assert response.status_code == HTTPStatus.UNAUTHORIZED + assert response.content["error"] == "invalid_client" @pytest.mark.asyncio diff --git a/tests/utils.py b/tests/utils.py index 8a540f1..4424e0b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,11 +12,11 @@ "client_id": Response( content=asdict( ErrorResponse( - error="invalid_request", + error="invalid_client", description="Missing client_id parameter.", ) ), - status_code=HTTPStatus.BAD_REQUEST, + status_code=HTTPStatus.UNAUTHORIZED, headers=default_headers, ), "response_type": Response( @@ -159,11 +159,11 @@ "client_id": Response( content=asdict( ErrorResponse( - error="invalid_request", + error="invalid_client", description="Invalid client_id parameter value.", ) ), - status_code=HTTPStatus.BAD_REQUEST, + status_code=HTTPStatus.UNAUTHORIZED, headers=default_headers, ), "response_type": Response( @@ -261,21 +261,21 @@ "client_id": Response( content=asdict( ErrorResponse( - error="invalid_request", + error="invalid_client", description="Invalid client_id parameter value.", ) ), - status_code=HTTPStatus.BAD_REQUEST, + status_code=HTTPStatus.UNAUTHORIZED, headers=default_headers, ), "client_secret": Response( content=asdict( ErrorResponse( - error="invalid_request", + error="invalid_client", description="Invalid client_id parameter value.", ) ), - status_code=HTTPStatus.BAD_REQUEST, + status_code=HTTPStatus.UNAUTHORIZED, headers=default_headers, ), "username": Response( From d6553b7c086551f3c3b5dcb4c2f72ef39b2d3129 Mon Sep 17 00:00:00 2001 From: Michael Gorven Date: Thu, 2 Feb 2023 13:17:35 -0800 Subject: [PATCH 2/2] Return most /authorize errors as a redirect The spec defines that most errors to `/authorize` should be returned to the redirect URI with `error` and `error_description` query params: If the resource owner denies the access request or if the request fails for reasons other than a missing or invalid redirection URI, the authorization server informs the client by adding the following parameters to the query component of the redirection URI using the "application/x-www-form-urlencoded" format, per Appendix B: Add `InvalidRedirectURIError` and use `InvalidClientError` for all `client_id` errors since these should not redirect. Update `catch_errors_and_unavailability` to optionally return errors as redirects. --- aioauth/server.py | 6 +-- aioauth/utils.py | 87 ++++++++++++++++++++++----------- tests/test_endpoint.py | 2 +- tests/test_request_validator.py | 10 ++-- tests/utils.py | 66 +++++++++---------------- 5 files changed, 94 insertions(+), 77 deletions(-) diff --git a/aioauth/server.py b/aioauth/server.py index 4cb70f0..ca7ca9b 100644 --- a/aioauth/server.py +++ b/aioauth/server.py @@ -125,7 +125,7 @@ def validate_request(self, request: TRequest, allowed_methods: List[RequestMetho ) raise MethodNotAllowedError[TRequest](request=request, headers=headers) - @catch_errors_and_unavailability + @catch_errors_and_unavailability() async def create_token_introspection_response(self, request: TRequest) -> Response: """ Returns a response object with introspection of the passed token. @@ -219,7 +219,7 @@ def get_client_credentials(self, request: TRequest) -> Tuple[str, str]: return client_id, client_secret - @catch_errors_and_unavailability + @catch_errors_and_unavailability() async def create_token_response(self, request: TRequest) -> Response: """Endpoint to obtain an access and/or ID token by presenting an authorization grant or refresh token. @@ -290,7 +290,7 @@ async def token(request: fastapi.Request) -> fastapi.Response: content=content, status_code=HTTPStatus.OK, headers=default_headers ) - @catch_errors_and_unavailability + @catch_errors_and_unavailability(redirect=True) async def create_authorization_response(self, request: TRequest) -> Response: """ Endpoint to interact with the resource owner and obtain an diff --git a/aioauth/utils.py b/aioauth/utils.py index 48081be..13ef3e6 100644 --- a/aioauth/utils.py +++ b/aioauth/utils.py @@ -19,11 +19,15 @@ import random import string from base64 import b64decode, b64encode +from http import HTTPStatus from typing import Any, Callable, Coroutine, Dict, List, Optional, Set, Tuple, Union from urllib.parse import quote, urlencode, urlparse, urlunsplit from .collections import HTTPHeaderDict from .errors import ( + InvalidClientError, + InvalidRedirectURIError, + MethodNotAllowedError, OAuth2Error, ServerError, TemporarilyUnavailableError, @@ -213,7 +217,9 @@ def create_s256_code_challenge(code_verifier: str) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode() -def catch_errors_and_unavailability(f) -> Callable[..., Coroutine[Any, Any, Response]]: +def catch_errors_and_unavailability( + redirect=False, +) -> Callable[..., Callable[..., Coroutine[Any, Any, Response]]]: """ Decorator that adds error catching to the function passed. @@ -223,30 +229,55 @@ def catch_errors_and_unavailability(f) -> Callable[..., Coroutine[Any, Any, Resp A callable with error catching capabilities. """ - @functools.wraps(f) - async def wrapper(self, request, *args, **kwargs) -> Response: - error: Union[TemporarilyUnavailableError, ServerError] - - try: - response = await f(self, request, *args, **kwargs) - except OAuth2Error as exc: - content = ErrorResponse(error=exc.error, description=exc.description) - log.debug("%s %r", exc, request) - return Response( - content=asdict(content), - status_code=exc.status_code, - headers=exc.headers, - ) - except Exception: - error = ServerError(request=request) - log.exception("Exception caught while processing request.") - content = ErrorResponse(error=error.error, description=error.description) - return Response( - content=asdict(content), - status_code=error.status_code, - headers=error.headers, - ) - - return response - - return wrapper + non_redirect_exceptions = ( + (MethodNotAllowedError, InvalidClientError, InvalidRedirectURIError) + if redirect + else (OAuth2Error,) + ) + + def decorator(f) -> Callable[..., Coroutine[Any, Any, Response]]: + @functools.wraps(f) + async def wrapper(self, request, *args, **kwargs) -> Response: + error: Union[TemporarilyUnavailableError, ServerError] + + try: + response = await f(self, request, *args, **kwargs) + except non_redirect_exceptions as exc: # type: ignore + content = ErrorResponse(error=exc.error, description=exc.description) + log.debug("%s %r", exc, request) + return Response( + content=asdict(content), + status_code=exc.status_code, + headers=exc.headers, + ) + except OAuth2Error as exc: + log.debug("%s %r", exc, request) + query: Dict[str, str] = { + "error": exc.error, + } + if exc.description: + query["error_description"] = exc.description + if request.settings.ERROR_URI: + query["error_uri"] = request.settings.ERROR_URI + location = build_uri(request.query.redirect_uri, query) + return Response( + status_code=HTTPStatus.FOUND, + headers=HTTPHeaderDict({"location": location}), + ) + except Exception: + error = ServerError(request=request) + log.exception("Exception caught while processing request.") + content = ErrorResponse( + error=error.error, description=error.description + ) + return Response( + content=asdict(content), + status_code=error.status_code, + headers=error.headers, + ) + + return response + + return wrapper + + return decorator diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 53c18c1..b650e59 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -27,7 +27,7 @@ def __init__(self, available: Optional[bool] = None): if available is not None: self.available = available - @catch_errors_and_unavailability + @catch_errors_and_unavailability() async def server(self, request): raise Exception() diff --git a/tests/test_request_validator.py b/tests/test_request_validator.py index e0fb9f4..b363d72 100644 --- a/tests/test_request_validator.py +++ b/tests/test_request_validator.py @@ -2,6 +2,7 @@ from dataclasses import replace from http import HTTPStatus from typing import Dict, List +from urllib.parse import urlparse, parse_qs import pytest @@ -25,7 +26,9 @@ async def test_insecure_transport_error(server: AuthorizationServer): request = Request(url=request_url, method="GET") response = await server.create_authorization_response(request) - assert response.status_code == HTTPStatus.BAD_REQUEST + assert response.status_code == HTTPStatus.FOUND + query_params = parse_qs(urlparse(response.headers["Location"]).query) + assert query_params["error"] == ["insecure_transport"] @pytest.mark.asyncio @@ -153,8 +156,9 @@ async def test_invalid_response_type( user=user, ) response = await server.create_authorization_response(request) - assert response.status_code == HTTPStatus.BAD_REQUEST - assert response.content["error"] == "unsupported_response_type" + assert response.status_code == HTTPStatus.FOUND + query_params = parse_qs(urlparse(response.headers["Location"]).query) + assert query_params["error"] == ["unsupported_response_type"] @pytest.mark.asyncio diff --git a/tests/utils.py b/tests/utils.py index 4424e0b..4595174 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,14 +20,11 @@ headers=default_headers, ), "response_type": Response( - content=asdict( - ErrorResponse( - error="invalid_request", - description="Missing response_type parameter.", - ) + content={}, + status_code=HTTPStatus.FOUND, + headers=HTTPHeaderDict( + location="https://ownauth.com/callback?error=invalid_request&error_description=Missing%20response_type%20parameter." ), - status_code=HTTPStatus.BAD_REQUEST, - headers=default_headers, ), "redirect_uri": Response( content=asdict( @@ -40,24 +37,18 @@ headers=default_headers, ), "code_challenge": Response( - content=asdict( - ErrorResponse( - error="invalid_request", - description="Code challenge required.", - ) + content={}, + status_code=HTTPStatus.FOUND, + headers=HTTPHeaderDict( + location="https://ownauth.com/callback?error=invalid_request&error_description=Code%20challenge%20required." ), - status_code=HTTPStatus.BAD_REQUEST, - headers=default_headers, ), "nonce": Response( - content=asdict( - ErrorResponse( - error="invalid_request", - description="Nonce required for response_type id_token.", - ) + content={}, + status_code=HTTPStatus.FOUND, + headers=HTTPHeaderDict( + location="https://ownauth.com/callback?error=invalid_request&error_description=Nonce%20required%20for%20response_type%20id_token." ), - status_code=HTTPStatus.BAD_REQUEST, - headers=default_headers, ), }, "POST": { @@ -167,14 +158,11 @@ headers=default_headers, ), "response_type": Response( - content=asdict( - ErrorResponse( - error="unsupported_response_type", - description="", - ) + content={}, + status_code=HTTPStatus.FOUND, + headers=HTTPHeaderDict( + location="https://ownauth.com/callback?error=unsupported_response_type" ), - status_code=HTTPStatus.BAD_REQUEST, - headers=default_headers, ), "redirect_uri": Response( content=asdict( @@ -187,24 +175,18 @@ headers=default_headers, ), "code_challenge_method": Response( - content=asdict( - ErrorResponse( - error="invalid_request", - description="Transform algorithm not supported.", - ) + content={}, + status_code=HTTPStatus.FOUND, + headers=HTTPHeaderDict( + location="https://ownauth.com/callback?error=invalid_request&error_description=Transform%20algorithm%20not%20supported." ), - status_code=HTTPStatus.BAD_REQUEST, - headers=default_headers, ), "scope": Response( - content=asdict( - ErrorResponse( - error="invalid_scope", - description="", - ) + content={}, + status_code=HTTPStatus.FOUND, + headers=HTTPHeaderDict( + location="https://ownauth.com/callback?error=invalid_scope" ), - status_code=HTTPStatus.BAD_REQUEST, - headers=default_headers, ), }, "POST": {