Skip to content

Commit

Permalink
Decoupling get_client_credentials and decode_auth_headers (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
aliev committed Nov 7, 2021
1 parent 0a8c3b3 commit c2b2711
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 65 deletions.
22 changes: 5 additions & 17 deletions aioauth/grant_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
----
"""

from typing import Tuple

from .errors import (
InvalidGrantError,
InvalidRequestError,
Expand All @@ -21,14 +18,16 @@
from .requests import Request
from .responses import TokenResponse
from .storage import BaseStorage
from .utils import decode_auth_headers, enforce_list, enforce_str, generate_token
from .utils import enforce_list, enforce_str, generate_token


class GrantTypeBase:
"""Base grant type that all other grant types inherit from."""

def __init__(self, storage: BaseStorage):
def __init__(self, storage: BaseStorage, client_id: str, client_secret: str):
self.storage = storage
self.client_id = client_id
self.client_secret = client_secret

async def create_token_response(
self, request: Request, client: Client
Expand All @@ -53,10 +52,8 @@ async def create_token_response(

async def validate_request(self, request: Request) -> Client:
"""Validates the client request to ensure it is valid."""
client_id, client_secret = self.get_client_credentials(request)

client = await self.storage.get_client(
request, client_id=client_id, client_secret=client_secret
request, client_id=self.client_id, client_secret=self.client_secret
)

if not client:
Expand All @@ -72,15 +69,6 @@ async def validate_request(self, request: Request) -> Client:

return client

def get_client_credentials(self, request: Request) -> Tuple[str, str]:
client_id = request.post.client_id
client_secret = request.post.client_secret

if client_id is None or client_secret is None:
client_id, client_secret = decode_auth_headers(request)

return client_id, client_secret


class AuthorizationCodeGrantType(GrantTypeBase):
"""
Expand Down
53 changes: 45 additions & 8 deletions aioauth/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
"""

from http import HTTPStatus
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Type, Union

from .collections import HTTPHeaderDict
from .constances import default_headers
from .errors import (
InsecureTransportError,
InvalidClientError,
InvalidRequestError,
MethodNotAllowedError,
TemporarilyUnavailableError,
Expand All @@ -33,6 +34,7 @@
from .grant_type import (
AuthorizationCodeGrantType,
ClientCredentialsGrantType,
GrantTypeBase,
PasswordGrantType,
RefreshTokenGrantType,
)
Expand All @@ -49,7 +51,13 @@
TokenInactiveIntrospectionResponse,
)
from .storage import BaseStorage
from .types import GrantType, RequestMethod, ResponseMode, ResponseType, TokenType
from .types import (
GrantType,
RequestMethod,
ResponseMode,
ResponseType,
TokenType,
)
from .utils import (
build_uri,
catch_errors_and_unavailability,
Expand Down Expand Up @@ -151,7 +159,7 @@ async def introspect(request: fastapi.Request) -> fastapi.Response:
response: An :py:class:`aioauth.responses.Response` object.
"""
self.validate_request(request, [RequestMethod.POST])
client_id, _ = decode_auth_headers(request)
client_id, _ = self.get_client_credentials(request)

token_types = set(TokenType)
token_type = TokenType.REFRESH
Expand Down Expand Up @@ -194,6 +202,22 @@ async def introspect(request: fastapi.Request) -> fastapi.Response:
content=content, status_code=HTTPStatus.OK, headers=default_headers
)

def get_client_credentials(self, request: Request) -> Tuple[str, str]:
client_id = request.post.client_id
client_secret = request.post.client_secret

if client_id is None or client_secret is None:
authorization = request.headers.get("Authorization", "")
headers = HTTPHeaderDict({"WWW-Authenticate": "Basic"})

# Get client credentials from the Authorization header.
try:
client_id, client_secret = decode_auth_headers(authorization)
except ValueError as exc:
raise InvalidClientError(request=request, headers=headers) from exc

return client_id, client_secret

@catch_errors_and_unavailability
async def create_token_response(self, request: Request) -> Response:
"""Endpoint to obtain an access and/or ID token by presenting an
Expand Down Expand Up @@ -228,20 +252,33 @@ async def token(request: fastapi.Request) -> fastapi.Response:
response: An :py:class:`aioauth.responses.Response` object.
"""
self.validate_request(request, [RequestMethod.POST])
client_id, client_secret = self.get_client_credentials(request)

if not request.post.grant_type:
# grant_type request value is empty
raise InvalidRequestError(
request=request, description="Request is missing grant type."
)

GrantTypeClass = self.grant_types.get(request.post.grant_type)
GrantTypeClass: Type[
Union[
GrantTypeBase,
AuthorizationCodeGrantType,
PasswordGrantType,
RefreshTokenGrantType,
ClientCredentialsGrantType,
]
]

if GrantTypeClass is None:
# Requested GrantType was not found in the list of the grant_types.
raise UnsupportedGrantTypeError(request=request)
try:
GrantTypeClass = self.grant_types[request.post.grant_type]
except KeyError as exc:
# grant_type request value is invalid
raise UnsupportedGrantTypeError(request=request) from exc

grant_type = GrantTypeClass(storage=self.storage)
grant_type = GrantTypeClass(
storage=self.storage, client_id=client_id, client_secret=client_secret
)

client = await grant_type.validate_request(request)

Expand Down
21 changes: 7 additions & 14 deletions aioauth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from .collections import HTTPHeaderDict
from .errors import (
InvalidClientError,
OAuth2Error,
ServerError,
TemporarilyUnavailableError,
Expand Down Expand Up @@ -166,37 +165,31 @@ def encode_auth_headers(client_id: str, client_secret: str) -> HTTPHeaderDict:
return HTTPHeaderDict(Authorization=f"basic {authorization.decode()}")


def decode_auth_headers(request: Request) -> Tuple[str, str]:
def decode_auth_headers(authorization: str) -> Tuple[str, str]:
"""
Decodes an encrypted HTTP basic authentication string.
Returns a tuple of the form ``(client_id, client_secret)``, and
raises a :py:class:`aioauth.errors.InvalidClientError` exception if nothing
could be decoded.
Args:
request: A request object.
authorization: Authorization header string.
Returns:
Tuple of the form ``(client_id, client_secret)``.
Raises:
aioauth.errors.InvalidClientError: Could not be decoded.
"""
authorization = request.headers.get("Authorization", "")

headers = HTTPHeaderDict({"WWW-Authenticate": "Basic"})

scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "basic":
raise InvalidClientError(request=request, headers=headers)
raise ValueError("Invalid authoirzation header string.")

try:
data = b64decode(param).decode("ascii")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise InvalidClientError(request=request, headers=headers)
except (ValueError, UnicodeDecodeError, binascii.Error) as exc:
raise ValueError("Invalid base64 encoding.") from exc

client_id, separator, client_secret = data.partition(":")

if not separator:
raise InvalidClientError(request=request, headers=headers)
raise ValueError("Separator was not provided.")

return client_id, client_secret

Expand Down Expand Up @@ -230,7 +223,7 @@ def catch_errors_and_unavailability(f) -> Callable:
"""

@functools.wraps(f)
async def wrapper(self, request: Request, *args, **kwargs) -> Optional[Response]:
async def wrapper(self, request: Request, *args, **kwargs) -> Response:
error: Union[TemporarilyUnavailableError, ServerError]

try:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def test_responses(app):
assert response.status_code == HTTPStatus.BAD_REQUEST

response = await ac.post("/token")
assert response.status_code == HTTPStatus.BAD_REQUEST
assert response.status_code == HTTPStatus.UNAUTHORIZED

response = await ac.post("/token/introspect")
assert response.status_code == HTTPStatus.UNAUTHORIZED
4 changes: 3 additions & 1 deletion tests/test_grant_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ async def test_refresh_token_grant_type(
headers=encode_auth_headers(client_id, client_secret),
)

grant_type = RefreshTokenGrantType(db)
grant_type = RefreshTokenGrantType(
db, client_id=defaults.client_id, client_secret=defaults.client_secret
)

client = await grant_type.validate_request(request)

Expand Down
33 changes: 9 additions & 24 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,39 +38,24 @@ def test_build_uri():

def test_decode_auth_headers():
request = Request(headers=HTTPHeaderDict(), method=RequestMethod.POST)
authorization = request.headers.get("Authorization", "")

# No authorization header
with pytest.raises(InvalidClientError):
decode_auth_headers(request=request)
with pytest.raises(ValueError):
decode_auth_headers("")

# Invalid authorization header
request = Request(
headers=HTTPHeaderDict({"authorization": ""}), method=RequestMethod.POST
)
with pytest.raises(InvalidClientError):
decode_auth_headers(request=request)
with pytest.raises(ValueError):
decode_auth_headers("test")

# No separator
authorization = b64encode("usernamepassword".encode("ascii"))

request = Request(
headers=HTTPHeaderDict(Authorization=f"basic {authorization.decode()}"),
method=RequestMethod.POST,
)

with pytest.raises(InvalidClientError):
decode_auth_headers(request=request)
with pytest.raises(ValueError):
decode_auth_headers(f"basic {authorization.decode()}")

# No base64 digits
authorization = b64encode("usernamepassword".encode("ascii"))

request = Request(
headers=HTTPHeaderDict(Authorization="basic привет"),
method=RequestMethod.POST,
)

with pytest.raises(InvalidClientError):
decode_auth_headers(request=request)
with pytest.raises(ValueError):
decode_auth_headers("basic привет")


def test_base_error_uri():
Expand Down

0 comments on commit c2b2711

Please sign in to comment.